1
use indoc::formatdoc;
2
use itertools::Itertools;
3
use sqlx::AssertSqlSafe;
4
use std::collections::BTreeMap;
5

            
6
use serde::{Deserialize, Serialize};
7

            
8
use sqlx::MySqlConnection;
9
use sqlx::prelude::*;
10

            
11
use crate::core::protocol::request_validation::GroupDenylist;
12
use crate::core::protocol::request_validation::validate_db_or_user_request;
13
use crate::core::types::DbOrUser;
14
use crate::{
15
    core::{
16
        common::UnixUser,
17
        database_privileges::DATABASE_PRIVILEGE_FIELDS,
18
        protocol::{
19
            CreateUserError, CreateUsersResponse, DropUserError, DropUsersResponse,
20
            ListAllUsersError, ListAllUsersResponse, ListUsersError, ListUsersResponse,
21
            LockUserError, LockUsersResponse, SetPasswordError, SetUserPasswordResponse,
22
            UnlockUserError, UnlockUsersResponse,
23
        },
24
        types::MySQLUser,
25
    },
26
    server::{
27
        common::{create_user_group_matching_regex, try_get_with_binary_fallback},
28
        sql::quote_literal,
29
    },
30
};
31

            
32
// NOTE: this function is unsafe because it does no input validation.
33
pub(super) async fn unsafe_user_exists(
34
    db_user: &str,
35
    connection: &mut MySqlConnection,
36
) -> Result<bool, sqlx::Error> {
37
    let result = sqlx::query(
38
        r"
39
          SELECT EXISTS(
40
            SELECT 1
41
            FROM `mysql`.`user`
42
            WHERE `User` = ?
43
              AND `Host` = '%'
44
          )
45
        ",
46
    )
47
    .bind(db_user)
48
    .fetch_one(connection)
49
    .await
50
    .map(|row| row.get::<bool, _>(0));
51

            
52
    if let Err(err) = &result {
53
        tracing::error!("Failed to check if database user exists: {:?}", err);
54
    }
55

            
56
    result
57
}
58

            
59
pub async fn complete_user_name(
60
    user_prefix: &str,
61
    unix_user: &UnixUser,
62
    connection: &mut MySqlConnection,
63
    _db_is_mariadb: bool,
64
    group_denylist: &GroupDenylist,
65
) -> Vec<MySQLUser> {
66
    let result = sqlx::query(
67
        r"
68
          SELECT `User` AS `user`
69
          FROM `mysql`.`user`
70
          WHERE `User` REGEXP ?
71
            AND `User` LIKE ?
72
            AND `Host` = '%'
73
        ",
74
    )
75
    .bind(create_user_group_matching_regex(unix_user, group_denylist))
76
    .bind(format!("{user_prefix}%"))
77
    .fetch_all(connection)
78
    .await;
79

            
80
    match result {
81
        Ok(rows) => rows
82
            .into_iter()
83
            .filter_map(|row| {
84
                let user: String = try_get_with_binary_fallback(&row, "user").ok()?;
85
                Some(user.into())
86
            })
87
            .collect(),
88
        Err(err) => {
89
            tracing::error!(
90
                "Failed to complete user name for prefix '{}' and user '{}': {:?}",
91
                user_prefix,
92
                unix_user.username,
93
                err
94
            );
95
            vec![]
96
        }
97
    }
98
}
99

            
100
pub async fn create_database_users(
101
    db_users: &[MySQLUser],
102
    unix_user: &UnixUser,
103
    connection: &mut MySqlConnection,
104
    _db_is_mariadb: bool,
105
    group_denylist: &GroupDenylist,
106
) -> CreateUsersResponse {
107
    let mut results = BTreeMap::new();
108

            
109
    for db_user in db_users.iter().cloned() {
110
        if let Err(err) =
111
            validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
112
                .map_err(CreateUserError::ValidationError)
113
        {
114
            results.insert(db_user, Err(err));
115
            continue;
116
        }
117

            
118
        match unsafe_user_exists(&db_user, &mut *connection).await {
119
            Ok(true) => {
120
                results.insert(db_user, Err(CreateUserError::UserAlreadyExists));
121
                continue;
122
            }
123
            Err(err) => {
124
                results.insert(db_user, Err(CreateUserError::MySqlError(err.to_string())));
125
                continue;
126
            }
127
            _ => {}
128
        }
129

            
130
        let statement = AssertSqlSafe(format!("CREATE USER {}@'%'", quote_literal(&db_user),));
131
        let result = sqlx::query(statement)
132
            .execute(&mut *connection)
133
            .await
134
            .map(|_| ())
135
            .map_err(|err| CreateUserError::MySqlError(err.to_string()));
136

            
137
        if let Err(err) = &result {
138
            tracing::error!("Failed to create database user '{}': {:?}", &db_user, err);
139
        }
140

            
141
        results.insert(db_user, result);
142
    }
143

            
144
    results
145
}
146

            
147
pub async fn drop_database_users(
148
    db_users: &[MySQLUser],
149
    unix_user: &UnixUser,
150
    connection: &mut MySqlConnection,
151
    _db_is_mariadb: bool,
152
    group_denylist: &GroupDenylist,
153
) -> DropUsersResponse {
154
    let mut results = BTreeMap::new();
155

            
156
    for db_user in db_users.iter().cloned() {
157
        if let Err(err) =
158
            validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
159
                .map_err(DropUserError::ValidationError)
160
        {
161
            results.insert(db_user, Err(err));
162
            continue;
163
        }
164

            
165
        match unsafe_user_exists(&db_user, &mut *connection).await {
166
            Ok(false) => {
167
                results.insert(db_user, Err(DropUserError::UserDoesNotExist));
168
                continue;
169
            }
170
            Err(err) => {
171
                results.insert(db_user, Err(DropUserError::MySqlError(err.to_string())));
172
                continue;
173
            }
174
            _ => {}
175
        }
176

            
177
        let statement = AssertSqlSafe(format!("DROP USER {}@'%'", quote_literal(&db_user),));
178
        let result = sqlx::query(statement)
179
            .execute(&mut *connection)
180
            .await
181
            .map(|_| ())
182
            .map_err(|err| DropUserError::MySqlError(err.to_string()));
183

            
184
        if let Err(err) = &result {
185
            tracing::error!("Failed to drop database user '{}': {:?}", &db_user, err);
186
        }
187

            
188
        results.insert(db_user, result);
189
    }
190

            
191
    results
192
}
193

            
194
pub async fn set_password_for_database_user(
195
    db_user: &MySQLUser,
196
    password: &str,
197
    unix_user: &UnixUser,
198
    connection: &mut MySqlConnection,
199
    _db_is_mariadb: bool,
200
    group_denylist: &GroupDenylist,
201
) -> SetUserPasswordResponse {
202
    validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
203
        .map_err(SetPasswordError::ValidationError)?;
204

            
205
    match unsafe_user_exists(db_user, &mut *connection).await {
206
        Ok(false) => return Err(SetPasswordError::UserDoesNotExist),
207
        Err(err) => return Err(SetPasswordError::MySqlError(err.to_string())),
208
        _ => {}
209
    }
210

            
211
    let statement = AssertSqlSafe(format!(
212
        "ALTER USER {}@'%' IDENTIFIED BY {}",
213
        quote_literal(db_user),
214
        quote_literal(password).as_str(),
215
    ));
216
    let result = sqlx::query(statement)
217
        .execute(&mut *connection)
218
        .await
219
        .map(|_| ())
220
        .map_err(|err| SetPasswordError::MySqlError(err.to_string()));
221

            
222
    if result.is_err() {
223
        tracing::error!(
224
            "Failed to set password for database user '{}': <REDACTED>",
225
            &db_user,
226
        );
227
    }
228

            
229
    result
230
}
231

            
232
const DATABASE_USER_LOCK_STATUS_QUERY_MARIADB: &str = r#"
233
    SELECT COALESCE(
234
        JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"),
235
        'false'
236
    ) != 'false'
237
    FROM `mysql`.`global_priv`
238
    WHERE `User` = ?
239
    AND `Host` = '%'
240
"#;
241

            
242
const DATABASE_USER_LOCK_STATUS_QUERY_MYSQL: &str = r"
243
    SELECT `mysql`.`user`.`account_locked` = 'Y'
244
    FROM `mysql`.`user`
245
    WHERE `User` = ?
246
    AND `Host` = '%'
247
";
248

            
249
// NOTE: this function is unsafe because it does no input validation.
250
async fn database_user_is_locked_unsafe(
251
    db_user: &str,
252
    connection: &mut MySqlConnection,
253
    db_is_mariadb: bool,
254
) -> Result<bool, sqlx::Error> {
255
    let result = sqlx::query(if db_is_mariadb {
256
        DATABASE_USER_LOCK_STATUS_QUERY_MARIADB
257
    } else {
258
        DATABASE_USER_LOCK_STATUS_QUERY_MYSQL
259
    })
260
    .bind(db_user)
261
    .fetch_one(connection)
262
    .await
263
    .map(|row| row.try_get(0))
264
    .and_then(|res| res);
265

            
266
    if let Err(err) = &result {
267
        tracing::error!(
268
            "Failed to check if database user is locked '{}': {:?}",
269
            &db_user,
270
            err
271
        );
272
    }
273

            
274
    result
275
}
276

            
277
pub async fn lock_database_users(
278
    db_users: &[MySQLUser],
279
    unix_user: &UnixUser,
280
    connection: &mut MySqlConnection,
281
    db_is_mariadb: bool,
282
    group_denylist: &GroupDenylist,
283
) -> LockUsersResponse {
284
    let mut results = BTreeMap::new();
285

            
286
    for db_user in db_users.iter().cloned() {
287
        if let Err(err) =
288
            validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
289
                .map_err(LockUserError::ValidationError)
290
        {
291
            results.insert(db_user, Err(err));
292
            continue;
293
        }
294

            
295
        match unsafe_user_exists(&db_user, &mut *connection).await {
296
            Ok(true) => {}
297
            Ok(false) => {
298
                results.insert(db_user, Err(LockUserError::UserDoesNotExist));
299
                continue;
300
            }
301
            Err(err) => {
302
                results.insert(db_user, Err(LockUserError::MySqlError(err.to_string())));
303
                continue;
304
            }
305
        }
306

            
307
        match database_user_is_locked_unsafe(&db_user, &mut *connection, db_is_mariadb).await {
308
            Ok(false) => {}
309
            Ok(true) => {
310
                results.insert(db_user, Err(LockUserError::UserIsAlreadyLocked));
311
                continue;
312
            }
313
            Err(err) => {
314
                results.insert(db_user, Err(LockUserError::MySqlError(err.to_string())));
315
                continue;
316
            }
317
        }
318

            
319
        let statement = AssertSqlSafe(format!(
320
            "ALTER USER {}@'%' ACCOUNT LOCK",
321
            quote_literal(&db_user),
322
        ));
323
        let result = sqlx::query(statement)
324
            .execute(&mut *connection)
325
            .await
326
            .map(|_| ())
327
            .map_err(|err| LockUserError::MySqlError(err.to_string()));
328

            
329
        if let Err(err) = &result {
330
            tracing::error!("Failed to lock database user '{}': {:?}", &db_user, err);
331
        }
332

            
333
        results.insert(db_user, result);
334
    }
335

            
336
    results
337
}
338

            
339
pub async fn unlock_database_users(
340
    db_users: &[MySQLUser],
341
    unix_user: &UnixUser,
342
    connection: &mut MySqlConnection,
343
    db_is_mariadb: bool,
344
    group_denylist: &GroupDenylist,
345
) -> UnlockUsersResponse {
346
    let mut results = BTreeMap::new();
347

            
348
    for db_user in db_users.iter().cloned() {
349
        if let Err(err) =
350
            validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
351
                .map_err(UnlockUserError::ValidationError)
352
        {
353
            results.insert(db_user, Err(err));
354
            continue;
355
        }
356

            
357
        match unsafe_user_exists(&db_user, &mut *connection).await {
358
            Ok(false) => {
359
                results.insert(db_user, Err(UnlockUserError::UserDoesNotExist));
360
                continue;
361
            }
362
            Err(err) => {
363
                results.insert(db_user, Err(UnlockUserError::MySqlError(err.to_string())));
364
                continue;
365
            }
366
            _ => {}
367
        }
368

            
369
        match database_user_is_locked_unsafe(&db_user, &mut *connection, db_is_mariadb).await {
370
            Ok(false) => {
371
                results.insert(db_user, Err(UnlockUserError::UserIsAlreadyUnlocked));
372
                continue;
373
            }
374
            Err(err) => {
375
                results.insert(db_user, Err(UnlockUserError::MySqlError(err.to_string())));
376
                continue;
377
            }
378
            _ => {}
379
        }
380

            
381
        let statement = AssertSqlSafe(format!(
382
            "ALTER USER {}@'%' ACCOUNT UNLOCK",
383
            quote_literal(&db_user),
384
        ));
385
        let result = sqlx::query(statement)
386
            .execute(&mut *connection)
387
            .await
388
            .map(|_| ())
389
            .map_err(|err| UnlockUserError::MySqlError(err.to_string()));
390

            
391
        if let Err(err) = &result {
392
            tracing::error!("Failed to unlock database user '{}': {:?}", &db_user, err);
393
        }
394

            
395
        results.insert(db_user, result);
396
    }
397

            
398
    results
399
}
400

            
401
/// This struct contains information about a database user.
402
/// This can be extended if we need more information in the future.
403
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
404
pub struct DatabaseUser {
405
    pub user: MySQLUser,
406
    #[serde(skip)]
407
    pub host: String,
408
    pub has_password: bool,
409
    pub is_locked: bool,
410
    pub databases: Vec<String>,
411
}
412

            
413
impl FromRow<'_, sqlx::mysql::MySqlRow> for DatabaseUser {
414
    fn from_row(row: &sqlx::mysql::MySqlRow) -> Result<Self, sqlx::Error> {
415
        Ok(Self {
416
            user: try_get_with_binary_fallback(row, "User")?.into(),
417
            host: try_get_with_binary_fallback(row, "Host")?,
418
            has_password: row.try_get("has_password")?,
419
            is_locked: row.try_get("account_locked")?,
420
            databases: Vec::new(),
421
        })
422
    }
423
}
424

            
425
const DB_USER_SELECT_STATEMENT_MARIADB: &str = r#"
426
SELECT
427
  `user`.`User`,
428
  `user`.`Host`,
429
  `user`.`Password` != '' OR `user`.`authentication_string` != '' AS `has_password`,
430
  COALESCE(
431
    JSON_EXTRACT(`global_priv`.`priv`, "$.account_locked"),
432
    'false'
433
  ) != 'false' AS `account_locked`
434
FROM `user`
435
JOIN `global_priv` ON
436
  `user`.`User` = `global_priv`.`User`
437
  AND `user`.`Host` = `global_priv`.`Host`
438
"#;
439

            
440
const DB_USER_SELECT_STATEMENT_MYSQL: &str = r"
441
SELECT
442
  `user`.`User`,
443
  `user`.`Host`,
444
  `user`.`authentication_string` != '' AS `has_password`,
445
  `user`.`account_locked` = 'Y' AS `account_locked`
446
FROM `user`
447
";
448

            
449
pub async fn list_database_users(
450
    db_users: &[MySQLUser],
451
    unix_user: &UnixUser,
452
    connection: &mut MySqlConnection,
453
    db_is_mariadb: bool,
454
    group_denylist: &GroupDenylist,
455
) -> ListUsersResponse {
456
    let mut results = BTreeMap::new();
457

            
458
    for db_user in db_users.iter().cloned() {
459
        if let Err(err) =
460
            validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
461
                .map_err(ListUsersError::ValidationError)
462
        {
463
            results.insert(db_user, Err(err));
464
            continue;
465
        }
466

            
467
        let statement = AssertSqlSafe(
468
            if db_is_mariadb {
469
                DB_USER_SELECT_STATEMENT_MARIADB.to_string()
470
            } else {
471
                DB_USER_SELECT_STATEMENT_MYSQL.to_string()
472
            } + "WHERE `mysql`.`user`.`User` = ? AND `mysql`.`user`.`Host` = '%'",
473
        );
474
        let mut result = sqlx::query_as::<_, DatabaseUser>(statement)
475
            .bind(db_user.as_str())
476
            .fetch_optional(&mut *connection)
477
            .await;
478

            
479
        if let Err(err) = &result {
480
            tracing::error!("Failed to list database user '{}': {:?}", &db_user, err);
481
        }
482

            
483
        if let Ok(Some(user)) = result.as_mut()
484
            && let Err(err) = set_databases_where_user_has_privileges(user, &mut *connection).await
485
        {
486
            result = Err(err);
487
        }
488

            
489
        match result {
490
            Ok(Some(user)) => results.insert(db_user, Ok(user)),
491
            Ok(None) => results.insert(db_user, Err(ListUsersError::UserDoesNotExist)),
492
            Err(err) => results.insert(db_user, Err(ListUsersError::MySqlError(err.to_string()))),
493
        };
494
    }
495

            
496
    results
497
}
498

            
499
pub async fn list_all_database_users_for_unix_user(
500
    unix_user: &UnixUser,
501
    connection: &mut MySqlConnection,
502
    db_is_mariadb: bool,
503
    group_denylist: &GroupDenylist,
504
) -> ListAllUsersResponse {
505
    let statement = AssertSqlSafe(
506
        if db_is_mariadb {
507
            DB_USER_SELECT_STATEMENT_MARIADB.to_string()
508
        } else {
509
            DB_USER_SELECT_STATEMENT_MYSQL.to_string()
510
        } + "WHERE `user`.`User` REGEXP ? AND `user`.`Host` = '%'",
511
    );
512
    let mut result = sqlx::query_as::<_, DatabaseUser>(statement)
513
        .bind(create_user_group_matching_regex(unix_user, group_denylist))
514
        .fetch_all(&mut *connection)
515
        .await
516
        .map_err(|err| ListAllUsersError::MySqlError(err.to_string()));
517

            
518
    if let Err(err) = &result {
519
        tracing::error!("Failed to list all database users: {:?}", err);
520
    }
521

            
522
    if let Ok(users) = result.as_mut() {
523
        for user in users {
524
            if let Err(mysql_error) =
525
                set_databases_where_user_has_privileges(user, &mut *connection).await
526
            {
527
                return Err(ListAllUsersError::MySqlError(mysql_error.to_string()));
528
            }
529
        }
530
    }
531

            
532
    result
533
}
534

            
535
/// This function sets the `databases` field of the given `DatabaseUser`
536
/// where the user has any privileges.
537
pub async fn set_databases_where_user_has_privileges(
538
    db_user: &mut DatabaseUser,
539
    connection: &mut MySqlConnection,
540
) -> Result<(), sqlx::Error> {
541
    let statement = AssertSqlSafe(formatdoc!(
542
        r"
543
            SELECT `Db` AS `database`
544
            FROM `db`
545
            WHERE `User` = ?  AND `Host` = '%' AND ({})
546
        ",
547
        DATABASE_PRIVILEGE_FIELDS
548
            .iter()
549
            .map(|field| format!("`{field}` = 'Y'"))
550
            .join(" OR "),
551
    ));
552
    let database_list = sqlx::query(statement)
553
        .bind(db_user.user.as_str())
554
        .fetch_all(&mut *connection)
555
        .await;
556

            
557
    if let Err(err) = &database_list {
558
        tracing::error!(
559
            "Failed to list databases for user '{}': {:?}",
560
            &db_user.user,
561
            err
562
        );
563
    }
564

            
565
    db_user.databases = database_list.and_then(|rows| {
566
        rows.into_iter()
567
            .map(|row| try_get_with_binary_fallback(&row, "database"))
568
            .collect::<Result<Vec<String>, sqlx::Error>>()
569
    })?;
570

            
571
    Ok(())
572
}