1
use std::collections::BTreeMap;
2

            
3
use sqlx::MySqlConnection;
4
use sqlx::prelude::*;
5

            
6
use serde::{Deserialize, Serialize};
7

            
8
use crate::core::protocol::CompleteDatabaseNameResponse;
9
use crate::core::protocol::request_validation::GroupDenylist;
10
use crate::core::protocol::request_validation::validate_db_or_user_request;
11
use crate::core::types::DbOrUser;
12
use crate::core::types::MySQLDatabase;
13
use crate::core::types::MySQLUser;
14
use crate::{
15
    core::{
16
        common::UnixUser,
17
        protocol::{
18
            CreateDatabaseError, CreateDatabasesResponse, DropDatabaseError, DropDatabasesResponse,
19
            ListAllDatabasesError, ListAllDatabasesResponse, ListDatabasesError,
20
            ListDatabasesResponse,
21
        },
22
    },
23
    server::{common::create_user_group_matching_regex, sql::quote_identifier},
24
};
25

            
26
// NOTE: this function is unsafe because it does no input validation.
27
pub(super) async fn unsafe_database_exists(
28
    database_name: &str,
29
    connection: &mut MySqlConnection,
30
) -> Result<bool, sqlx::Error> {
31
    let result =
32
        sqlx::query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?")
33
            .bind(database_name)
34
            .fetch_optional(connection)
35
            .await;
36

            
37
    if let Err(err) = &result {
38
        tracing::error!(
39
            "Failed to check if database '{}' exists: {:?}",
40
            &database_name,
41
            err
42
        );
43
    }
44

            
45
    Ok(result?.is_some())
46
}
47

            
48
pub async fn complete_database_name(
49
    database_prefix: &str,
50
    unix_user: &UnixUser,
51
    connection: &mut MySqlConnection,
52
    _db_is_mariadb: bool,
53
    group_denylist: &GroupDenylist,
54
) -> CompleteDatabaseNameResponse {
55
    let result = sqlx::query(
56
        r"
57
          SELECT CAST(`SCHEMA_NAME` AS CHAR(64)) AS `database`
58
          FROM `information_schema`.`SCHEMATA`
59
          WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')
60
            AND `SCHEMA_NAME` REGEXP ?
61
            AND `SCHEMA_NAME` LIKE ?
62
        ",
63
    )
64
    .bind(create_user_group_matching_regex(unix_user, group_denylist))
65
    .bind(format!("{database_prefix}%"))
66
    .fetch_all(connection)
67
    .await;
68

            
69
    match result {
70
        Ok(rows) => rows
71
            .into_iter()
72
            .filter_map(|row| {
73
                let database: String = row.try_get("database").ok()?;
74
                Some(database.into())
75
            })
76
            .collect(),
77
        Err(err) => {
78
            tracing::error!(
79
                "Failed to complete database name for prefix '{}' and user '{}': {:?}",
80
                database_prefix,
81
                unix_user.username,
82
                err
83
            );
84
            vec![]
85
        }
86
    }
87
}
88

            
89
pub async fn create_databases(
90
    database_names: &[MySQLDatabase],
91
    unix_user: &UnixUser,
92
    connection: &mut MySqlConnection,
93
    _db_is_mariadb: bool,
94
    group_denylist: &GroupDenylist,
95
) -> CreateDatabasesResponse {
96
    let mut results = BTreeMap::new();
97

            
98
    for database_name in database_names.iter().cloned() {
99
        if let Err(err) = validate_db_or_user_request(
100
            &DbOrUser::Database(database_name.clone()),
101
            unix_user,
102
            group_denylist,
103
        )
104
        .map_err(CreateDatabaseError::ValidationError)
105
        {
106
            results.insert(database_name.clone(), Err(err));
107
            continue;
108
        }
109

            
110
        match unsafe_database_exists(&database_name, &mut *connection).await {
111
            Ok(true) => {
112
                results.insert(
113
                    database_name.clone(),
114
                    Err(CreateDatabaseError::DatabaseAlreadyExists),
115
                );
116
                continue;
117
            }
118
            Err(err) => {
119
                results.insert(
120
                    database_name.clone(),
121
                    Err(CreateDatabaseError::MySqlError(err.to_string())),
122
                );
123
                continue;
124
            }
125
            _ => {}
126
        }
127

            
128
        let result =
129
            sqlx::query(format!("CREATE DATABASE {}", quote_identifier(&database_name)).as_str())
130
                .execute(&mut *connection)
131
                .await
132
                .map(|_| ())
133
                .map_err(|err| CreateDatabaseError::MySqlError(err.to_string()));
134

            
135
        if let Err(err) = &result {
136
            tracing::error!("Failed to create database '{}': {:?}", &database_name, err);
137
        }
138

            
139
        results.insert(database_name, result);
140
    }
141

            
142
    results
143
}
144

            
145
pub async fn drop_databases(
146
    database_names: &[MySQLDatabase],
147
    unix_user: &UnixUser,
148
    connection: &mut MySqlConnection,
149
    _db_is_mariadb: bool,
150
    group_denylist: &GroupDenylist,
151
) -> DropDatabasesResponse {
152
    let mut results = BTreeMap::new();
153

            
154
    for database_name in database_names.iter().cloned() {
155
        if let Err(err) = validate_db_or_user_request(
156
            &DbOrUser::Database(database_name.clone()),
157
            unix_user,
158
            group_denylist,
159
        )
160
        .map_err(DropDatabaseError::ValidationError)
161
        {
162
            results.insert(database_name.clone(), Err(err));
163
            continue;
164
        }
165

            
166
        match unsafe_database_exists(&database_name, &mut *connection).await {
167
            Ok(false) => {
168
                results.insert(
169
                    database_name.clone(),
170
                    Err(DropDatabaseError::DatabaseDoesNotExist),
171
                );
172
                continue;
173
            }
174
            Err(err) => {
175
                results.insert(
176
                    database_name.clone(),
177
                    Err(DropDatabaseError::MySqlError(err.to_string())),
178
                );
179
                continue;
180
            }
181
            _ => {}
182
        }
183

            
184
        let result =
185
            sqlx::query(format!("DROP DATABASE {}", quote_identifier(&database_name)).as_str())
186
                .execute(&mut *connection)
187
                .await
188
                .map(|_| ())
189
                .map_err(|err| DropDatabaseError::MySqlError(err.to_string()));
190

            
191
        if let Err(err) = &result {
192
            tracing::error!("Failed to drop database '{}': {:?}", &database_name, err);
193
        }
194

            
195
        results.insert(database_name, result);
196
    }
197

            
198
    results
199
}
200

            
201
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
202
pub struct DatabaseRow {
203
    pub database: MySQLDatabase,
204
    pub tables: Vec<String>,
205
    pub users: Vec<MySQLUser>,
206
    pub collation: Option<String>,
207
    pub character_set: Option<String>,
208
    pub size_bytes: u64,
209
}
210

            
211
impl FromRow<'_, sqlx::mysql::MySqlRow> for DatabaseRow {
212
    fn from_row(row: &sqlx::mysql::MySqlRow) -> Result<Self, sqlx::Error> {
213
        Ok(DatabaseRow {
214
            database: row.try_get::<String, _>("database")?.into(),
215
            tables: {
216
                let s: Option<String> = row.try_get("tables")?;
217
                s.and_then(|s| {
218
                    if s.is_empty() {
219
                        None
220
                    } else {
221
                        Some(s.split(',').map(std::borrow::ToOwned::to_owned).collect())
222
                    }
223
                })
224
                .unwrap_or_default()
225
            },
226
            users: {
227
                let s: Option<String> = row.try_get("users")?;
228
                s.and_then(|s| {
229
                    if s.is_empty() {
230
                        None
231
                    } else {
232
                        Some(s.split(',').map(|s| s.to_owned().into()).collect())
233
                    }
234
                })
235
                .unwrap_or_default()
236
            },
237
            collation: row.try_get::<Option<String>, _>("collation")?,
238
            character_set: row.try_get::<Option<String>, _>("character_set")?,
239
            size_bytes: row.try_get::<u64, _>("size_bytes")?,
240
        })
241
    }
242
}
243

            
244
pub async fn list_databases(
245
    database_names: &[MySQLDatabase],
246
    unix_user: &UnixUser,
247
    connection: &mut MySqlConnection,
248
    _db_is_mariadb: bool,
249
    group_denylist: &GroupDenylist,
250
) -> ListDatabasesResponse {
251
    let mut results = BTreeMap::new();
252

            
253
    for database_name in database_names.iter().cloned() {
254
        if let Err(err) = validate_db_or_user_request(
255
            &DbOrUser::Database(database_name.clone()),
256
            unix_user,
257
            group_denylist,
258
        )
259
        .map_err(ListDatabasesError::ValidationError)
260
        {
261
            results.insert(database_name.clone(), Err(err));
262
            continue;
263
        }
264

            
265
        let result = sqlx::query_as::<_, DatabaseRow>(
266
            r"
267
                SELECT
268
                  CAST(`information_schema`.`SCHEMATA`.`SCHEMA_NAME` AS CHAR(64)) AS `database`,
269
                  GROUP_CONCAT(DISTINCT CAST(`information_schema`.`TABLES`.`TABLE_NAME` AS CHAR(64)) SEPARATOR ',') AS `tables`,
270
                  GROUP_CONCAT(DISTINCT CAST(`mysql`.`db`.`User` AS CHAR(64)) SEPARATOR ',') AS `users`,
271
                  MAX(`information_schema`.`SCHEMATA`.`DEFAULT_COLLATION_NAME`) AS `collation`,
272
                  MAX(`information_schema`.`SCHEMATA`.`DEFAULT_CHARACTER_SET_NAME`) AS `character_set`,
273
                  CAST(IFNULL(
274
                    SUM(`information_schema`.`TABLES`.`DATA_LENGTH` + `information_schema`.`TABLES`.`INDEX_LENGTH`),
275
                    0
276
                  ) AS UNSIGNED INTEGER) AS `size_bytes`
277
                FROM `information_schema`.`SCHEMATA`
278
                LEFT OUTER JOIN `information_schema`.`TABLES`
279
                  ON `information_schema`.`SCHEMATA`.`SCHEMA_NAME` = `TABLES`.`TABLE_SCHEMA`
280
                LEFT OUTER JOIN `mysql`.`db`
281
                  ON `information_schema`.`SCHEMATA`.`SCHEMA_NAME` = `mysql`.`db`.`DB`
282
                WHERE `information_schema`.`SCHEMATA`.`SCHEMA_NAME` = ?
283
                GROUP BY `information_schema`.`SCHEMATA`.`SCHEMA_NAME`
284
            ",
285

            
286
        )
287
        .bind(database_name.to_string())
288
        .fetch_optional(&mut *connection)
289
        .await
290
        .map_err(|err| ListDatabasesError::MySqlError(err.to_string()))
291
        .and_then(|database| {
292
            database.map_or_else(|| Err(ListDatabasesError::DatabaseDoesNotExist), Ok)
293
        });
294

            
295
        if let Err(err) = &result {
296
            tracing::error!("Failed to list database '{}': {:?}", &database_name, err);
297
        }
298

            
299
        // TODO: should we assert that the users are also owned by the unix_user from the request?
300

            
301
        results.insert(database_name, result);
302
    }
303

            
304
    results
305
}
306

            
307
pub async fn list_all_databases_for_user(
308
    unix_user: &UnixUser,
309
    connection: &mut MySqlConnection,
310
    _db_is_mariadb: bool,
311
    group_denylist: &GroupDenylist,
312
) -> ListAllDatabasesResponse {
313
    let result = sqlx::query_as::<_, DatabaseRow>(
314
        r"
315
          SELECT
316
            CAST(`information_schema`.`SCHEMATA`.`SCHEMA_NAME` AS CHAR(64)) AS `database`,
317
            GROUP_CONCAT(DISTINCT CAST(`information_schema`.`TABLES`.`TABLE_NAME` AS CHAR(64)) SEPARATOR ',') AS `tables`,
318
            GROUP_CONCAT(DISTINCT CAST(`mysql`.`db`.`User` AS CHAR(64)) SEPARATOR ',') AS `users`,
319
            MAX(`information_schema`.`SCHEMATA`.`DEFAULT_COLLATION_NAME`) AS `collation`,
320
            MAX(`information_schema`.`SCHEMATA`.`DEFAULT_CHARACTER_SET_NAME`) AS `character_set`,
321
            CAST(IFNULL(
322
              SUM(`information_schema`.`TABLES`.`DATA_LENGTH` + `information_schema`.`TABLES`.`INDEX_LENGTH`),
323
              0
324
            ) AS UNSIGNED INTEGER) AS `size_bytes`
325
          FROM `information_schema`.`SCHEMATA`
326
          LEFT OUTER JOIN `information_schema`.`TABLES`
327
            ON `information_schema`.`SCHEMATA`.`SCHEMA_NAME` = `TABLES`.`TABLE_SCHEMA`
328
          LEFT OUTER JOIN `mysql`.`db`
329
            ON `information_schema`.`SCHEMATA`.`SCHEMA_NAME` = `mysql`.`db`.`DB`
330
          WHERE `information_schema`.`SCHEMATA`.`SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')
331
            AND `information_schema`.`SCHEMATA`.`SCHEMA_NAME` REGEXP ?
332
          GROUP BY `information_schema`.`SCHEMATA`.`SCHEMA_NAME`
333
        ",
334
    )
335
    .bind(create_user_group_matching_regex(unix_user, group_denylist))
336
    .fetch_all(connection)
337
    .await
338
    .map_err(|err| ListAllDatabasesError::MySqlError(err.to_string()));
339

            
340
    // TODO: should we assert that the users are also owned by the unix_user from the request?
341

            
342
    if let Err(err) = &result {
343
        tracing::error!(
344
            "Failed to list databases for user '{}': {:?}",
345
            unix_user.username,
346
            err
347
        );
348
    }
349

            
350
    result
351
}