1use std::collections::BTreeMap;
2
3use sqlx::MySqlConnection;
4use sqlx::prelude::*;
5
6use serde::{Deserialize, Serialize};
7
8use crate::core::types::MySQLDatabase;
9use crate::{
10 core::{
11 common::UnixUser,
12 protocol::{
13 CreateDatabaseError, CreateDatabasesResponse, DropDatabaseError, DropDatabasesResponse,
14 ListAllDatabasesError, ListAllDatabasesResponse, ListDatabasesError,
15 ListDatabasesResponse,
16 },
17 },
18 server::{
19 common::create_user_group_matching_regex,
20 input_sanitization::{quote_identifier, validate_name, validate_ownership_by_unix_user},
21 },
22};
23
24pub(super) async fn unsafe_database_exists(
26 database_name: &str,
27 connection: &mut MySqlConnection,
28) -> Result<bool, sqlx::Error> {
29 let result =
30 sqlx::query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?")
31 .bind(database_name)
32 .fetch_optional(connection)
33 .await;
34
35 if let Err(err) = &result {
36 log::error!(
37 "Failed to check if database '{}' exists: {:?}",
38 &database_name,
39 err
40 );
41 }
42
43 Ok(result?.is_some())
44}
45
46pub async fn create_databases(
47 database_names: Vec<MySQLDatabase>,
48 unix_user: &UnixUser,
49 connection: &mut MySqlConnection,
50) -> CreateDatabasesResponse {
51 let mut results = BTreeMap::new();
52
53 for database_name in database_names {
54 if let Err(err) = validate_name(&database_name) {
55 results.insert(
56 database_name.to_owned(),
57 Err(CreateDatabaseError::SanitizationError(err)),
58 );
59 continue;
60 }
61
62 if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) {
63 results.insert(
64 database_name.to_owned(),
65 Err(CreateDatabaseError::OwnershipError(err)),
66 );
67 continue;
68 }
69
70 match unsafe_database_exists(&database_name, &mut *connection).await {
71 Ok(true) => {
72 results.insert(
73 database_name.to_owned(),
74 Err(CreateDatabaseError::DatabaseAlreadyExists),
75 );
76 continue;
77 }
78 Err(err) => {
79 results.insert(
80 database_name.to_owned(),
81 Err(CreateDatabaseError::MySqlError(err.to_string())),
82 );
83 continue;
84 }
85 _ => {}
86 }
87
88 let result =
89 sqlx::query(format!("CREATE DATABASE {}", quote_identifier(&database_name)).as_str())
90 .execute(&mut *connection)
91 .await
92 .map(|_| ())
93 .map_err(|err| CreateDatabaseError::MySqlError(err.to_string()));
94
95 if let Err(err) = &result {
96 log::error!("Failed to create database '{}': {:?}", &database_name, err);
97 }
98
99 results.insert(database_name, result);
100 }
101
102 results
103}
104
105pub async fn drop_databases(
106 database_names: Vec<MySQLDatabase>,
107 unix_user: &UnixUser,
108 connection: &mut MySqlConnection,
109) -> DropDatabasesResponse {
110 let mut results = BTreeMap::new();
111
112 for database_name in database_names {
113 if let Err(err) = validate_name(&database_name) {
114 results.insert(
115 database_name.to_owned(),
116 Err(DropDatabaseError::SanitizationError(err)),
117 );
118 continue;
119 }
120
121 if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) {
122 results.insert(
123 database_name.to_owned(),
124 Err(DropDatabaseError::OwnershipError(err)),
125 );
126 continue;
127 }
128
129 match unsafe_database_exists(&database_name, &mut *connection).await {
130 Ok(false) => {
131 results.insert(
132 database_name.to_owned(),
133 Err(DropDatabaseError::DatabaseDoesNotExist),
134 );
135 continue;
136 }
137 Err(err) => {
138 results.insert(
139 database_name.to_owned(),
140 Err(DropDatabaseError::MySqlError(err.to_string())),
141 );
142 continue;
143 }
144 _ => {}
145 }
146
147 let result =
148 sqlx::query(format!("DROP DATABASE {}", quote_identifier(&database_name)).as_str())
149 .execute(&mut *connection)
150 .await
151 .map(|_| ())
152 .map_err(|err| DropDatabaseError::MySqlError(err.to_string()));
153
154 if let Err(err) = &result {
155 log::error!("Failed to drop database '{}': {:?}", &database_name, err);
156 }
157
158 results.insert(database_name, result);
159 }
160
161 results
162}
163
164#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
165pub struct DatabaseRow {
166 pub database: MySQLDatabase,
167}
168
169impl FromRow<'_, sqlx::mysql::MySqlRow> for DatabaseRow {
170 fn from_row(row: &sqlx::mysql::MySqlRow) -> Result<Self, sqlx::Error> {
171 Ok(DatabaseRow {
172 database: row.try_get::<String, _>("database")?.into(),
173 })
174 }
175}
176
177pub async fn list_databases(
178 database_names: Vec<MySQLDatabase>,
179 unix_user: &UnixUser,
180 connection: &mut MySqlConnection,
181) -> ListDatabasesResponse {
182 let mut results = BTreeMap::new();
183
184 for database_name in database_names {
185 if let Err(err) = validate_name(&database_name) {
186 results.insert(
187 database_name.to_owned(),
188 Err(ListDatabasesError::SanitizationError(err)),
189 );
190 continue;
191 }
192
193 if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) {
194 results.insert(
195 database_name.to_owned(),
196 Err(ListDatabasesError::OwnershipError(err)),
197 );
198 continue;
199 }
200
201 let result = sqlx::query_as::<_, DatabaseRow>(
202 r#"
203 SELECT `SCHEMA_NAME` AS `database`
204 FROM `information_schema`.`SCHEMATA`
205 WHERE `SCHEMA_NAME` = ?
206 "#,
207 )
208 .bind(database_name.to_string())
209 .fetch_optional(&mut *connection)
210 .await
211 .map_err(|err| ListDatabasesError::MySqlError(err.to_string()))
212 .and_then(|database| {
213 database
214 .map(Ok)
215 .unwrap_or_else(|| Err(ListDatabasesError::DatabaseDoesNotExist))
216 });
217
218 if let Err(err) = &result {
219 log::error!("Failed to list database '{}': {:?}", &database_name, err);
220 }
221
222 results.insert(database_name, result);
223 }
224
225 results
226}
227
228pub async fn list_all_databases_for_user(
229 unix_user: &UnixUser,
230 connection: &mut MySqlConnection,
231) -> ListAllDatabasesResponse {
232 let result = sqlx::query_as::<_, DatabaseRow>(
233 r#"
234 SELECT `SCHEMA_NAME` AS `database`
235 FROM `information_schema`.`SCHEMATA`
236 WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')
237 AND `SCHEMA_NAME` REGEXP ?
238 "#,
239 )
240 .bind(create_user_group_matching_regex(unix_user))
241 .fetch_all(connection)
242 .await
243 .map_err(|err| ListAllDatabasesError::MySqlError(err.to_string()));
244
245 if let Err(err) = &result {
246 log::error!(
247 "Failed to list databases for user '{}': {:?}",
248 unix_user.username,
249 err
250 );
251 }
252
253 result
254}