1
use indoc::formatdoc;
2
use itertools::Itertools;
3
use std::collections::BTreeMap;
4

            
5
use serde::{Deserialize, Serialize};
6

            
7
use sqlx::MySqlConnection;
8
use sqlx::prelude::*;
9

            
10
use crate::core::protocol::request_validation::GroupDenylist;
11
use crate::core::protocol::request_validation::validate_db_or_user_request;
12
use crate::core::types::DbOrUser;
13
use crate::{
14
    core::{
15
        common::UnixUser,
16
        database_privileges::DATABASE_PRIVILEGE_FIELDS,
17
        protocol::{
18
            CreateUserError, CreateUsersResponse, DropUserError, DropUsersResponse,
19
            ListAllUsersError, ListAllUsersResponse, ListUsersError, ListUsersResponse,
20
            LockUserError, LockUsersResponse, SetPasswordError, SetUserPasswordResponse,
21
            UnlockUserError, UnlockUsersResponse,
22
        },
23
        types::MySQLUser,
24
    },
25
    server::{
26
        common::{create_user_group_matching_regex, try_get_with_binary_fallback},
27
        sql::quote_literal,
28
    },
29
};
30

            
31
// NOTE: this function is unsafe because it does no input validation.
32
pub(super) async fn unsafe_user_exists(
33
    db_user: &str,
34
    connection: &mut MySqlConnection,
35
) -> Result<bool, sqlx::Error> {
36
    let result = sqlx::query(
37
        r"
38
          SELECT EXISTS(
39
            SELECT 1
40
            FROM `mysql`.`user`
41
            WHERE `User` = ?
42
          )
43
        ",
44
    )
45
    .bind(db_user)
46
    .fetch_one(connection)
47
    .await
48
    .map(|row| row.get::<bool, _>(0));
49

            
50
    if let Err(err) = &result {
51
        tracing::error!("Failed to check if database user exists: {:?}", err);
52
    }
53

            
54
    result
55
}
56

            
57
pub async fn complete_user_name(
58
    user_prefix: &str,
59
    unix_user: &UnixUser,
60
    connection: &mut MySqlConnection,
61
    _db_is_mariadb: bool,
62
    group_denylist: &GroupDenylist,
63
) -> Vec<MySQLUser> {
64
    let result = sqlx::query(
65
        r"
66
          SELECT `User` AS `user`
67
          FROM `mysql`.`user`
68
          WHERE `User` REGEXP ?
69
            AND `User` LIKE ?
70
        ",
71
    )
72
    .bind(create_user_group_matching_regex(unix_user, group_denylist))
73
    .bind(format!("{user_prefix}%"))
74
    .fetch_all(connection)
75
    .await;
76

            
77
    match result {
78
        Ok(rows) => rows
79
            .into_iter()
80
            .filter_map(|row| {
81
                let user: String = try_get_with_binary_fallback(&row, "user").ok()?;
82
                Some(user.into())
83
            })
84
            .collect(),
85
        Err(err) => {
86
            tracing::error!(
87
                "Failed to complete user name for prefix '{}' and user '{}': {:?}",
88
                user_prefix,
89
                unix_user.username,
90
                err
91
            );
92
            vec![]
93
        }
94
    }
95
}
96

            
97
pub async fn create_database_users(
98
    db_users: &[MySQLUser],
99
    unix_user: &UnixUser,
100
    connection: &mut MySqlConnection,
101
    _db_is_mariadb: bool,
102
    group_denylist: &GroupDenylist,
103
) -> CreateUsersResponse {
104
    let mut results = BTreeMap::new();
105

            
106
    for db_user in db_users.iter().cloned() {
107
        if let Err(err) =
108
            validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
109
                .map_err(CreateUserError::ValidationError)
110
        {
111
            results.insert(db_user, Err(err));
112
            continue;
113
        }
114

            
115
        match unsafe_user_exists(&db_user, &mut *connection).await {
116
            Ok(true) => {
117
                results.insert(db_user, Err(CreateUserError::UserAlreadyExists));
118
                continue;
119
            }
120
            Err(err) => {
121
                results.insert(db_user, Err(CreateUserError::MySqlError(err.to_string())));
122
                continue;
123
            }
124
            _ => {}
125
        }
126

            
127
        let result = sqlx::query(format!("CREATE USER {}@'%'", quote_literal(&db_user),).as_str())
128
            .execute(&mut *connection)
129
            .await
130
            .map(|_| ())
131
            .map_err(|err| CreateUserError::MySqlError(err.to_string()));
132

            
133
        if let Err(err) = &result {
134
            tracing::error!("Failed to create database user '{}': {:?}", &db_user, err);
135
        }
136

            
137
        results.insert(db_user, result);
138
    }
139

            
140
    results
141
}
142

            
143
pub async fn drop_database_users(
144
    db_users: &[MySQLUser],
145
    unix_user: &UnixUser,
146
    connection: &mut MySqlConnection,
147
    _db_is_mariadb: bool,
148
    group_denylist: &GroupDenylist,
149
) -> DropUsersResponse {
150
    let mut results = BTreeMap::new();
151

            
152
    for db_user in db_users.iter().cloned() {
153
        if let Err(err) =
154
            validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
155
                .map_err(DropUserError::ValidationError)
156
        {
157
            results.insert(db_user, Err(err));
158
            continue;
159
        }
160

            
161
        match unsafe_user_exists(&db_user, &mut *connection).await {
162
            Ok(false) => {
163
                results.insert(db_user, Err(DropUserError::UserDoesNotExist));
164
                continue;
165
            }
166
            Err(err) => {
167
                results.insert(db_user, Err(DropUserError::MySqlError(err.to_string())));
168
                continue;
169
            }
170
            _ => {}
171
        }
172

            
173
        let result = sqlx::query(format!("DROP USER {}@'%'", quote_literal(&db_user),).as_str())
174
            .execute(&mut *connection)
175
            .await
176
            .map(|_| ())
177
            .map_err(|err| DropUserError::MySqlError(err.to_string()));
178

            
179
        if let Err(err) = &result {
180
            tracing::error!("Failed to drop database user '{}': {:?}", &db_user, err);
181
        }
182

            
183
        results.insert(db_user, result);
184
    }
185

            
186
    results
187
}
188

            
189
pub async fn set_password_for_database_user(
190
    db_user: &MySQLUser,
191
    password: &str,
192
    unix_user: &UnixUser,
193
    connection: &mut MySqlConnection,
194
    _db_is_mariadb: bool,
195
    group_denylist: &GroupDenylist,
196
) -> SetUserPasswordResponse {
197
    validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
198
        .map_err(SetPasswordError::ValidationError)?;
199

            
200
    match unsafe_user_exists(db_user, &mut *connection).await {
201
        Ok(false) => return Err(SetPasswordError::UserDoesNotExist),
202
        Err(err) => return Err(SetPasswordError::MySqlError(err.to_string())),
203
        _ => {}
204
    }
205

            
206
    let result = sqlx::query(
207
        format!(
208
            "ALTER USER {}@'%' IDENTIFIED BY {}",
209
            quote_literal(db_user),
210
            quote_literal(password).as_str(),
211
        )
212
        .as_str(),
213
    )
214
    .execute(&mut *connection)
215
    .await
216
    .map(|_| ())
217
    .map_err(|err| SetPasswordError::MySqlError(err.to_string()));
218

            
219
    if result.is_err() {
220
        tracing::error!(
221
            "Failed to set password for database user '{}': <REDACTED>",
222
            &db_user,
223
        );
224
    }
225

            
226
    result
227
}
228

            
229
const DATABASE_USER_LOCK_STATUS_QUERY_MARIADB: &str = r#"
230
    SELECT COALESCE(
231
        JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"),
232
        'false'
233
    ) != 'false'
234
    FROM `mysql`.`global_priv`
235
    WHERE `User` = ?
236
    AND `Host` = '%'
237
"#;
238

            
239
const DATABASE_USER_LOCK_STATUS_QUERY_MYSQL: &str = r"
240
    SELECT `mysql`.`user`.`account_locked` = 'Y'
241
    FROM `mysql`.`user`
242
    WHERE `User` = ?
243
    AND `Host` = '%'
244
";
245

            
246
// NOTE: this function is unsafe because it does no input validation.
247
async fn database_user_is_locked_unsafe(
248
    db_user: &str,
249
    connection: &mut MySqlConnection,
250
    db_is_mariadb: bool,
251
) -> Result<bool, sqlx::Error> {
252
    let result = sqlx::query(if db_is_mariadb {
253
        DATABASE_USER_LOCK_STATUS_QUERY_MARIADB
254
    } else {
255
        DATABASE_USER_LOCK_STATUS_QUERY_MYSQL
256
    })
257
    .bind(db_user)
258
    .fetch_one(connection)
259
    .await
260
    .map(|row| row.try_get(0))
261
    .and_then(|res| res);
262

            
263
    if let Err(err) = &result {
264
        tracing::error!(
265
            "Failed to check if database user is locked '{}': {:?}",
266
            &db_user,
267
            err
268
        );
269
    }
270

            
271
    result
272
}
273

            
274
pub async fn lock_database_users(
275
    db_users: &[MySQLUser],
276
    unix_user: &UnixUser,
277
    connection: &mut MySqlConnection,
278
    db_is_mariadb: bool,
279
    group_denylist: &GroupDenylist,
280
) -> LockUsersResponse {
281
    let mut results = BTreeMap::new();
282

            
283
    for db_user in db_users.iter().cloned() {
284
        if let Err(err) =
285
            validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
286
                .map_err(LockUserError::ValidationError)
287
        {
288
            results.insert(db_user, Err(err));
289
            continue;
290
        }
291

            
292
        match unsafe_user_exists(&db_user, &mut *connection).await {
293
            Ok(true) => {}
294
            Ok(false) => {
295
                results.insert(db_user, Err(LockUserError::UserDoesNotExist));
296
                continue;
297
            }
298
            Err(err) => {
299
                results.insert(db_user, Err(LockUserError::MySqlError(err.to_string())));
300
                continue;
301
            }
302
        }
303

            
304
        match database_user_is_locked_unsafe(&db_user, &mut *connection, db_is_mariadb).await {
305
            Ok(false) => {}
306
            Ok(true) => {
307
                results.insert(db_user, Err(LockUserError::UserIsAlreadyLocked));
308
                continue;
309
            }
310
            Err(err) => {
311
                results.insert(db_user, Err(LockUserError::MySqlError(err.to_string())));
312
                continue;
313
            }
314
        }
315

            
316
        let result = sqlx::query(
317
            format!("ALTER USER {}@'%' ACCOUNT LOCK", quote_literal(&db_user),).as_str(),
318
        )
319
        .execute(&mut *connection)
320
        .await
321
        .map(|_| ())
322
        .map_err(|err| LockUserError::MySqlError(err.to_string()));
323

            
324
        if let Err(err) = &result {
325
            tracing::error!("Failed to lock database user '{}': {:?}", &db_user, err);
326
        }
327

            
328
        results.insert(db_user, result);
329
    }
330

            
331
    results
332
}
333

            
334
pub async fn unlock_database_users(
335
    db_users: &[MySQLUser],
336
    unix_user: &UnixUser,
337
    connection: &mut MySqlConnection,
338
    db_is_mariadb: bool,
339
    group_denylist: &GroupDenylist,
340
) -> UnlockUsersResponse {
341
    let mut results = BTreeMap::new();
342

            
343
    for db_user in db_users.iter().cloned() {
344
        if let Err(err) =
345
            validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
346
                .map_err(UnlockUserError::ValidationError)
347
        {
348
            results.insert(db_user, Err(err));
349
            continue;
350
        }
351

            
352
        match unsafe_user_exists(&db_user, &mut *connection).await {
353
            Ok(false) => {
354
                results.insert(db_user, Err(UnlockUserError::UserDoesNotExist));
355
                continue;
356
            }
357
            Err(err) => {
358
                results.insert(db_user, Err(UnlockUserError::MySqlError(err.to_string())));
359
                continue;
360
            }
361
            _ => {}
362
        }
363

            
364
        match database_user_is_locked_unsafe(&db_user, &mut *connection, db_is_mariadb).await {
365
            Ok(false) => {
366
                results.insert(db_user, Err(UnlockUserError::UserIsAlreadyUnlocked));
367
                continue;
368
            }
369
            Err(err) => {
370
                results.insert(db_user, Err(UnlockUserError::MySqlError(err.to_string())));
371
                continue;
372
            }
373
            _ => {}
374
        }
375

            
376
        let result = sqlx::query(
377
            format!("ALTER USER {}@'%' ACCOUNT UNLOCK", quote_literal(&db_user),).as_str(),
378
        )
379
        .execute(&mut *connection)
380
        .await
381
        .map(|_| ())
382
        .map_err(|err| UnlockUserError::MySqlError(err.to_string()));
383

            
384
        if let Err(err) = &result {
385
            tracing::error!("Failed to unlock database user '{}': {:?}", &db_user, err);
386
        }
387

            
388
        results.insert(db_user, result);
389
    }
390

            
391
    results
392
}
393

            
394
/// This struct contains information about a database user.
395
/// This can be extended if we need more information in the future.
396
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
397
pub struct DatabaseUser {
398
    pub user: MySQLUser,
399
    #[serde(skip)]
400
    pub host: String,
401
    pub has_password: bool,
402
    pub is_locked: bool,
403
    pub databases: Vec<String>,
404
}
405

            
406
impl FromRow<'_, sqlx::mysql::MySqlRow> for DatabaseUser {
407
    fn from_row(row: &sqlx::mysql::MySqlRow) -> Result<Self, sqlx::Error> {
408
        Ok(Self {
409
            user: try_get_with_binary_fallback(row, "User")?.into(),
410
            host: try_get_with_binary_fallback(row, "Host")?,
411
            has_password: row.try_get("has_password")?,
412
            is_locked: row.try_get("account_locked")?,
413
            databases: Vec::new(),
414
        })
415
    }
416
}
417

            
418
const DB_USER_SELECT_STATEMENT_MARIADB: &str = r#"
419
SELECT
420
  `user`.`User`,
421
  `user`.`Host`,
422
  `user`.`Password` != '' OR `user`.`authentication_string` != '' AS `has_password`,
423
  COALESCE(
424
    JSON_EXTRACT(`global_priv`.`priv`, "$.account_locked"),
425
    'false'
426
  ) != 'false' AS `account_locked`
427
FROM `user`
428
JOIN `global_priv` ON
429
  `user`.`User` = `global_priv`.`User`
430
  AND `user`.`Host` = `global_priv`.`Host`
431
"#;
432

            
433
const DB_USER_SELECT_STATEMENT_MYSQL: &str = r"
434
SELECT
435
  `user`.`User`,
436
  `user`.`Host`,
437
  `user`.`authentication_string` != '' AS `has_password`,
438
  `user`.`account_locked` = 'Y' AS `account_locked`
439
FROM `user`
440
";
441

            
442
pub async fn list_database_users(
443
    db_users: &[MySQLUser],
444
    unix_user: &UnixUser,
445
    connection: &mut MySqlConnection,
446
    db_is_mariadb: bool,
447
    group_denylist: &GroupDenylist,
448
) -> ListUsersResponse {
449
    let mut results = BTreeMap::new();
450

            
451
    for db_user in db_users.iter().cloned() {
452
        if let Err(err) =
453
            validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
454
                .map_err(ListUsersError::ValidationError)
455
        {
456
            results.insert(db_user, Err(err));
457
            continue;
458
        }
459

            
460
        let mut result = sqlx::query_as::<_, DatabaseUser>(
461
            &(if db_is_mariadb {
462
                DB_USER_SELECT_STATEMENT_MARIADB.to_string()
463
            } else {
464
                DB_USER_SELECT_STATEMENT_MYSQL.to_string()
465
            } + "WHERE `mysql`.`user`.`User` = ?"),
466
        )
467
        .bind(db_user.as_str())
468
        .fetch_optional(&mut *connection)
469
        .await;
470

            
471
        if let Err(err) = &result {
472
            tracing::error!("Failed to list database user '{}': {:?}", &db_user, err);
473
        }
474

            
475
        if let Ok(Some(user)) = result.as_mut()
476
            && let Err(err) = set_databases_where_user_has_privileges(user, &mut *connection).await
477
        {
478
            result = Err(err);
479
        }
480

            
481
        match result {
482
            Ok(Some(user)) => results.insert(db_user, Ok(user)),
483
            Ok(None) => results.insert(db_user, Err(ListUsersError::UserDoesNotExist)),
484
            Err(err) => results.insert(db_user, Err(ListUsersError::MySqlError(err.to_string()))),
485
        };
486
    }
487

            
488
    results
489
}
490

            
491
pub async fn list_all_database_users_for_unix_user(
492
    unix_user: &UnixUser,
493
    connection: &mut MySqlConnection,
494
    db_is_mariadb: bool,
495
    group_denylist: &GroupDenylist,
496
) -> ListAllUsersResponse {
497
    let mut result = sqlx::query_as::<_, DatabaseUser>(
498
        &(if db_is_mariadb {
499
            DB_USER_SELECT_STATEMENT_MARIADB.to_string()
500
        } else {
501
            DB_USER_SELECT_STATEMENT_MYSQL.to_string()
502
        } + "WHERE `user`.`User` REGEXP ?"),
503
    )
504
    .bind(create_user_group_matching_regex(unix_user, group_denylist))
505
    .fetch_all(&mut *connection)
506
    .await
507
    .map_err(|err| ListAllUsersError::MySqlError(err.to_string()));
508

            
509
    if let Err(err) = &result {
510
        tracing::error!("Failed to list all database users: {:?}", err);
511
    }
512

            
513
    if let Ok(users) = result.as_mut() {
514
        for user in users {
515
            if let Err(mysql_error) =
516
                set_databases_where_user_has_privileges(user, &mut *connection).await
517
            {
518
                return Err(ListAllUsersError::MySqlError(mysql_error.to_string()));
519
            }
520
        }
521
    }
522

            
523
    result
524
}
525

            
526
/// This function sets the `databases` field of the given `DatabaseUser`
527
/// where the user has any privileges.
528
pub async fn set_databases_where_user_has_privileges(
529
    db_user: &mut DatabaseUser,
530
    connection: &mut MySqlConnection,
531
) -> Result<(), sqlx::Error> {
532
    let database_list = sqlx::query(
533
        formatdoc!(
534
            r"
535
                SELECT `Db` AS `database`
536
                FROM `db`
537
                WHERE `User` = ? AND ({})
538
            ",
539
            DATABASE_PRIVILEGE_FIELDS
540
                .iter()
541
                .map(|field| format!("`{field}` = 'Y'"))
542
                .join(" OR "),
543
        )
544
        .as_str(),
545
    )
546
    .bind(db_user.user.as_str())
547
    .fetch_all(&mut *connection)
548
    .await;
549

            
550
    if let Err(err) = &database_list {
551
        tracing::error!(
552
            "Failed to list databases for user '{}': {:?}",
553
            &db_user.user,
554
            err
555
        );
556
    }
557

            
558
    db_user.databases = database_list.and_then(|rows| {
559
        rows.into_iter()
560
            .map(|row| try_get_with_binary_fallback(&row, "database"))
561
            .collect::<Result<Vec<String>, sqlx::Error>>()
562
    })?;
563

            
564
    Ok(())
565
}