1use std::{
2 collections::BTreeSet,
3 fs,
4 os::unix::{io::FromRawFd, net::UnixListener as StdUnixListener},
5 path::PathBuf,
6};
7
8use anyhow::Context;
9use futures_util::{SinkExt, StreamExt};
10use indoc::concatdoc;
11use tokio::net::{UnixListener as TokioUnixListener, UnixStream as TokioUnixStream};
12
13use sqlx::MySqlConnection;
14use sqlx::prelude::*;
15
16use crate::core::protocol::SetPasswordError;
17use crate::server::sql::database_operations::list_databases;
18use crate::{
19 core::{
20 common::{DEFAULT_SOCKET_PATH, UnixUser},
21 protocol::{
22 Request, Response, ServerToClientMessageStream, create_server_to_client_message_stream,
23 },
24 },
25 server::{
26 config::{ServerConfig, create_mysql_connection_from_config},
27 sql::{
28 database_operations::{create_databases, drop_databases, list_all_databases_for_user},
29 database_privilege_operations::{
30 apply_privilege_diffs, get_all_database_privileges, get_databases_privilege_data,
31 },
32 user_operations::{
33 create_database_users, drop_database_users, list_all_database_users_for_unix_user,
34 list_database_users, lock_database_users, set_password_for_database_user,
35 unlock_database_users,
36 },
37 },
38 },
39};
40
41pub async fn listen_for_incoming_connections_with_socket_path(
44 socket_path: Option<PathBuf>,
45 config: ServerConfig,
46) -> anyhow::Result<()> {
47 let socket_path = socket_path.unwrap_or(PathBuf::from(DEFAULT_SOCKET_PATH));
48
49 let parent_directory = socket_path.parent().unwrap();
50 if !parent_directory.exists() {
51 log::debug!("Creating directory {:?}", parent_directory);
52 fs::create_dir_all(parent_directory)?;
53 }
54
55 log::info!("Listening on socket {:?}", socket_path);
56
57 match fs::remove_file(socket_path.as_path()) {
58 Ok(_) => {}
59 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
60 Err(e) => return Err(e.into()),
61 }
62
63 let listener = TokioUnixListener::bind(socket_path)?;
64
65 listen_for_incoming_connections_with_listener(listener, config).await
66}
67
68pub async fn listen_for_incoming_connections_with_systemd_socket(
69 config: ServerConfig,
70) -> anyhow::Result<()> {
71 let fd = sd_notify::listen_fds()
72 .context("Failed to get file descriptors from systemd")?
73 .next()
74 .context("No file descriptors received from systemd")?;
75
76 debug_assert!(fd == 3, "Unexpected file descriptor from systemd: {}", fd);
77
78 log::debug!(
79 "Received file descriptor from systemd with id: '{}', assuming socket",
80 fd
81 );
82
83 let std_unix_listener = unsafe { StdUnixListener::from_raw_fd(fd) };
84 let listener = TokioUnixListener::from_std(std_unix_listener)?;
85 listen_for_incoming_connections_with_listener(listener, config).await
86}
87
88pub async fn listen_for_incoming_connections_with_listener(
89 listener: TokioUnixListener,
90 config: ServerConfig,
91) -> anyhow::Result<()> {
92 sd_notify::notify(false, &[sd_notify::NotifyState::Ready]).ok();
93
94 while let Ok((conn, _addr)) = listener.accept().await {
95 let uid = match conn.peer_cred() {
96 Ok(cred) => cred.uid(),
97 Err(e) => {
98 log::error!("Failed to get peer credentials from socket: {}", e);
99 let mut message_stream = create_server_to_client_message_stream(conn);
100 message_stream
101 .send(Response::Error(
102 (concatdoc! {
103 "Server failed to get peer credentials from socket\n",
104 "Please check the server logs or contact the system administrators"
105 })
106 .to_string(),
107 ))
108 .await
109 .ok();
110 continue;
111 }
112 };
113
114 log::debug!("Accepted connection from uid {}", uid);
115
116 let unix_user = match UnixUser::from_uid(uid) {
117 Ok(user) => user,
118 Err(e) => {
119 log::error!("Failed to get username from uid: {}", e);
120 let mut message_stream = create_server_to_client_message_stream(conn);
121 message_stream
122 .send(Response::Error(
123 (concatdoc! {
124 "Server failed to get user data from the system\n",
125 "Please check the server logs or contact the system administrators"
126 })
127 .to_string(),
128 ))
129 .await
130 .ok();
131 continue;
132 }
133 };
134
135 log::info!("Accepted connection from {}", unix_user.username);
136
137 match handle_requests_for_single_session(conn, &unix_user, &config).await {
138 Ok(()) => {}
139 Err(e) => {
140 log::error!("Failed to run server: {}", e);
141 }
142 }
143 }
144
145 Ok(())
146}
147
148async fn close_or_ignore_db_connection(db_connection: MySqlConnection) {
149 if let Err(e) = db_connection.close().await {
150 log::error!("Failed to close database connection: {}", e);
151 log::error!("{}", e);
152 log::error!("Ignoring...");
153 }
154}
155
156pub async fn handle_requests_for_single_session(
157 socket: TokioUnixStream,
158 unix_user: &UnixUser,
159 config: &ServerConfig,
160) -> anyhow::Result<()> {
161 let mut message_stream = create_server_to_client_message_stream(socket);
162
163 log::debug!("Opening connection to database");
164
165 let mut db_connection = match create_mysql_connection_from_config(&config.mysql).await {
166 Ok(connection) => connection,
167 Err(err) => {
168 message_stream
169 .send(Response::Error(
170 (concatdoc! {
171 "Server failed to connect to database\n",
172 "Please check the server logs or contact the system administrators"
173 })
174 .to_string(),
175 ))
176 .await?;
177 message_stream.flush().await?;
178 return Err(err);
179 }
180 };
181
182 log::debug!("Verifying that database connection is valid");
183
184 if let Err(e) = db_connection.ping().await {
185 log::error!("Failed to ping database: {}", e);
186 message_stream
187 .send(Response::Error(
188 (concatdoc! {
189 "Server failed to connect to database\n",
190 "Please check the server logs or contact the system administrators"
191 })
192 .to_string(),
193 ))
194 .await?;
195 message_stream.flush().await?;
196 close_or_ignore_db_connection(db_connection).await;
197 return Err(e.into());
198 }
199
200 log::debug!("Successfully connected to database");
201
202 let result = handle_requests_for_single_session_with_db_connection(
203 message_stream,
204 unix_user,
205 &mut db_connection,
206 )
207 .await;
208
209 close_or_ignore_db_connection(db_connection).await;
210
211 result
212}
213
214async fn handle_requests_for_single_session_with_db_connection(
218 mut stream: ServerToClientMessageStream,
219 unix_user: &UnixUser,
220 db_connection: &mut MySqlConnection,
221) -> anyhow::Result<()> {
222 stream.send(Response::Ready).await?;
223 loop {
224 let request = match stream.next().await {
226 Some(Ok(request)) => request,
227 Some(Err(e)) => return Err(e.into()),
228 None => {
229 log::warn!("Client disconnected without sending an exit message");
230 break;
231 }
232 };
233
234 let request_to_display = match &request {
236 Request::PasswdUser((db_user, _)) => {
237 Request::PasswdUser((db_user.to_owned(), "<REDACTED>".to_string()))
238 }
239 request => request.to_owned(),
240 };
241 log::info!("Received request: {:#?}", request_to_display);
242
243 let response = match request {
244 Request::CreateDatabases(databases_names) => {
245 let result = create_databases(databases_names, unix_user, db_connection).await;
246 Response::CreateDatabases(result)
247 }
248 Request::DropDatabases(databases_names) => {
249 let result = drop_databases(databases_names, unix_user, db_connection).await;
250 Response::DropDatabases(result)
251 }
252 Request::ListDatabases(database_names) => match database_names {
253 Some(database_names) => {
254 let result = list_databases(database_names, unix_user, db_connection).await;
255 Response::ListDatabases(result)
256 }
257 None => {
258 let result = list_all_databases_for_user(unix_user, db_connection).await;
259 Response::ListAllDatabases(result)
260 }
261 },
262 Request::ListPrivileges(database_names) => match database_names {
263 Some(database_names) => {
264 let privilege_data =
265 get_databases_privilege_data(database_names, unix_user, db_connection)
266 .await;
267 Response::ListPrivileges(privilege_data)
268 }
269 None => {
270 let privilege_data =
271 get_all_database_privileges(unix_user, db_connection).await;
272 Response::ListAllPrivileges(privilege_data)
273 }
274 },
275 Request::ModifyPrivileges(database_privilege_diffs) => {
276 let result = apply_privilege_diffs(
277 BTreeSet::from_iter(database_privilege_diffs),
278 unix_user,
279 db_connection,
280 )
281 .await;
282 Response::ModifyPrivileges(result)
283 }
284 Request::CreateUsers(db_users) => {
285 let result = create_database_users(db_users, unix_user, db_connection).await;
286 Response::CreateUsers(result)
287 }
288 Request::DropUsers(db_users) => {
289 let result = drop_database_users(db_users, unix_user, db_connection).await;
290 Response::DropUsers(result)
291 }
292 Request::PasswdUser((db_user, password)) => {
293 let result =
294 set_password_for_database_user(&db_user, &password, unix_user, db_connection)
295 .await;
296 Response::SetUserPassword(result)
297 }
298 Request::ListUsers(db_users) => match db_users {
299 Some(db_users) => {
300 let result = list_database_users(db_users, unix_user, db_connection).await;
301 Response::ListUsers(result)
302 }
303 None => {
304 let result =
305 list_all_database_users_for_unix_user(unix_user, db_connection).await;
306 Response::ListAllUsers(result)
307 }
308 },
309 Request::LockUsers(db_users) => {
310 let result = lock_database_users(db_users, unix_user, db_connection).await;
311 Response::LockUsers(result)
312 }
313 Request::UnlockUsers(db_users) => {
314 let result = unlock_database_users(db_users, unix_user, db_connection).await;
315 Response::UnlockUsers(result)
316 }
317 Request::Exit => {
318 break;
319 }
320 };
321
322 let response_to_display = match &response {
324 Response::SetUserPassword(Err(SetPasswordError::MySqlError(_))) => {
325 Response::SetUserPassword(Err(SetPasswordError::MySqlError(
326 "<REDACTED>".to_string(),
327 )))
328 }
329 response => response.to_owned(),
330 };
331 log::info!("Response: {:#?}", response_to_display);
332
333 stream.send(response).await?;
334 stream.flush().await?;
335 log::debug!("Successfully processed request");
336 }
337
338 Ok(())
339}