1
// TODO: fix comment
2
//! Database privilege operations
3
//!
4
//! This module contains functions for querying, modifying,
5
//! displaying and comparing database privileges.
6
//!
7
//! A lot of the complexity comes from two core components:
8
//!
9
//! - The privilege editor that needs to be able to print
10
//!   an editable table of privileges and reparse the content
11
//!   after the user has made manual changes.
12
//!
13
//! - The comparison functionality that tells the user what
14
//!   changes will be made when applying a set of changes
15
//!   to the list of database privileges.
16

            
17
use std::collections::{BTreeMap, BTreeSet};
18

            
19
use indoc::indoc;
20
use itertools::Itertools;
21
use sqlx::{AssertSqlSafe, MySqlConnection, mysql::MySqlRow, prelude::*};
22

            
23
use crate::{
24
    core::{
25
        common::{UnixUser, rev_yn, yn},
26
        database_privileges::{
27
            DATABASE_PRIVILEGE_FIELDS, DatabasePrivilegeChange, DatabasePrivilegeRow,
28
            DatabasePrivilegesDiff,
29
        },
30
        protocol::{
31
            DiffDoesNotApplyError, ListAllPrivilegesError, ListAllPrivilegesResponse,
32
            ListPrivilegesError, ListPrivilegesResponse, ModifyDatabasePrivilegesError,
33
            ModifyPrivilegesResponse,
34
            request_validation::{GroupDenylist, validate_db_or_user_request},
35
        },
36
        types::{DbOrUser, MySQLDatabase, MySQLUser},
37
    },
38
    server::{
39
        common::{create_user_group_matching_regex, try_get_with_binary_fallback},
40
        sql::{
41
            database_operations::unsafe_database_exists, quote_identifier,
42
            user_operations::unsafe_user_exists,
43
        },
44
    },
45
};
46

            
47
// TODO: get by name instead of row tuple position
48

            
49
#[inline]
50
fn get_mysql_row_priv_field(row: &MySqlRow, position: usize) -> Result<bool, sqlx::Error> {
51
    let field = DATABASE_PRIVILEGE_FIELDS[position];
52
    let value = row.try_get(position)?;
53
    if let Some(val) = rev_yn(value) {
54
        Ok(val)
55
    } else {
56
        tracing::warn!(r#"Invalid value for privilege "{}": '{}'"#, field, value);
57
        Ok(false)
58
    }
59
}
60

            
61
impl FromRow<'_, MySqlRow> for DatabasePrivilegeRow {
62
    fn from_row(row: &MySqlRow) -> Result<Self, sqlx::Error> {
63
        Ok(Self {
64
            db: try_get_with_binary_fallback(row, "Db")?.into(),
65
            user: try_get_with_binary_fallback(row, "User")?.into(),
66
            select_priv: get_mysql_row_priv_field(row, 2)?,
67
            insert_priv: get_mysql_row_priv_field(row, 3)?,
68
            update_priv: get_mysql_row_priv_field(row, 4)?,
69
            delete_priv: get_mysql_row_priv_field(row, 5)?,
70
            create_priv: get_mysql_row_priv_field(row, 6)?,
71
            drop_priv: get_mysql_row_priv_field(row, 7)?,
72
            alter_priv: get_mysql_row_priv_field(row, 8)?,
73
            index_priv: get_mysql_row_priv_field(row, 9)?,
74
            create_tmp_table_priv: get_mysql_row_priv_field(row, 10)?,
75
            lock_tables_priv: get_mysql_row_priv_field(row, 11)?,
76
            references_priv: get_mysql_row_priv_field(row, 12)?,
77
        })
78
    }
79
}
80

            
81
// NOTE: this function is unsafe because it does no input validation.
82
/// Get all users + privileges for a single database.
83
async fn unsafe_get_database_privileges(
84
    database_name: &str,
85
    connection: &mut MySqlConnection,
86
) -> Result<Vec<DatabasePrivilegeRow>, sqlx::Error> {
87
    let statement = AssertSqlSafe(format!(
88
        "SELECT {} FROM `db` WHERE `Db` = ?",
89
        DATABASE_PRIVILEGE_FIELDS
90
            .iter()
91
            .map(|field| quote_identifier(field))
92
            .join(","),
93
    ));
94
    let result = sqlx::query_as::<_, DatabasePrivilegeRow>(statement)
95
        .bind(database_name)
96
        .fetch_all(connection)
97
        .await;
98

            
99
    if let Err(e) = &result {
100
        tracing::error!(
101
            "Failed to get database privileges for '{}': {}",
102
            &database_name,
103
            e
104
        );
105
    }
106

            
107
    result
108
}
109

            
110
// NOTE: this function is unsafe because it does no input validation.
111
/// Get all users + privileges for a single database-user pair.
112
pub async fn unsafe_get_database_privileges_for_db_user_pair(
113
    database_name: &MySQLDatabase,
114
    user_name: &MySQLUser,
115
    connection: &mut MySqlConnection,
116
) -> Result<Option<DatabasePrivilegeRow>, sqlx::Error> {
117
    let statement = AssertSqlSafe(format!(
118
        "SELECT {} FROM `db` WHERE `Db` = ? AND `User` = ? AND `Host` = '%'",
119
        DATABASE_PRIVILEGE_FIELDS
120
            .iter()
121
            .map(|field| quote_identifier(field))
122
            .join(","),
123
    ));
124
    let result = sqlx::query_as::<_, DatabasePrivilegeRow>(statement)
125
        .bind(database_name.as_str())
126
        .bind(user_name.as_str())
127
        .fetch_optional(connection)
128
        .await;
129

            
130
    if let Err(e) = &result {
131
        tracing::error!(
132
            "Failed to get database privileges for '{}.{}': {}",
133
            &database_name,
134
            &user_name,
135
            e
136
        );
137
    }
138

            
139
    result
140
}
141

            
142
pub async fn get_databases_privilege_data(
143
    database_names: &[MySQLDatabase],
144
    unix_user: &UnixUser,
145
    connection: &mut MySqlConnection,
146
    _db_is_mariadb: bool,
147
    group_denylist: &GroupDenylist,
148
) -> ListPrivilegesResponse {
149
    let mut results = BTreeMap::new();
150

            
151
    for database_name in database_names.iter().cloned() {
152
        if let Err(err) = validate_db_or_user_request(
153
            &DbOrUser::Database(database_name.to_owned()),
154
            unix_user,
155
            group_denylist,
156
        )
157
        .map_err(ListPrivilegesError::ValidationError)
158
        {
159
            results.insert(database_name, Err(err));
160
            continue;
161
        }
162

            
163
        match unsafe_database_exists(&database_name, connection).await {
164
            Ok(false) => {
165
                results.insert(
166
                    database_name.to_owned(),
167
                    Err(ListPrivilegesError::DatabaseDoesNotExist),
168
                );
169
                continue;
170
            }
171
            Err(e) => {
172
                results.insert(
173
                    database_name.to_owned(),
174
                    Err(ListPrivilegesError::MySqlError(e.to_string())),
175
                );
176
                continue;
177
            }
178
            Ok(true) => {}
179
        }
180

            
181
        let result = unsafe_get_database_privileges(&database_name, connection)
182
            .await
183
            .map_err(|e| ListPrivilegesError::MySqlError(e.to_string()));
184

            
185
        results.insert(database_name.to_owned(), result);
186
    }
187

            
188
    debug_assert!(database_names.len() == results.len());
189

            
190
    results
191
}
192

            
193
/// TODO: make this constant
194
fn get_all_db_privs_query() -> AssertSqlSafe<String> {
195
    AssertSqlSafe(format!(
196
        indoc! {r"
197
            SELECT {} FROM `db` WHERE `db` IN
198
            (SELECT DISTINCT CAST(`SCHEMA_NAME` AS CHAR(64)) AS `database`
199
              FROM `information_schema`.`SCHEMATA`
200
              WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')
201
                AND `SCHEMA_NAME` REGEXP ?)
202
        "},
203
        DATABASE_PRIVILEGE_FIELDS
204
            .iter()
205
            .map(|field| quote_identifier(field))
206
            .join(","),
207
    ))
208
}
209

            
210
/// Get all database + user + privileges pairs that are owned by the current user.
211
pub async fn get_all_database_privileges(
212
    unix_user: &UnixUser,
213
    connection: &mut MySqlConnection,
214
    _db_is_mariadb: bool,
215
    group_denylist: &GroupDenylist,
216
) -> ListAllPrivilegesResponse {
217
    let result = sqlx::query_as::<_, DatabasePrivilegeRow>(get_all_db_privs_query())
218
        .bind(create_user_group_matching_regex(unix_user, group_denylist))
219
        .fetch_all(connection)
220
        .await
221
        .map_err(|e| ListAllPrivilegesError::MySqlError(e.to_string()));
222

            
223
    if let Err(e) = &result {
224
        tracing::error!("Failed to get all database privileges: {:?}", e);
225
    }
226

            
227
    result
228
}
229

            
230
// TODO: make these queries constant strings.
231
async fn unsafe_apply_privilege_diff(
232
    database_privilege_diff: &DatabasePrivilegesDiff,
233
    connection: &mut MySqlConnection,
234
) -> Result<(), sqlx::Error> {
235
    let result = match database_privilege_diff {
236
        DatabasePrivilegesDiff::New(p) => {
237
            let tables = DATABASE_PRIVILEGE_FIELDS
238
                .iter()
239
                .chain(&["Host"])
240
                .map(|field| quote_identifier(field))
241
                .join(",");
242

            
243
            let question_marks =
244
                std::iter::repeat_n("?", DATABASE_PRIVILEGE_FIELDS.len() + 1).join(",");
245

            
246
            let statement = AssertSqlSafe(format!(
247
                "INSERT INTO `db` ({tables}) VALUES ({question_marks})"
248
            ));
249
            sqlx::query(statement)
250
                .bind(p.db.to_string())
251
                .bind(p.user.to_string())
252
                .bind(yn(p.select_priv))
253
                .bind(yn(p.insert_priv))
254
                .bind(yn(p.update_priv))
255
                .bind(yn(p.delete_priv))
256
                .bind(yn(p.create_priv))
257
                .bind(yn(p.drop_priv))
258
                .bind(yn(p.alter_priv))
259
                .bind(yn(p.index_priv))
260
                .bind(yn(p.create_tmp_table_priv))
261
                .bind(yn(p.lock_tables_priv))
262
                .bind(yn(p.references_priv))
263
                .bind("%")
264
                .execute(connection)
265
                .await
266
                .map(|_| ())
267
        }
268
        DatabasePrivilegesDiff::Modified(p) => {
269
            let changes = DATABASE_PRIVILEGE_FIELDS
270
                .iter()
271
                .skip(2) // Skip Db and User fields
272
                .map(|field| {
273
                    format!(
274
                        "{} = COALESCE(?, {})",
275
                        quote_identifier(field),
276
                        quote_identifier(field)
277
                    )
278
                })
279
                .join(",");
280

            
281
            fn change_to_yn(change: DatabasePrivilegeChange) -> &'static str {
282
                match change {
283
                    DatabasePrivilegeChange::YesToNo => "N",
284
                    DatabasePrivilegeChange::NoToYes => "Y",
285
                }
286
            }
287

            
288
            let statement = AssertSqlSafe(format!(
289
                "UPDATE `db` SET {changes} WHERE `Db` = ? AND `User` = ? AND `Host` = ?"
290
            ));
291
            sqlx::query(statement)
292
                .bind(p.select_priv.map(change_to_yn))
293
                .bind(p.insert_priv.map(change_to_yn))
294
                .bind(p.update_priv.map(change_to_yn))
295
                .bind(p.delete_priv.map(change_to_yn))
296
                .bind(p.create_priv.map(change_to_yn))
297
                .bind(p.drop_priv.map(change_to_yn))
298
                .bind(p.alter_priv.map(change_to_yn))
299
                .bind(p.index_priv.map(change_to_yn))
300
                .bind(p.create_tmp_table_priv.map(change_to_yn))
301
                .bind(p.lock_tables_priv.map(change_to_yn))
302
                .bind(p.references_priv.map(change_to_yn))
303
                .bind(p.db.to_string())
304
                .bind(p.user.to_string())
305
                .bind("%")
306
                .execute(connection)
307
                .await
308
                .map(|_| ())
309
        }
310
        DatabasePrivilegesDiff::Deleted(p) => {
311
            sqlx::query("DELETE FROM `db` WHERE `Db` = ? AND `User` = ? AND `Host` = ?")
312
                .bind(p.db.to_string())
313
                .bind(p.user.to_string())
314
                .bind("%")
315
                .execute(connection)
316
                .await
317
                .map(|_| ())
318
        }
319
        DatabasePrivilegesDiff::Noop { .. } => Ok(()),
320
    };
321

            
322
    if let Err(e) = &result {
323
        tracing::error!("Failed to apply database privilege diff: {}", e);
324
    }
325

            
326
    result
327
}
328

            
329
async fn validate_diff(
330
    diff: &DatabasePrivilegesDiff,
331
    connection: &mut MySqlConnection,
332
) -> Result<(), ModifyDatabasePrivilegesError> {
333
    let privilege_row = unsafe_get_database_privileges_for_db_user_pair(
334
        diff.get_database_name(),
335
        diff.get_user_name(),
336
        connection,
337
    )
338
    .await;
339

            
340
    let privilege_row = match privilege_row {
341
        Ok(privilege_row) => privilege_row,
342
        Err(e) => return Err(ModifyDatabasePrivilegesError::MySqlError(e.to_string())),
343
    };
344

            
345
    match diff {
346
        DatabasePrivilegesDiff::New(_) => {
347
            if privilege_row.is_some() {
348
                Err(ModifyDatabasePrivilegesError::DiffDoesNotApply(
349
                    DiffDoesNotApplyError::RowAlreadyExists(
350
                        diff.get_database_name().to_owned(),
351
                        diff.get_user_name().to_owned(),
352
                    ),
353
                ))
354
            } else {
355
                Ok(())
356
            }
357
        }
358
        DatabasePrivilegesDiff::Modified(_) if privilege_row.is_none() => {
359
            Err(ModifyDatabasePrivilegesError::DiffDoesNotApply(
360
                DiffDoesNotApplyError::RowDoesNotExist(
361
                    diff.get_database_name().to_owned(),
362
                    diff.get_user_name().to_owned(),
363
                ),
364
            ))
365
        }
366
        DatabasePrivilegesDiff::Modified(row_diff) => {
367
            let row = privilege_row.unwrap();
368

            
369
            let error_exists = DATABASE_PRIVILEGE_FIELDS
370
                .iter()
371
                .skip(2) // Skip Db and User fields
372
                .any(
373
                    |field| match row_diff.get_privilege_change_by_name(field).unwrap() {
374
                        Some(DatabasePrivilegeChange::YesToNo) => {
375
                            !row.get_privilege_by_name(field).unwrap()
376
                        }
377
                        Some(DatabasePrivilegeChange::NoToYes) => {
378
                            row.get_privilege_by_name(field).unwrap()
379
                        }
380
                        None => false,
381
                    },
382
                );
383

            
384
            if error_exists {
385
                Err(ModifyDatabasePrivilegesError::DiffDoesNotApply(
386
                    DiffDoesNotApplyError::RowPrivilegeChangeDoesNotApply(row_diff.to_owned(), row),
387
                ))
388
            } else {
389
                Ok(())
390
            }
391
        }
392
        DatabasePrivilegesDiff::Deleted(_) => {
393
            if privilege_row.is_none() {
394
                Err(ModifyDatabasePrivilegesError::DiffDoesNotApply(
395
                    DiffDoesNotApplyError::RowDoesNotExist(
396
                        diff.get_database_name().to_owned(),
397
                        diff.get_user_name().to_owned(),
398
                    ),
399
                ))
400
            } else {
401
                Ok(())
402
            }
403
        }
404
        DatabasePrivilegesDiff::Noop { .. } => {
405
            tracing::warn!(
406
                "Server got sent a noop database privilege diff to validate, is the client buggy?"
407
            );
408
            Ok(())
409
        }
410
    }
411
}
412

            
413
/// Uses the result of [`diff_privileges`] to modify privileges in the database.
414
pub async fn apply_privilege_diffs(
415
    database_privilege_diffs: &BTreeSet<DatabasePrivilegesDiff>,
416
    unix_user: &UnixUser,
417
    connection: &mut MySqlConnection,
418
    _db_is_mariadb: bool,
419
    group_denylist: &GroupDenylist,
420
) -> ModifyPrivilegesResponse {
421
    let mut results: BTreeMap<(MySQLDatabase, MySQLUser), _> = BTreeMap::new();
422

            
423
    for diff in database_privilege_diffs {
424
        let key = (
425
            diff.get_database_name().to_owned(),
426
            diff.get_user_name().to_owned(),
427
        );
428
        if let Err(err) = validate_db_or_user_request(
429
            &DbOrUser::Database(diff.get_database_name().to_owned()),
430
            unix_user,
431
            group_denylist,
432
        )
433
        .map_err(ModifyDatabasePrivilegesError::UserValidationError)
434
        {
435
            results.insert(key, Err(err));
436
            continue;
437
        }
438

            
439
        if let Err(err) = validate_db_or_user_request(
440
            &DbOrUser::User(diff.get_user_name().to_owned()),
441
            unix_user,
442
            group_denylist,
443
        )
444
        .map_err(ModifyDatabasePrivilegesError::UserValidationError)
445
        {
446
            results.insert(key, Err(err));
447
            continue;
448
        }
449

            
450
        match unsafe_database_exists(diff.get_database_name(), connection).await {
451
            Ok(false) => {
452
                results.insert(
453
                    key,
454
                    Err(ModifyDatabasePrivilegesError::DatabaseDoesNotExist),
455
                );
456
                continue;
457
            }
458
            Err(e) => {
459
                results.insert(
460
                    key,
461
                    Err(ModifyDatabasePrivilegesError::MySqlError(e.to_string())),
462
                );
463
                continue;
464
            }
465
            Ok(true) => {}
466
        }
467

            
468
        match unsafe_user_exists(diff.get_user_name(), connection).await {
469
            Ok(false) => {
470
                results.insert(key, Err(ModifyDatabasePrivilegesError::UserDoesNotExist));
471
                continue;
472
            }
473
            Err(e) => {
474
                results.insert(
475
                    key,
476
                    Err(ModifyDatabasePrivilegesError::MySqlError(e.to_string())),
477
                );
478
                continue;
479
            }
480
            Ok(true) => {}
481
        }
482

            
483
        if let Err(err) = validate_diff(diff, connection).await {
484
            results.insert(key, Err(err));
485
            continue;
486
        }
487

            
488
        let result = unsafe_apply_privilege_diff(diff, connection)
489
            .await
490
            .map_err(|e| ModifyDatabasePrivilegesError::MySqlError(e.to_string()));
491

            
492
        results.insert(key, result);
493
    }
494

            
495
    if let Err(err) = connection.execute("FLUSH PRIVILEGES").await {
496
        tracing::error!("Failed to flush privileges: {}", err);
497
    }
498

            
499
    results
500
        .into_iter()
501
        .map(|((k1, k2), v)| (k1, (k2, v)))
502
        .into_group_map()
503
        .into_iter()
504
        .map(|(k1, pairs)| {
505
            let inner = pairs.into_iter().collect::<BTreeMap<_, _>>();
506
            (k1, inner)
507
        })
508
        .collect()
509
}