1
use std::sync::Arc;
2

            
3
use futures_util::{SinkExt, StreamExt};
4
use indoc::concatdoc;
5
use itertools::Itertools;
6
use sqlx::{MySqlConnection, MySqlPool};
7
use tokio::{net::UnixStream, sync::RwLock};
8
use tracing::Instrument;
9

            
10
use crate::{
11
    core::{
12
        common::UnixUser,
13
        protocol::{
14
            Request, Response, ServerToClientMessageStream, SetPasswordError,
15
            create_server_to_client_message_stream, request_validation::GroupDenylist,
16
        },
17
    },
18
    server::{
19
        authorization::check_authorization,
20
        common::get_user_filtered_groups,
21
        sql::{
22
            database_operations::{
23
                complete_database_name, create_databases, drop_databases,
24
                list_all_databases_for_user, list_databases,
25
            },
26
            database_privilege_operations::{
27
                apply_privilege_diffs, get_all_database_privileges, get_databases_privilege_data,
28
            },
29
            user_operations::{
30
                complete_user_name, create_database_users, drop_database_users,
31
                list_all_database_users_for_unix_user, list_database_users, lock_database_users,
32
                set_password_for_database_user, unlock_database_users,
33
            },
34
        },
35
    },
36
};
37

            
38
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
39
pub struct SessionId(u64);
40

            
41
impl SessionId {
42
    pub fn new(id: u64) -> Self {
43
        SessionId(id)
44
    }
45

            
46
    pub fn inner(&self) -> u64 {
47
        self.0
48
    }
49
}
50

            
51
// TODO: don't use database connection unless necessary.
52

            
53
pub async fn session_handler(
54
    socket: UnixStream,
55
    session_id: SessionId,
56
    db_pool: Arc<RwLock<MySqlPool>>,
57
    db_is_mariadb: bool,
58
    group_denylist: &GroupDenylist,
59
) -> anyhow::Result<()> {
60
    let uid = match socket.peer_cred() {
61
        Ok(cred) => cred.uid(),
62
        Err(e) => {
63
            tracing::error!("Failed to get peer credentials from socket: {}", e);
64
            let mut message_stream = create_server_to_client_message_stream(socket);
65
            message_stream
66
                .send(Response::Error(
67
                    (concatdoc! {
68
                        "Server failed to get peer credentials from socket\n",
69
                        "Please check the server logs or contact the system administrators"
70
                    })
71
                    .to_string(),
72
                ))
73
                .await
74
                .ok();
75
            anyhow::bail!("Failed to get peer credentials from socket");
76
        }
77
    };
78

            
79
    tracing::trace!("Validated peer UID: {}", uid);
80

            
81
    let unix_user = match UnixUser::from_uid(uid) {
82
        Ok(user) => user,
83
        Err(e) => {
84
            tracing::error!("Failed to get username from uid: {}", e);
85
            let mut message_stream = create_server_to_client_message_stream(socket);
86
            message_stream
87
                .send(Response::Error(
88
                    (concatdoc! {
89
                        "Server failed to get user data from the system\n",
90
                        "Please check the server logs or contact the system administrators"
91
                    })
92
                    .to_string(),
93
                ))
94
                .await
95
                .ok();
96
            anyhow::bail!("Failed to get username from uid: {e}");
97
        }
98
    };
99

            
100
    let span = tracing::info_span!(
101
        "user_session",
102
        session_id = session_id.inner(),
103
        user = %unix_user,
104
    );
105

            
106
    (async move {
107
        tracing::debug!("Accepted connection from user: {}", unix_user);
108

            
109
        let result = session_handler_with_unix_user(
110
            socket,
111
            session_id,
112
            &unix_user,
113
            db_pool,
114
            db_is_mariadb,
115
            group_denylist,
116
        )
117
        .await;
118

            
119
        tracing::debug!(
120
            "Finished handling requests for connection from user: {}",
121
            unix_user,
122
        );
123

            
124
        result
125
    })
126
    .instrument(span)
127
    .await
128
}
129

            
130
pub async fn session_handler_with_unix_user(
131
    socket: UnixStream,
132
    session_id: SessionId,
133
    unix_user: &UnixUser,
134
    db_pool: Arc<RwLock<MySqlPool>>,
135
    db_is_mariadb: bool,
136
    group_denylist: &GroupDenylist,
137
) -> anyhow::Result<()> {
138
    let mut message_stream = create_server_to_client_message_stream(socket);
139

            
140
    tracing::trace!("Requesting database connection from pool");
141
    let mut db_connection = match db_pool.read().await.acquire().await {
142
        Ok(connection) => connection,
143
        Err(err) => {
144
            message_stream
145
                .send(Response::Error(
146
                    (concatdoc! {
147
                        "Server failed to connect to database\n",
148
                        "Please check the server logs or contact the system administrators"
149
                    })
150
                    .to_string(),
151
                ))
152
                .await?;
153
            message_stream.flush().await?;
154
            return Err(err.into());
155
        }
156
    };
157
    tracing::trace!("Successfully acquired database connection from pool");
158

            
159
    let result = session_handler_with_db_connection(
160
        message_stream,
161
        session_id,
162
        unix_user,
163
        &mut db_connection,
164
        db_is_mariadb,
165
        group_denylist,
166
    )
167
    .await;
168

            
169
    tracing::trace!("Releasing database connection back to pool");
170

            
171
    result
172
}
173

            
174
// TODO: ensure proper db_connection hygiene for functions that invoke
175
//       this function
176

            
177
async fn session_handler_with_db_connection(
178
    mut stream: ServerToClientMessageStream,
179
    session_id: SessionId,
180
    unix_user: &UnixUser,
181
    db_connection: &mut MySqlConnection,
182
    db_is_mariadb: bool,
183
    group_denylist: &GroupDenylist,
184
) -> anyhow::Result<()> {
185
    stream.send(Response::Ready).await?;
186
    loop {
187
        // TODO: better error handling
188
        // TODO: timeout for receiving requests
189
        // TODO: cancel on request by supervisor
190
        let request = match stream.next().await {
191
            Some(Ok(request)) => request,
192
            Some(Err(e)) => return Err(e.into()),
193
            None => {
194
                tracing::warn!("Client disconnected without sending an exit message");
195
                break;
196
            }
197
        };
198

            
199
        let request_span = tracing::info_span!("request", command = request.command_name());
200

            
201
        if !handle_request(
202
            request,
203
            session_id,
204
            unix_user,
205
            db_connection,
206
            db_is_mariadb,
207
            group_denylist,
208
            &mut stream,
209
        )
210
        .instrument(request_span)
211
        .await?
212
        {
213
            break;
214
        }
215
    }
216

            
217
    Ok(())
218
}
219

            
220
/// Handle a single request from a client.
221
///
222
/// If the function returns `true`, the session should continue.
223
async fn handle_request(
224
    request: Request,
225
    session_id: SessionId,
226
    unix_user: &UnixUser,
227
    db_connection: &mut MySqlConnection,
228
    db_is_mariadb: bool,
229
    group_denylist: &GroupDenylist,
230
    stream: &mut ServerToClientMessageStream,
231
) -> anyhow::Result<bool> {
232
    match &request {
233
        Request::Exit => tracing::debug!("Request: exit"),
234
        Request::PasswdUser((db_user, _)) => tracing::debug!(
235
            "Request:\n{}",
236
            serde_json::to_string_pretty(&Request::PasswdUser((
237
                db_user.to_owned(),
238
                "<REDACTED>".to_string()
239
            )))?
240
        ),
241
        request => tracing::debug!("Request:\n{}", serde_json::to_string_pretty(request)?),
242
    }
243

            
244
    let affected_dbs = request.affected_databases();
245
    if !affected_dbs.is_empty() {
246
        tracing::trace!(
247
            "Affected databases: {}",
248
            affected_dbs.into_iter().map(|db| db.to_string()).join(", ")
249
        );
250
    }
251

            
252
    let affected_users = request.affected_users();
253
    if !affected_users.is_empty() {
254
        tracing::trace!(
255
            "Affected users: {}",
256
            affected_users.into_iter().map(|u| u.to_string()).join(", "),
257
        );
258
    }
259

            
260
    let response = match request {
261
        Request::CheckAuthorization(ref dbs_or_users) => {
262
            let result = check_authorization(dbs_or_users, unix_user, group_denylist).await;
263
            Response::CheckAuthorization(result)
264
        }
265
        Request::ListValidNamePrefixes => {
266
            let mut result = Vec::with_capacity(unix_user.groups.len() + 1);
267
            result.push(unix_user.username.clone());
268

            
269
            for group in get_user_filtered_groups(unix_user, group_denylist) {
270
                result.push(group.clone());
271
            }
272

            
273
            Response::ListValidNamePrefixes(result)
274
        }
275
        Request::CompleteDatabaseName(ref partial_database_name) => {
276
            // TODO: more correct validation here
277
            if partial_database_name
278
                .chars()
279
                .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
280
            {
281
                let result = complete_database_name(
282
                    partial_database_name,
283
                    unix_user,
284
                    db_connection,
285
                    db_is_mariadb,
286
                    group_denylist,
287
                )
288
                .await;
289
                Response::CompleteDatabaseName(result)
290
            } else {
291
                Response::CompleteDatabaseName(vec![])
292
            }
293
        }
294
        Request::CompleteUserName(ref partial_user_name) => {
295
            // TODO: more correct validation here
296
            if partial_user_name
297
                .chars()
298
                .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
299
            {
300
                let result = complete_user_name(
301
                    partial_user_name,
302
                    unix_user,
303
                    db_connection,
304
                    db_is_mariadb,
305
                    group_denylist,
306
                )
307
                .await;
308
                Response::CompleteUserName(result)
309
            } else {
310
                Response::CompleteUserName(vec![])
311
            }
312
        }
313
        Request::CreateDatabases(ref databases_names) => {
314
            let result = create_databases(
315
                databases_names,
316
                unix_user,
317
                db_connection,
318
                db_is_mariadb,
319
                group_denylist,
320
            )
321
            .await;
322
            Response::CreateDatabases(result)
323
        }
324
        Request::DropDatabases(ref databases_names) => {
325
            let result = drop_databases(
326
                databases_names,
327
                unix_user,
328
                db_connection,
329
                db_is_mariadb,
330
                group_denylist,
331
            )
332
            .await;
333
            Response::DropDatabases(result)
334
        }
335
        Request::ListDatabases(ref database_names) => {
336
            if let Some(database_names) = database_names {
337
                let result = list_databases(
338
                    database_names,
339
                    unix_user,
340
                    db_connection,
341
                    db_is_mariadb,
342
                    group_denylist,
343
                )
344
                .await;
345
                Response::ListDatabases(result)
346
            } else {
347
                let result = list_all_databases_for_user(
348
                    unix_user,
349
                    db_connection,
350
                    db_is_mariadb,
351
                    group_denylist,
352
                )
353
                .await;
354
                Response::ListAllDatabases(result)
355
            }
356
        }
357
        Request::ListPrivileges(ref database_names) => {
358
            if let Some(database_names) = database_names {
359
                let privilege_data = get_databases_privilege_data(
360
                    database_names,
361
                    unix_user,
362
                    db_connection,
363
                    db_is_mariadb,
364
                    group_denylist,
365
                )
366
                .await;
367
                Response::ListPrivileges(privilege_data)
368
            } else {
369
                let privilege_data = get_all_database_privileges(
370
                    unix_user,
371
                    db_connection,
372
                    db_is_mariadb,
373
                    group_denylist,
374
                )
375
                .await;
376
                Response::ListAllPrivileges(privilege_data)
377
            }
378
        }
379
        Request::ModifyPrivileges(ref database_privilege_diffs) => {
380
            let result = apply_privilege_diffs(
381
                database_privilege_diffs,
382
                unix_user,
383
                db_connection,
384
                db_is_mariadb,
385
                group_denylist,
386
            )
387
            .await;
388
            Response::ModifyPrivileges(result)
389
        }
390
        Request::CreateUsers(ref db_users) => {
391
            let result = create_database_users(
392
                db_users,
393
                unix_user,
394
                db_connection,
395
                db_is_mariadb,
396
                group_denylist,
397
            )
398
            .await;
399
            Response::CreateUsers(result)
400
        }
401
        Request::DropUsers(ref db_users) => {
402
            let result = drop_database_users(
403
                db_users,
404
                unix_user,
405
                db_connection,
406
                db_is_mariadb,
407
                group_denylist,
408
            )
409
            .await;
410
            Response::DropUsers(result)
411
        }
412
        Request::PasswdUser((ref db_user, ref password)) => {
413
            let result = set_password_for_database_user(
414
                db_user,
415
                password,
416
                unix_user,
417
                db_connection,
418
                db_is_mariadb,
419
                group_denylist,
420
            )
421
            .await;
422
            Response::SetUserPassword(result)
423
        }
424
        Request::ListUsers(ref db_users) => {
425
            if let Some(db_users) = db_users {
426
                let result = list_database_users(
427
                    db_users,
428
                    unix_user,
429
                    db_connection,
430
                    db_is_mariadb,
431
                    group_denylist,
432
                )
433
                .await;
434
                Response::ListUsers(result)
435
            } else {
436
                let result = list_all_database_users_for_unix_user(
437
                    unix_user,
438
                    db_connection,
439
                    db_is_mariadb,
440
                    group_denylist,
441
                )
442
                .await;
443
                Response::ListAllUsers(result)
444
            }
445
        }
446
        Request::LockUsers(ref db_users) => {
447
            let result = lock_database_users(
448
                db_users,
449
                unix_user,
450
                db_connection,
451
                db_is_mariadb,
452
                group_denylist,
453
            )
454
            .await;
455
            Response::LockUsers(result)
456
        }
457
        Request::UnlockUsers(ref db_users) => {
458
            let result = unlock_database_users(
459
                db_users,
460
                unix_user,
461
                db_connection,
462
                db_is_mariadb,
463
                group_denylist,
464
            )
465
            .await;
466
            Response::UnlockUsers(result)
467
        }
468
        Request::Exit => {
469
            return Ok(false);
470
        }
471
    };
472

            
473
    let response_to_display = match &response {
474
        Response::SetUserPassword(Err(SetPasswordError::MySqlError(_))) => {
475
            &Response::SetUserPassword(Err(SetPasswordError::MySqlError("<REDACTED>".to_string())))
476
        }
477
        response => response,
478
    };
479
    tracing::debug!(
480
        "Response:\n{}",
481
        serde_json::to_string_pretty(&response_to_display)?
482
    );
483

            
484
    log_request(session_id, unix_user, &request, &response);
485

            
486
    stream.send(response).await?;
487
    stream.flush().await?;
488
    tracing::trace!("Successfully processed request");
489

            
490
    Ok(true)
491
}
492

            
493
/// Log a summary of the request and its result.
494
fn log_request(
495
    session_id: SessionId,
496
    unix_user: &UnixUser,
497
    request: &Request,
498
    response: &Response,
499
) {
500
    tracing::info!(
501
        "[{}|session:{}|user:{unix_user}] {}",
502
        response.ok_status(),
503
        session_id.inner(),
504
        request.log_summary(),
505
    );
506
}