1
use std::collections::BTreeMap;
2

            
3
use sqlx::MySqlConnection;
4
use sqlx::prelude::*;
5

            
6
use serde::{Deserialize, Serialize};
7

            
8
use crate::core::types::MySQLDatabase;
9
use crate::{
10
    core::{
11
        common::UnixUser,
12
        protocol::{
13
            CreateDatabaseError, CreateDatabasesResponse, DropDatabaseError, DropDatabasesResponse,
14
            ListAllDatabasesError, ListAllDatabasesResponse, ListDatabasesError,
15
            ListDatabasesResponse,
16
        },
17
    },
18
    server::{
19
        common::create_user_group_matching_regex,
20
        input_sanitization::{quote_identifier, validate_name, validate_ownership_by_unix_user},
21
    },
22
};
23

            
24
// NOTE: this function is unsafe because it does no input validation.
25
pub(super) async fn unsafe_database_exists(
26
    database_name: &str,
27
    connection: &mut MySqlConnection,
28
) -> Result<bool, sqlx::Error> {
29
    let result =
30
        sqlx::query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?")
31
            .bind(database_name)
32
            .fetch_optional(connection)
33
            .await;
34

            
35
    if let Err(err) = &result {
36
        log::error!(
37
            "Failed to check if database '{}' exists: {:?}",
38
            &database_name,
39
            err
40
        );
41
    }
42

            
43
    Ok(result?.is_some())
44
}
45

            
46
pub async fn create_databases(
47
    database_names: Vec<MySQLDatabase>,
48
    unix_user: &UnixUser,
49
    connection: &mut MySqlConnection,
50
) -> CreateDatabasesResponse {
51
    let mut results = BTreeMap::new();
52

            
53
    for database_name in database_names {
54
        if let Err(err) = validate_name(&database_name) {
55
            results.insert(
56
                database_name.to_owned(),
57
                Err(CreateDatabaseError::SanitizationError(err)),
58
            );
59
            continue;
60
        }
61

            
62
        if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) {
63
            results.insert(
64
                database_name.to_owned(),
65
                Err(CreateDatabaseError::OwnershipError(err)),
66
            );
67
            continue;
68
        }
69

            
70
        match unsafe_database_exists(&database_name, &mut *connection).await {
71
            Ok(true) => {
72
                results.insert(
73
                    database_name.to_owned(),
74
                    Err(CreateDatabaseError::DatabaseAlreadyExists),
75
                );
76
                continue;
77
            }
78
            Err(err) => {
79
                results.insert(
80
                    database_name.to_owned(),
81
                    Err(CreateDatabaseError::MySqlError(err.to_string())),
82
                );
83
                continue;
84
            }
85
            _ => {}
86
        }
87

            
88
        let result =
89
            sqlx::query(format!("CREATE DATABASE {}", quote_identifier(&database_name)).as_str())
90
                .execute(&mut *connection)
91
                .await
92
                .map(|_| ())
93
                .map_err(|err| CreateDatabaseError::MySqlError(err.to_string()));
94

            
95
        if let Err(err) = &result {
96
            log::error!("Failed to create database '{}': {:?}", &database_name, err);
97
        }
98

            
99
        results.insert(database_name, result);
100
    }
101

            
102
    results
103
}
104

            
105
pub async fn drop_databases(
106
    database_names: Vec<MySQLDatabase>,
107
    unix_user: &UnixUser,
108
    connection: &mut MySqlConnection,
109
) -> DropDatabasesResponse {
110
    let mut results = BTreeMap::new();
111

            
112
    for database_name in database_names {
113
        if let Err(err) = validate_name(&database_name) {
114
            results.insert(
115
                database_name.to_owned(),
116
                Err(DropDatabaseError::SanitizationError(err)),
117
            );
118
            continue;
119
        }
120

            
121
        if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) {
122
            results.insert(
123
                database_name.to_owned(),
124
                Err(DropDatabaseError::OwnershipError(err)),
125
            );
126
            continue;
127
        }
128

            
129
        match unsafe_database_exists(&database_name, &mut *connection).await {
130
            Ok(false) => {
131
                results.insert(
132
                    database_name.to_owned(),
133
                    Err(DropDatabaseError::DatabaseDoesNotExist),
134
                );
135
                continue;
136
            }
137
            Err(err) => {
138
                results.insert(
139
                    database_name.to_owned(),
140
                    Err(DropDatabaseError::MySqlError(err.to_string())),
141
                );
142
                continue;
143
            }
144
            _ => {}
145
        }
146

            
147
        let result =
148
            sqlx::query(format!("DROP DATABASE {}", quote_identifier(&database_name)).as_str())
149
                .execute(&mut *connection)
150
                .await
151
                .map(|_| ())
152
                .map_err(|err| DropDatabaseError::MySqlError(err.to_string()));
153

            
154
        if let Err(err) = &result {
155
            log::error!("Failed to drop database '{}': {:?}", &database_name, err);
156
        }
157

            
158
        results.insert(database_name, result);
159
    }
160

            
161
    results
162
}
163

            
164
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
165
pub struct DatabaseRow {
166
    pub database: MySQLDatabase,
167
}
168

            
169
impl FromRow<'_, sqlx::mysql::MySqlRow> for DatabaseRow {
170
    fn from_row(row: &sqlx::mysql::MySqlRow) -> Result<Self, sqlx::Error> {
171
        Ok(DatabaseRow {
172
            database: row.try_get::<String, _>("database")?.into(),
173
        })
174
    }
175
}
176

            
177
pub async fn list_databases(
178
    database_names: Vec<MySQLDatabase>,
179
    unix_user: &UnixUser,
180
    connection: &mut MySqlConnection,
181
) -> ListDatabasesResponse {
182
    let mut results = BTreeMap::new();
183

            
184
    for database_name in database_names {
185
        if let Err(err) = validate_name(&database_name) {
186
            results.insert(
187
                database_name.to_owned(),
188
                Err(ListDatabasesError::SanitizationError(err)),
189
            );
190
            continue;
191
        }
192

            
193
        if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) {
194
            results.insert(
195
                database_name.to_owned(),
196
                Err(ListDatabasesError::OwnershipError(err)),
197
            );
198
            continue;
199
        }
200

            
201
        let result = sqlx::query_as::<_, DatabaseRow>(
202
            r#"
203
          SELECT `SCHEMA_NAME` AS `database`
204
          FROM `information_schema`.`SCHEMATA`
205
          WHERE `SCHEMA_NAME` = ?
206
        "#,
207
        )
208
        .bind(database_name.to_string())
209
        .fetch_optional(&mut *connection)
210
        .await
211
        .map_err(|err| ListDatabasesError::MySqlError(err.to_string()))
212
        .and_then(|database| {
213
            database
214
                .map(Ok)
215
                .unwrap_or_else(|| Err(ListDatabasesError::DatabaseDoesNotExist))
216
        });
217

            
218
        if let Err(err) = &result {
219
            log::error!("Failed to list database '{}': {:?}", &database_name, err);
220
        }
221

            
222
        results.insert(database_name, result);
223
    }
224

            
225
    results
226
}
227

            
228
pub async fn list_all_databases_for_user(
229
    unix_user: &UnixUser,
230
    connection: &mut MySqlConnection,
231
) -> ListAllDatabasesResponse {
232
    let result = sqlx::query_as::<_, DatabaseRow>(
233
        r#"
234
          SELECT `SCHEMA_NAME` AS `database`
235
          FROM `information_schema`.`SCHEMATA`
236
          WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')
237
            AND `SCHEMA_NAME` REGEXP ?
238
        "#,
239
    )
240
    .bind(create_user_group_matching_regex(unix_user))
241
    .fetch_all(connection)
242
    .await
243
    .map_err(|err| ListAllDatabasesError::MySqlError(err.to_string()));
244

            
245
    if let Err(err) = &result {
246
        log::error!(
247
            "Failed to list databases for user '{}': {:?}",
248
            unix_user.username,
249
            err
250
        );
251
    }
252

            
253
    result
254
}