mysqladm/server/
server_loop.rs

1use std::{
2    collections::BTreeSet,
3    fs,
4    os::unix::{io::FromRawFd, net::UnixListener as StdUnixListener},
5    path::PathBuf,
6};
7
8use anyhow::Context;
9use futures_util::{SinkExt, StreamExt};
10use indoc::concatdoc;
11use tokio::net::{UnixListener as TokioUnixListener, UnixStream as TokioUnixStream};
12
13use sqlx::MySqlConnection;
14use sqlx::prelude::*;
15
16use crate::core::protocol::SetPasswordError;
17use crate::server::sql::database_operations::list_databases;
18use crate::{
19    core::{
20        common::{DEFAULT_SOCKET_PATH, UnixUser},
21        protocol::{
22            Request, Response, ServerToClientMessageStream, create_server_to_client_message_stream,
23        },
24    },
25    server::{
26        config::{ServerConfig, create_mysql_connection_from_config},
27        sql::{
28            database_operations::{create_databases, drop_databases, list_all_databases_for_user},
29            database_privilege_operations::{
30                apply_privilege_diffs, get_all_database_privileges, get_databases_privilege_data,
31            },
32            user_operations::{
33                create_database_users, drop_database_users, list_all_database_users_for_unix_user,
34                list_database_users, lock_database_users, set_password_for_database_user,
35                unlock_database_users,
36            },
37        },
38    },
39};
40
41// TODO: consider using a connection pool
42
43pub async fn listen_for_incoming_connections_with_socket_path(
44    socket_path: Option<PathBuf>,
45    config: ServerConfig,
46) -> anyhow::Result<()> {
47    let socket_path = socket_path.unwrap_or(PathBuf::from(DEFAULT_SOCKET_PATH));
48
49    let parent_directory = socket_path.parent().unwrap();
50    if !parent_directory.exists() {
51        log::debug!("Creating directory {:?}", parent_directory);
52        fs::create_dir_all(parent_directory)?;
53    }
54
55    log::info!("Listening on socket {:?}", socket_path);
56
57    match fs::remove_file(socket_path.as_path()) {
58        Ok(_) => {}
59        Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
60        Err(e) => return Err(e.into()),
61    }
62
63    let listener = TokioUnixListener::bind(socket_path)?;
64
65    listen_for_incoming_connections_with_listener(listener, config).await
66}
67
68pub async fn listen_for_incoming_connections_with_systemd_socket(
69    config: ServerConfig,
70) -> anyhow::Result<()> {
71    let fd = sd_notify::listen_fds()
72        .context("Failed to get file descriptors from systemd")?
73        .next()
74        .context("No file descriptors received from systemd")?;
75
76    debug_assert!(fd == 3, "Unexpected file descriptor from systemd: {}", fd);
77
78    log::debug!(
79        "Received file descriptor from systemd with id: '{}', assuming socket",
80        fd
81    );
82
83    let std_unix_listener = unsafe { StdUnixListener::from_raw_fd(fd) };
84    let listener = TokioUnixListener::from_std(std_unix_listener)?;
85    listen_for_incoming_connections_with_listener(listener, config).await
86}
87
88pub async fn listen_for_incoming_connections_with_listener(
89    listener: TokioUnixListener,
90    config: ServerConfig,
91) -> anyhow::Result<()> {
92    sd_notify::notify(false, &[sd_notify::NotifyState::Ready]).ok();
93
94    while let Ok((conn, _addr)) = listener.accept().await {
95        let uid = match conn.peer_cred() {
96            Ok(cred) => cred.uid(),
97            Err(e) => {
98                log::error!("Failed to get peer credentials from socket: {}", e);
99                let mut message_stream = create_server_to_client_message_stream(conn);
100                message_stream
101                    .send(Response::Error(
102                        (concatdoc! {
103                            "Server failed to get peer credentials from socket\n",
104                            "Please check the server logs or contact the system administrators"
105                        })
106                        .to_string(),
107                    ))
108                    .await
109                    .ok();
110                continue;
111            }
112        };
113
114        log::debug!("Accepted connection from uid {}", uid);
115
116        let unix_user = match UnixUser::from_uid(uid) {
117            Ok(user) => user,
118            Err(e) => {
119                log::error!("Failed to get username from uid: {}", e);
120                let mut message_stream = create_server_to_client_message_stream(conn);
121                message_stream
122                    .send(Response::Error(
123                        (concatdoc! {
124                            "Server failed to get user data from the system\n",
125                            "Please check the server logs or contact the system administrators"
126                        })
127                        .to_string(),
128                    ))
129                    .await
130                    .ok();
131                continue;
132            }
133        };
134
135        log::info!("Accepted connection from {}", unix_user.username);
136
137        match handle_requests_for_single_session(conn, &unix_user, &config).await {
138            Ok(()) => {}
139            Err(e) => {
140                log::error!("Failed to run server: {}", e);
141            }
142        }
143    }
144
145    Ok(())
146}
147
148async fn close_or_ignore_db_connection(db_connection: MySqlConnection) {
149    if let Err(e) = db_connection.close().await {
150        log::error!("Failed to close database connection: {}", e);
151        log::error!("{}", e);
152        log::error!("Ignoring...");
153    }
154}
155
156pub async fn handle_requests_for_single_session(
157    socket: TokioUnixStream,
158    unix_user: &UnixUser,
159    config: &ServerConfig,
160) -> anyhow::Result<()> {
161    let mut message_stream = create_server_to_client_message_stream(socket);
162
163    log::debug!("Opening connection to database");
164
165    let mut db_connection = match create_mysql_connection_from_config(&config.mysql).await {
166        Ok(connection) => connection,
167        Err(err) => {
168            message_stream
169                .send(Response::Error(
170                    (concatdoc! {
171                        "Server failed to connect to database\n",
172                        "Please check the server logs or contact the system administrators"
173                    })
174                    .to_string(),
175                ))
176                .await?;
177            message_stream.flush().await?;
178            return Err(err);
179        }
180    };
181
182    log::debug!("Verifying that database connection is valid");
183
184    if let Err(e) = db_connection.ping().await {
185        log::error!("Failed to ping database: {}", e);
186        message_stream
187            .send(Response::Error(
188                (concatdoc! {
189                    "Server failed to connect to database\n",
190                    "Please check the server logs or contact the system administrators"
191                })
192                .to_string(),
193            ))
194            .await?;
195        message_stream.flush().await?;
196        close_or_ignore_db_connection(db_connection).await;
197        return Err(e.into());
198    }
199
200    log::debug!("Successfully connected to database");
201
202    let result = handle_requests_for_single_session_with_db_connection(
203        message_stream,
204        unix_user,
205        &mut db_connection,
206    )
207    .await;
208
209    close_or_ignore_db_connection(db_connection).await;
210
211    result
212}
213
214// TODO: ensure proper db_connection hygiene for functions that invoke
215//       this function
216
217async fn handle_requests_for_single_session_with_db_connection(
218    mut stream: ServerToClientMessageStream,
219    unix_user: &UnixUser,
220    db_connection: &mut MySqlConnection,
221) -> anyhow::Result<()> {
222    stream.send(Response::Ready).await?;
223    loop {
224        // TODO: better error handling
225        let request = match stream.next().await {
226            Some(Ok(request)) => request,
227            Some(Err(e)) => return Err(e.into()),
228            None => {
229                log::warn!("Client disconnected without sending an exit message");
230                break;
231            }
232        };
233
234        // TODO: don't clone the request
235        let request_to_display = match &request {
236            Request::PasswdUser((db_user, _)) => {
237                Request::PasswdUser((db_user.to_owned(), "<REDACTED>".to_string()))
238            }
239            request => request.to_owned(),
240        };
241        log::info!("Received request: {:#?}", request_to_display);
242
243        let response = match request {
244            Request::CreateDatabases(databases_names) => {
245                let result = create_databases(databases_names, unix_user, db_connection).await;
246                Response::CreateDatabases(result)
247            }
248            Request::DropDatabases(databases_names) => {
249                let result = drop_databases(databases_names, unix_user, db_connection).await;
250                Response::DropDatabases(result)
251            }
252            Request::ListDatabases(database_names) => match database_names {
253                Some(database_names) => {
254                    let result = list_databases(database_names, unix_user, db_connection).await;
255                    Response::ListDatabases(result)
256                }
257                None => {
258                    let result = list_all_databases_for_user(unix_user, db_connection).await;
259                    Response::ListAllDatabases(result)
260                }
261            },
262            Request::ListPrivileges(database_names) => match database_names {
263                Some(database_names) => {
264                    let privilege_data =
265                        get_databases_privilege_data(database_names, unix_user, db_connection)
266                            .await;
267                    Response::ListPrivileges(privilege_data)
268                }
269                None => {
270                    let privilege_data =
271                        get_all_database_privileges(unix_user, db_connection).await;
272                    Response::ListAllPrivileges(privilege_data)
273                }
274            },
275            Request::ModifyPrivileges(database_privilege_diffs) => {
276                let result = apply_privilege_diffs(
277                    BTreeSet::from_iter(database_privilege_diffs),
278                    unix_user,
279                    db_connection,
280                )
281                .await;
282                Response::ModifyPrivileges(result)
283            }
284            Request::CreateUsers(db_users) => {
285                let result = create_database_users(db_users, unix_user, db_connection).await;
286                Response::CreateUsers(result)
287            }
288            Request::DropUsers(db_users) => {
289                let result = drop_database_users(db_users, unix_user, db_connection).await;
290                Response::DropUsers(result)
291            }
292            Request::PasswdUser((db_user, password)) => {
293                let result =
294                    set_password_for_database_user(&db_user, &password, unix_user, db_connection)
295                        .await;
296                Response::SetUserPassword(result)
297            }
298            Request::ListUsers(db_users) => match db_users {
299                Some(db_users) => {
300                    let result = list_database_users(db_users, unix_user, db_connection).await;
301                    Response::ListUsers(result)
302                }
303                None => {
304                    let result =
305                        list_all_database_users_for_unix_user(unix_user, db_connection).await;
306                    Response::ListAllUsers(result)
307                }
308            },
309            Request::LockUsers(db_users) => {
310                let result = lock_database_users(db_users, unix_user, db_connection).await;
311                Response::LockUsers(result)
312            }
313            Request::UnlockUsers(db_users) => {
314                let result = unlock_database_users(db_users, unix_user, db_connection).await;
315                Response::UnlockUsers(result)
316            }
317            Request::Exit => {
318                break;
319            }
320        };
321
322        // TODO: don't clone the response
323        let response_to_display = match &response {
324            Response::SetUserPassword(Err(SetPasswordError::MySqlError(_))) => {
325                Response::SetUserPassword(Err(SetPasswordError::MySqlError(
326                    "<REDACTED>".to_string(),
327                )))
328            }
329            response => response.to_owned(),
330        };
331        log::info!("Response: {:#?}", response_to_display);
332
333        stream.send(response).await?;
334        stream.flush().await?;
335        log::debug!("Successfully processed request");
336    }
337
338    Ok(())
339}