mysqladm/server/sql/
database_operations.rs

1use std::collections::BTreeMap;
2
3use sqlx::MySqlConnection;
4use sqlx::prelude::*;
5
6use serde::{Deserialize, Serialize};
7
8use crate::core::types::MySQLDatabase;
9use crate::{
10    core::{
11        common::UnixUser,
12        protocol::{
13            CreateDatabaseError, CreateDatabasesResponse, DropDatabaseError, DropDatabasesResponse,
14            ListAllDatabasesError, ListAllDatabasesResponse, ListDatabasesError,
15            ListDatabasesResponse,
16        },
17    },
18    server::{
19        common::create_user_group_matching_regex,
20        input_sanitization::{quote_identifier, validate_name, validate_ownership_by_unix_user},
21    },
22};
23
24// NOTE: this function is unsafe because it does no input validation.
25pub(super) async fn unsafe_database_exists(
26    database_name: &str,
27    connection: &mut MySqlConnection,
28) -> Result<bool, sqlx::Error> {
29    let result =
30        sqlx::query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?")
31            .bind(database_name)
32            .fetch_optional(connection)
33            .await;
34
35    if let Err(err) = &result {
36        log::error!(
37            "Failed to check if database '{}' exists: {:?}",
38            &database_name,
39            err
40        );
41    }
42
43    Ok(result?.is_some())
44}
45
46pub async fn create_databases(
47    database_names: Vec<MySQLDatabase>,
48    unix_user: &UnixUser,
49    connection: &mut MySqlConnection,
50) -> CreateDatabasesResponse {
51    let mut results = BTreeMap::new();
52
53    for database_name in database_names {
54        if let Err(err) = validate_name(&database_name) {
55            results.insert(
56                database_name.to_owned(),
57                Err(CreateDatabaseError::SanitizationError(err)),
58            );
59            continue;
60        }
61
62        if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) {
63            results.insert(
64                database_name.to_owned(),
65                Err(CreateDatabaseError::OwnershipError(err)),
66            );
67            continue;
68        }
69
70        match unsafe_database_exists(&database_name, &mut *connection).await {
71            Ok(true) => {
72                results.insert(
73                    database_name.to_owned(),
74                    Err(CreateDatabaseError::DatabaseAlreadyExists),
75                );
76                continue;
77            }
78            Err(err) => {
79                results.insert(
80                    database_name.to_owned(),
81                    Err(CreateDatabaseError::MySqlError(err.to_string())),
82                );
83                continue;
84            }
85            _ => {}
86        }
87
88        let result =
89            sqlx::query(format!("CREATE DATABASE {}", quote_identifier(&database_name)).as_str())
90                .execute(&mut *connection)
91                .await
92                .map(|_| ())
93                .map_err(|err| CreateDatabaseError::MySqlError(err.to_string()));
94
95        if let Err(err) = &result {
96            log::error!("Failed to create database '{}': {:?}", &database_name, err);
97        }
98
99        results.insert(database_name, result);
100    }
101
102    results
103}
104
105pub async fn drop_databases(
106    database_names: Vec<MySQLDatabase>,
107    unix_user: &UnixUser,
108    connection: &mut MySqlConnection,
109) -> DropDatabasesResponse {
110    let mut results = BTreeMap::new();
111
112    for database_name in database_names {
113        if let Err(err) = validate_name(&database_name) {
114            results.insert(
115                database_name.to_owned(),
116                Err(DropDatabaseError::SanitizationError(err)),
117            );
118            continue;
119        }
120
121        if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) {
122            results.insert(
123                database_name.to_owned(),
124                Err(DropDatabaseError::OwnershipError(err)),
125            );
126            continue;
127        }
128
129        match unsafe_database_exists(&database_name, &mut *connection).await {
130            Ok(false) => {
131                results.insert(
132                    database_name.to_owned(),
133                    Err(DropDatabaseError::DatabaseDoesNotExist),
134                );
135                continue;
136            }
137            Err(err) => {
138                results.insert(
139                    database_name.to_owned(),
140                    Err(DropDatabaseError::MySqlError(err.to_string())),
141                );
142                continue;
143            }
144            _ => {}
145        }
146
147        let result =
148            sqlx::query(format!("DROP DATABASE {}", quote_identifier(&database_name)).as_str())
149                .execute(&mut *connection)
150                .await
151                .map(|_| ())
152                .map_err(|err| DropDatabaseError::MySqlError(err.to_string()));
153
154        if let Err(err) = &result {
155            log::error!("Failed to drop database '{}': {:?}", &database_name, err);
156        }
157
158        results.insert(database_name, result);
159    }
160
161    results
162}
163
164#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
165pub struct DatabaseRow {
166    pub database: MySQLDatabase,
167}
168
169impl FromRow<'_, sqlx::mysql::MySqlRow> for DatabaseRow {
170    fn from_row(row: &sqlx::mysql::MySqlRow) -> Result<Self, sqlx::Error> {
171        Ok(DatabaseRow {
172            database: row.try_get::<String, _>("database")?.into(),
173        })
174    }
175}
176
177pub async fn list_databases(
178    database_names: Vec<MySQLDatabase>,
179    unix_user: &UnixUser,
180    connection: &mut MySqlConnection,
181) -> ListDatabasesResponse {
182    let mut results = BTreeMap::new();
183
184    for database_name in database_names {
185        if let Err(err) = validate_name(&database_name) {
186            results.insert(
187                database_name.to_owned(),
188                Err(ListDatabasesError::SanitizationError(err)),
189            );
190            continue;
191        }
192
193        if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) {
194            results.insert(
195                database_name.to_owned(),
196                Err(ListDatabasesError::OwnershipError(err)),
197            );
198            continue;
199        }
200
201        let result = sqlx::query_as::<_, DatabaseRow>(
202            r#"
203          SELECT `SCHEMA_NAME` AS `database`
204          FROM `information_schema`.`SCHEMATA`
205          WHERE `SCHEMA_NAME` = ?
206        "#,
207        )
208        .bind(database_name.to_string())
209        .fetch_optional(&mut *connection)
210        .await
211        .map_err(|err| ListDatabasesError::MySqlError(err.to_string()))
212        .and_then(|database| {
213            database
214                .map(Ok)
215                .unwrap_or_else(|| Err(ListDatabasesError::DatabaseDoesNotExist))
216        });
217
218        if let Err(err) = &result {
219            log::error!("Failed to list database '{}': {:?}", &database_name, err);
220        }
221
222        results.insert(database_name, result);
223    }
224
225    results
226}
227
228pub async fn list_all_databases_for_user(
229    unix_user: &UnixUser,
230    connection: &mut MySqlConnection,
231) -> ListAllDatabasesResponse {
232    let result = sqlx::query_as::<_, DatabaseRow>(
233        r#"
234          SELECT `SCHEMA_NAME` AS `database`
235          FROM `information_schema`.`SCHEMATA`
236          WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')
237            AND `SCHEMA_NAME` REGEXP ?
238        "#,
239    )
240    .bind(create_user_group_matching_regex(unix_user))
241    .fetch_all(connection)
242    .await
243    .map_err(|err| ListAllDatabasesError::MySqlError(err.to_string()));
244
245    if let Err(err) = &result {
246        log::error!(
247            "Failed to list databases for user '{}': {:?}",
248            unix_user.username,
249            err
250        );
251    }
252
253    result
254}