1
use std::collections::BTreeMap;
2

            
3
use sqlx::AssertSqlSafe;
4
use sqlx::MySqlConnection;
5
use sqlx::prelude::*;
6

            
7
use serde::{Deserialize, Serialize};
8

            
9
use crate::core::protocol::CompleteDatabaseNameResponse;
10
use crate::core::protocol::request_validation::GroupDenylist;
11
use crate::core::protocol::request_validation::validate_db_or_user_request;
12
use crate::core::types::DbOrUser;
13
use crate::core::types::MySQLDatabase;
14
use crate::core::types::MySQLUser;
15
use crate::{
16
    core::{
17
        common::UnixUser,
18
        protocol::{
19
            CreateDatabaseError, CreateDatabasesResponse, DropDatabaseError, DropDatabasesResponse,
20
            ListAllDatabasesError, ListAllDatabasesResponse, ListDatabasesError,
21
            ListDatabasesResponse,
22
        },
23
    },
24
    server::{common::create_user_group_matching_regex, sql::quote_identifier},
25
};
26

            
27
// NOTE: this function is unsafe because it does no input validation.
28
pub(super) async fn unsafe_database_exists(
29
    database_name: &str,
30
    connection: &mut MySqlConnection,
31
) -> Result<bool, sqlx::Error> {
32
    let result =
33
        sqlx::query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?")
34
            .bind(database_name)
35
            .fetch_optional(connection)
36
            .await;
37

            
38
    if let Err(err) = &result {
39
        tracing::error!(
40
            "Failed to check if database '{}' exists: {:?}",
41
            &database_name,
42
            err
43
        );
44
    }
45

            
46
    Ok(result?.is_some())
47
}
48

            
49
pub async fn complete_database_name(
50
    database_prefix: &str,
51
    unix_user: &UnixUser,
52
    connection: &mut MySqlConnection,
53
    _db_is_mariadb: bool,
54
    group_denylist: &GroupDenylist,
55
) -> CompleteDatabaseNameResponse {
56
    let result = sqlx::query(
57
        r"
58
          SELECT CAST(`SCHEMA_NAME` AS CHAR(64)) AS `database`
59
          FROM `information_schema`.`SCHEMATA`
60
          WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')
61
            AND `SCHEMA_NAME` REGEXP ?
62
            AND `SCHEMA_NAME` LIKE ?
63
        ",
64
    )
65
    .bind(create_user_group_matching_regex(unix_user, group_denylist))
66
    .bind(format!("{database_prefix}%"))
67
    .fetch_all(connection)
68
    .await;
69

            
70
    match result {
71
        Ok(rows) => rows
72
            .into_iter()
73
            .filter_map(|row| {
74
                let database: String = row.try_get("database").ok()?;
75
                Some(database.into())
76
            })
77
            .collect(),
78
        Err(err) => {
79
            tracing::error!(
80
                "Failed to complete database name for prefix '{}' and user '{}': {:?}",
81
                database_prefix,
82
                unix_user.username,
83
                err
84
            );
85
            vec![]
86
        }
87
    }
88
}
89

            
90
pub async fn create_databases(
91
    database_names: &[MySQLDatabase],
92
    unix_user: &UnixUser,
93
    connection: &mut MySqlConnection,
94
    _db_is_mariadb: bool,
95
    group_denylist: &GroupDenylist,
96
) -> CreateDatabasesResponse {
97
    let mut results = BTreeMap::new();
98

            
99
    for database_name in database_names.iter().cloned() {
100
        if let Err(err) = validate_db_or_user_request(
101
            &DbOrUser::Database(database_name.clone()),
102
            unix_user,
103
            group_denylist,
104
        )
105
        .map_err(CreateDatabaseError::ValidationError)
106
        {
107
            results.insert(database_name.clone(), Err(err));
108
            continue;
109
        }
110

            
111
        match unsafe_database_exists(&database_name, &mut *connection).await {
112
            Ok(true) => {
113
                results.insert(
114
                    database_name.clone(),
115
                    Err(CreateDatabaseError::DatabaseAlreadyExists),
116
                );
117
                continue;
118
            }
119
            Err(err) => {
120
                results.insert(
121
                    database_name.clone(),
122
                    Err(CreateDatabaseError::MySqlError(err.to_string())),
123
                );
124
                continue;
125
            }
126
            _ => {}
127
        }
128

            
129
        let statement = AssertSqlSafe(format!(
130
            "CREATE DATABASE {}",
131
            quote_identifier(&database_name)
132
        ));
133
        let result = sqlx::query(statement)
134
            .execute(&mut *connection)
135
            .await
136
            .map(|_| ())
137
            .map_err(|err| CreateDatabaseError::MySqlError(err.to_string()));
138

            
139
        if let Err(err) = &result {
140
            tracing::error!("Failed to create database '{}': {:?}", &database_name, err);
141
        }
142

            
143
        results.insert(database_name, result);
144
    }
145

            
146
    results
147
}
148

            
149
pub async fn drop_databases(
150
    database_names: &[MySQLDatabase],
151
    unix_user: &UnixUser,
152
    connection: &mut MySqlConnection,
153
    _db_is_mariadb: bool,
154
    group_denylist: &GroupDenylist,
155
) -> DropDatabasesResponse {
156
    let mut results = BTreeMap::new();
157

            
158
    for database_name in database_names.iter().cloned() {
159
        if let Err(err) = validate_db_or_user_request(
160
            &DbOrUser::Database(database_name.clone()),
161
            unix_user,
162
            group_denylist,
163
        )
164
        .map_err(DropDatabaseError::ValidationError)
165
        {
166
            results.insert(database_name.clone(), Err(err));
167
            continue;
168
        }
169

            
170
        match unsafe_database_exists(&database_name, &mut *connection).await {
171
            Ok(false) => {
172
                results.insert(
173
                    database_name.clone(),
174
                    Err(DropDatabaseError::DatabaseDoesNotExist),
175
                );
176
                continue;
177
            }
178
            Err(err) => {
179
                results.insert(
180
                    database_name.clone(),
181
                    Err(DropDatabaseError::MySqlError(err.to_string())),
182
                );
183
                continue;
184
            }
185
            _ => {}
186
        }
187

            
188
        let statement = AssertSqlSafe(format!(
189
            "DROP DATABASE {}",
190
            quote_identifier(&database_name)
191
        ));
192
        let result = sqlx::query(statement)
193
            .execute(&mut *connection)
194
            .await
195
            .map(|_| ())
196
            .map_err(|err| DropDatabaseError::MySqlError(err.to_string()));
197

            
198
        if let Err(err) = &result {
199
            tracing::error!("Failed to drop database '{}': {:?}", &database_name, err);
200
        }
201

            
202
        results.insert(database_name, result);
203
    }
204

            
205
    results
206
}
207

            
208
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
209
pub struct DatabaseRow {
210
    pub database: MySQLDatabase,
211
    pub tables: Vec<String>,
212
    pub users: Vec<MySQLUser>,
213
    pub collation: Option<String>,
214
    pub character_set: Option<String>,
215
    pub size_bytes: u64,
216
}
217

            
218
impl FromRow<'_, sqlx::mysql::MySqlRow> for DatabaseRow {
219
    fn from_row(row: &sqlx::mysql::MySqlRow) -> Result<Self, sqlx::Error> {
220
        Ok(DatabaseRow {
221
            database: row.try_get::<String, _>("database")?.into(),
222
            tables: {
223
                let s: Option<String> = row.try_get("tables")?;
224
                s.and_then(|s| {
225
                    if s.is_empty() {
226
                        None
227
                    } else {
228
                        Some(s.split(',').map(std::borrow::ToOwned::to_owned).collect())
229
                    }
230
                })
231
                .unwrap_or_default()
232
            },
233
            users: {
234
                let s: Option<String> = row.try_get("users")?;
235
                s.and_then(|s| {
236
                    if s.is_empty() {
237
                        None
238
                    } else {
239
                        Some(s.split(',').map(|s| s.to_owned().into()).collect())
240
                    }
241
                })
242
                .unwrap_or_default()
243
            },
244
            collation: row.try_get::<Option<String>, _>("collation")?,
245
            character_set: row.try_get::<Option<String>, _>("character_set")?,
246
            size_bytes: row.try_get::<u64, _>("size_bytes")?,
247
        })
248
    }
249
}
250

            
251
pub async fn list_databases(
252
    database_names: &[MySQLDatabase],
253
    unix_user: &UnixUser,
254
    connection: &mut MySqlConnection,
255
    _db_is_mariadb: bool,
256
    group_denylist: &GroupDenylist,
257
) -> ListDatabasesResponse {
258
    let mut results = BTreeMap::new();
259

            
260
    for database_name in database_names.iter().cloned() {
261
        if let Err(err) = validate_db_or_user_request(
262
            &DbOrUser::Database(database_name.clone()),
263
            unix_user,
264
            group_denylist,
265
        )
266
        .map_err(ListDatabasesError::ValidationError)
267
        {
268
            results.insert(database_name.clone(), Err(err));
269
            continue;
270
        }
271

            
272
        let result = sqlx::query_as::<_, DatabaseRow>(
273
            r"
274
                SELECT
275
                    CAST(s.SCHEMA_NAME AS CHAR(64)) AS `database`,
276
                    t.tables,
277
                    u.users,
278
                    s.DEFAULT_COLLATION_NAME AS `collation`,
279
                    s.DEFAULT_CHARACTER_SET_NAME AS `character_set`,
280
                    CAST(COALESCE(t.size_bytes, 0) AS UNSIGNED) AS `size_bytes`
281
                FROM information_schema.SCHEMATA s
282

            
283
                LEFT JOIN (
284
                    SELECT
285
                        TABLE_SCHEMA,
286
                        GROUP_CONCAT(
287
                            DISTINCT CAST(TABLE_NAME AS CHAR(64))
288
                            ORDER BY TABLE_NAME
289
                            SEPARATOR ','
290
                        ) AS tables,
291
                        SUM(DATA_LENGTH + INDEX_LENGTH) AS size_bytes
292
                    FROM information_schema.TABLES
293
                    WHERE TABLE_SCHEMA = ?
294
                    GROUP BY TABLE_SCHEMA
295
                ) t
296
                    ON t.TABLE_SCHEMA = s.SCHEMA_NAME
297

            
298
                LEFT JOIN (
299
                    SELECT
300
                        DB,
301
                        GROUP_CONCAT(
302
                            DISTINCT CAST(User AS CHAR(64))
303
                            ORDER BY User
304
                            SEPARATOR ','
305
                        ) AS users
306
                    FROM mysql.db
307
                    WHERE DB = ?
308
                    GROUP BY DB
309
                ) u
310
                    ON u.DB = s.SCHEMA_NAME
311

            
312
                WHERE s.SCHEMA_NAME = ?;
313
            ",
314
        )
315
        .bind(database_name.to_string())
316
        .bind(database_name.to_string())
317
        .bind(database_name.to_string())
318
        .fetch_optional(&mut *connection)
319
        .await
320
        .map_err(|err| ListDatabasesError::MySqlError(err.to_string()))
321
        .and_then(|database| {
322
            database.map_or_else(|| Err(ListDatabasesError::DatabaseDoesNotExist), Ok)
323
        });
324

            
325
        if let Err(err) = &result {
326
            tracing::error!("Failed to list database '{}': {:?}", &database_name, err);
327
        }
328

            
329
        // TODO: should we assert that the users are also owned by the unix_user from the request?
330

            
331
        results.insert(database_name, result);
332
    }
333

            
334
    results
335
}
336

            
337
pub async fn list_all_databases_for_user(
338
    unix_user: &UnixUser,
339
    connection: &mut MySqlConnection,
340
    _db_is_mariadb: bool,
341
    group_denylist: &GroupDenylist,
342
) -> ListAllDatabasesResponse {
343
    let result = sqlx::query_as::<_, DatabaseRow>(
344
        r"
345
          SELECT
346
              CAST(s.SCHEMA_NAME AS CHAR(64)) AS `database`,
347
              t.tables,
348
              u.users,
349
              s.DEFAULT_COLLATION_NAME AS collation,
350
              s.DEFAULT_CHARACTER_SET_NAME AS character_set,
351
              CAST(COALESCE(t.size_bytes, 0) AS UNSIGNED) AS size_bytes
352
          FROM information_schema.SCHEMATA s
353

            
354
          LEFT JOIN (
355
              SELECT
356
                  TABLE_SCHEMA,
357
                  GROUP_CONCAT(
358
                      DISTINCT CAST(TABLE_NAME AS CHAR(64))
359
                      ORDER BY TABLE_NAME
360
                      SEPARATOR ','
361
                  ) AS tables,
362
                  SUM(DATA_LENGTH + INDEX_LENGTH) AS size_bytes
363
              FROM information_schema.TABLES
364
              WHERE TABLE_SCHEMA REGEXP ?
365
              GROUP BY TABLE_SCHEMA
366
          ) t
367
              ON t.TABLE_SCHEMA = s.SCHEMA_NAME
368

            
369
          LEFT JOIN (
370
              SELECT
371
                  DB,
372
                  GROUP_CONCAT(
373
                      DISTINCT CAST(User AS CHAR(64))
374
                      ORDER BY User
375
                      SEPARATOR ','
376
                  ) AS users
377
              FROM mysql.db
378
              WHERE DB REGEXP ?
379
              GROUP BY DB
380
          ) u
381
              ON u.DB = s.SCHEMA_NAME
382

            
383
          WHERE s.SCHEMA_NAME REGEXP ?
384
          AND s.SCHEMA_NAME NOT IN (
385
              'information_schema',
386
              'performance_schema',
387
              'mysql',
388
              'sys'
389
          )
390

            
391
          ORDER BY s.SCHEMA_NAME
392
        ",
393
    )
394
    .bind(create_user_group_matching_regex(unix_user, group_denylist))
395
    .bind(create_user_group_matching_regex(unix_user, group_denylist))
396
    .bind(create_user_group_matching_regex(unix_user, group_denylist))
397
    .fetch_all(connection)
398
    .await
399
    .map_err(|err| ListAllDatabasesError::MySqlError(err.to_string()));
400

            
401
    // TODO: should we assert that the users are also owned by the unix_user from the request?
402

            
403
    if let Err(err) = &result {
404
        tracing::error!(
405
            "Failed to list databases for user '{}': {:?}",
406
            unix_user.username,
407
            err
408
        );
409
    }
410

            
411
    result
412
}