1
use std::{
2
    collections::BTreeSet,
3
    fs,
4
    os::unix::{io::FromRawFd, net::UnixListener as StdUnixListener},
5
    path::PathBuf,
6
    sync::Arc,
7
    time::Duration,
8
};
9

            
10
use anyhow::Context;
11
use futures_util::{SinkExt, StreamExt};
12
use indoc::concatdoc;
13
use tokio::{
14
    net::{UnixListener as TokioUnixListener, UnixStream as TokioUnixStream},
15
    time::interval,
16
};
17

            
18
use sqlx::MySqlConnection;
19
use sqlx::prelude::*;
20

            
21
use crate::core::protocol::SetPasswordError;
22
use crate::server::sql::database_operations::list_databases;
23
use crate::{
24
    core::{
25
        common::{DEFAULT_SOCKET_PATH, UnixUser},
26
        protocol::{
27
            Request, Response, ServerToClientMessageStream, create_server_to_client_message_stream,
28
        },
29
    },
30
    server::{
31
        config::{ServerConfig, create_mysql_connection_from_config},
32
        sql::{
33
            database_operations::{create_databases, drop_databases, list_all_databases_for_user},
34
            database_privilege_operations::{
35
                apply_privilege_diffs, get_all_database_privileges, get_databases_privilege_data,
36
            },
37
            user_operations::{
38
                create_database_users, drop_database_users, list_all_database_users_for_unix_user,
39
                list_database_users, lock_database_users, set_password_for_database_user,
40
                unlock_database_users,
41
            },
42
        },
43
    },
44
};
45

            
46
// TODO: consider using a connection pool
47

            
48
pub async fn listen_for_incoming_connections_with_socket_path(
49
    socket_path: Option<PathBuf>,
50
    config: ServerConfig,
51
) -> anyhow::Result<()> {
52
    let socket_path = socket_path.unwrap_or(PathBuf::from(DEFAULT_SOCKET_PATH));
53

            
54
    let parent_directory = socket_path.parent().unwrap();
55
    if !parent_directory.exists() {
56
        log::debug!("Creating directory {:?}", parent_directory);
57
        fs::create_dir_all(parent_directory)?;
58
    }
59

            
60
    log::info!("Listening on socket {:?}", socket_path);
61

            
62
    match fs::remove_file(socket_path.as_path()) {
63
        Ok(_) => {}
64
        Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
65
        Err(e) => return Err(e.into()),
66
    }
67

            
68
    let listener = TokioUnixListener::bind(socket_path)?;
69

            
70
    listen_for_incoming_connections_with_listener(listener, config).await
71
}
72

            
73
pub async fn listen_for_incoming_connections_with_systemd_socket(
74
    config: ServerConfig,
75
) -> anyhow::Result<()> {
76
    let fd = sd_notify::listen_fds()
77
        .context("Failed to get file descriptors from systemd")?
78
        .next()
79
        .context("No file descriptors received from systemd")?;
80

            
81
    debug_assert!(fd == 3, "Unexpected file descriptor from systemd: {}", fd);
82

            
83
    log::debug!(
84
        "Received file descriptor from systemd with id: '{}', assuming socket",
85
        fd
86
    );
87

            
88
    let std_unix_listener = unsafe { StdUnixListener::from_raw_fd(fd) };
89
    let listener = TokioUnixListener::from_std(std_unix_listener)?;
90
    listen_for_incoming_connections_with_listener(listener, config).await
91
}
92

            
93
pub async fn listen_for_incoming_connections_with_listener(
94
    listener: TokioUnixListener,
95
    config: ServerConfig,
96
) -> anyhow::Result<()> {
97
    let connection_counter = Arc::new(());
98
    let connection_counter_for_log = Arc::clone(&connection_counter);
99
    tokio::spawn(async move {
100
        let mut interval = interval(Duration::from_secs(1));
101
        loop {
102
            interval.tick().await;
103
            let count = Arc::strong_count(&connection_counter_for_log) - 2;
104
            let message = if count > 0 {
105
                format!("Handling {} connections", count)
106
            } else {
107
                "Waiting for connections".to_string()
108
            };
109
            sd_notify::notify(false, &[sd_notify::NotifyState::Status(message.as_str())]).ok();
110
        }
111
    });
112

            
113
    sd_notify::notify(false, &[sd_notify::NotifyState::Ready]).ok();
114

            
115
    while let Ok((conn, _addr)) = listener.accept().await {
116
        let uid = match conn.peer_cred() {
117
            Ok(cred) => cred.uid(),
118
            Err(e) => {
119
                log::error!("Failed to get peer credentials from socket: {}", e);
120
                let mut message_stream = create_server_to_client_message_stream(conn);
121
                message_stream
122
                    .send(Response::Error(
123
                        (concatdoc! {
124
                            "Server failed to get peer credentials from socket\n",
125
                            "Please check the server logs or contact the system administrators"
126
                        })
127
                        .to_string(),
128
                    ))
129
                    .await
130
                    .ok();
131
                continue;
132
            }
133
        };
134

            
135
        let _connection_counter_guard = Arc::clone(&connection_counter);
136

            
137
        log::debug!("Accepted connection from uid {}", uid);
138

            
139
        let unix_user = match UnixUser::from_uid(uid) {
140
            Ok(user) => user,
141
            Err(e) => {
142
                log::error!("Failed to get username from uid: {}", e);
143
                let mut message_stream = create_server_to_client_message_stream(conn);
144
                message_stream
145
                    .send(Response::Error(
146
                        (concatdoc! {
147
                            "Server failed to get user data from the system\n",
148
                            "Please check the server logs or contact the system administrators"
149
                        })
150
                        .to_string(),
151
                    ))
152
                    .await
153
                    .ok();
154
                continue;
155
            }
156
        };
157

            
158
        log::info!("Accepted connection from {}", unix_user.username);
159

            
160
        match handle_requests_for_single_session(conn, &unix_user, &config).await {
161
            Ok(()) => {}
162
            Err(e) => {
163
                log::error!("Failed to run server: {}", e);
164
            }
165
        }
166
    }
167

            
168
    Ok(())
169
}
170

            
171
async fn close_or_ignore_db_connection(db_connection: MySqlConnection) {
172
    if let Err(e) = db_connection.close().await {
173
        log::error!("Failed to close database connection: {}", e);
174
        log::error!("{}", e);
175
        log::error!("Ignoring...");
176
    }
177
}
178

            
179
pub async fn handle_requests_for_single_session(
180
    socket: TokioUnixStream,
181
    unix_user: &UnixUser,
182
    config: &ServerConfig,
183
) -> anyhow::Result<()> {
184
    let mut message_stream = create_server_to_client_message_stream(socket);
185

            
186
    log::debug!("Opening connection to database");
187

            
188
    let mut db_connection = match create_mysql_connection_from_config(&config.mysql).await {
189
        Ok(connection) => connection,
190
        Err(err) => {
191
            message_stream
192
                .send(Response::Error(
193
                    (concatdoc! {
194
                        "Server failed to connect to database\n",
195
                        "Please check the server logs or contact the system administrators"
196
                    })
197
                    .to_string(),
198
                ))
199
                .await?;
200
            message_stream.flush().await?;
201
            return Err(err);
202
        }
203
    };
204

            
205
    log::debug!("Verifying that database connection is valid");
206

            
207
    if let Err(e) = db_connection.ping().await {
208
        log::error!("Failed to ping database: {}", e);
209
        message_stream
210
            .send(Response::Error(
211
                (concatdoc! {
212
                    "Server failed to connect to database\n",
213
                    "Please check the server logs or contact the system administrators"
214
                })
215
                .to_string(),
216
            ))
217
            .await?;
218
        message_stream.flush().await?;
219
        close_or_ignore_db_connection(db_connection).await;
220
        return Err(e.into());
221
    }
222

            
223
    log::debug!("Successfully connected to database");
224

            
225
    let result = handle_requests_for_single_session_with_db_connection(
226
        message_stream,
227
        unix_user,
228
        &mut db_connection,
229
    )
230
    .await;
231

            
232
    close_or_ignore_db_connection(db_connection).await;
233

            
234
    result
235
}
236

            
237
// TODO: ensure proper db_connection hygiene for functions that invoke
238
//       this function
239

            
240
async fn handle_requests_for_single_session_with_db_connection(
241
    mut stream: ServerToClientMessageStream,
242
    unix_user: &UnixUser,
243
    db_connection: &mut MySqlConnection,
244
) -> anyhow::Result<()> {
245
    stream.send(Response::Ready).await?;
246
    loop {
247
        // TODO: better error handling
248
        let request = match stream.next().await {
249
            Some(Ok(request)) => request,
250
            Some(Err(e)) => return Err(e.into()),
251
            None => {
252
                log::warn!("Client disconnected without sending an exit message");
253
                break;
254
            }
255
        };
256

            
257
        // TODO: don't clone the request
258
        let request_to_display = match &request {
259
            Request::PasswdUser((db_user, _)) => {
260
                Request::PasswdUser((db_user.to_owned(), "<REDACTED>".to_string()))
261
            }
262
            request => request.to_owned(),
263
        };
264
        log::info!("Received request: {:#?}", request_to_display);
265

            
266
        let response = match request {
267
            Request::CreateDatabases(databases_names) => {
268
                let result = create_databases(databases_names, unix_user, db_connection).await;
269
                Response::CreateDatabases(result)
270
            }
271
            Request::DropDatabases(databases_names) => {
272
                let result = drop_databases(databases_names, unix_user, db_connection).await;
273
                Response::DropDatabases(result)
274
            }
275
            Request::ListDatabases(database_names) => match database_names {
276
                Some(database_names) => {
277
                    let result = list_databases(database_names, unix_user, db_connection).await;
278
                    Response::ListDatabases(result)
279
                }
280
                None => {
281
                    let result = list_all_databases_for_user(unix_user, db_connection).await;
282
                    Response::ListAllDatabases(result)
283
                }
284
            },
285
            Request::ListPrivileges(database_names) => match database_names {
286
                Some(database_names) => {
287
                    let privilege_data =
288
                        get_databases_privilege_data(database_names, unix_user, db_connection)
289
                            .await;
290
                    Response::ListPrivileges(privilege_data)
291
                }
292
                None => {
293
                    let privilege_data =
294
                        get_all_database_privileges(unix_user, db_connection).await;
295
                    Response::ListAllPrivileges(privilege_data)
296
                }
297
            },
298
            Request::ModifyPrivileges(database_privilege_diffs) => {
299
                let result = apply_privilege_diffs(
300
                    BTreeSet::from_iter(database_privilege_diffs),
301
                    unix_user,
302
                    db_connection,
303
                )
304
                .await;
305
                Response::ModifyPrivileges(result)
306
            }
307
            Request::CreateUsers(db_users) => {
308
                let result = create_database_users(db_users, unix_user, db_connection).await;
309
                Response::CreateUsers(result)
310
            }
311
            Request::DropUsers(db_users) => {
312
                let result = drop_database_users(db_users, unix_user, db_connection).await;
313
                Response::DropUsers(result)
314
            }
315
            Request::PasswdUser((db_user, password)) => {
316
                let result =
317
                    set_password_for_database_user(&db_user, &password, unix_user, db_connection)
318
                        .await;
319
                Response::SetUserPassword(result)
320
            }
321
            Request::ListUsers(db_users) => match db_users {
322
                Some(db_users) => {
323
                    let result = list_database_users(db_users, unix_user, db_connection).await;
324
                    Response::ListUsers(result)
325
                }
326
                None => {
327
                    let result =
328
                        list_all_database_users_for_unix_user(unix_user, db_connection).await;
329
                    Response::ListAllUsers(result)
330
                }
331
            },
332
            Request::LockUsers(db_users) => {
333
                let result = lock_database_users(db_users, unix_user, db_connection).await;
334
                Response::LockUsers(result)
335
            }
336
            Request::UnlockUsers(db_users) => {
337
                let result = unlock_database_users(db_users, unix_user, db_connection).await;
338
                Response::UnlockUsers(result)
339
            }
340
            Request::Exit => {
341
                break;
342
            }
343
        };
344

            
345
        // TODO: don't clone the response
346
        let response_to_display = match &response {
347
            Response::SetUserPassword(Err(SetPasswordError::MySqlError(_))) => {
348
                Response::SetUserPassword(Err(SetPasswordError::MySqlError(
349
                    "<REDACTED>".to_string(),
350
                )))
351
            }
352
            response => response.to_owned(),
353
        };
354
        log::info!("Response: {:#?}", response_to_display);
355

            
356
        stream.send(response).await?;
357
        stream.flush().await?;
358
        log::debug!("Successfully processed request");
359
    }
360

            
361
    Ok(())
362
}