1
use std::collections::BTreeSet;
2

            
3
use futures_util::{SinkExt, StreamExt};
4
use indoc::concatdoc;
5
use sqlx::{MySql, MySqlConnection, MySqlPool, pool::PoolConnection};
6
use tokio::net::UnixStream;
7

            
8
use crate::{
9
    core::{
10
        common::UnixUser,
11
        protocol::{
12
            Request, Response, ServerToClientMessageStream, SetPasswordError,
13
            create_server_to_client_message_stream,
14
        },
15
    },
16
    server::sql::{
17
        database_operations::{
18
            create_databases, drop_databases, list_all_databases_for_user, list_databases,
19
        },
20
        database_privilege_operations::{
21
            apply_privilege_diffs, get_all_database_privileges, get_databases_privilege_data,
22
        },
23
        user_operations::{
24
            create_database_users, drop_database_users, list_all_database_users_for_unix_user,
25
            list_database_users, lock_database_users, set_password_for_database_user,
26
            unlock_database_users,
27
        },
28
    },
29
};
30

            
31
// TODO: don't use database connection unless necessary.
32

            
33
pub async fn session_handler(
34
    socket: UnixStream,
35
    unix_user: &UnixUser,
36
    db_pool: MySqlPool,
37
) -> anyhow::Result<()> {
38
    let mut message_stream = create_server_to_client_message_stream(socket);
39

            
40
    log::debug!("Opening connection to database");
41

            
42
    let mut db_connection = match db_pool.acquire().await {
43
        Ok(connection) => connection,
44
        Err(err) => {
45
            message_stream
46
                .send(Response::Error(
47
                    (concatdoc! {
48
                        "Server failed to connect to database\n",
49
                        "Please check the server logs or contact the system administrators"
50
                    })
51
                    .to_string(),
52
                ))
53
                .await?;
54
            message_stream.flush().await?;
55
            return Err(err.into());
56
        }
57
    };
58

            
59
    log::debug!("Successfully connected to database");
60

            
61
    let result =
62
        session_handler_with_db_connection(message_stream, unix_user, &mut db_connection).await;
63

            
64
    close_or_ignore_db_connection(db_connection).await;
65

            
66
    result
67
}
68

            
69
// TODO: ensure proper db_connection hygiene for functions that invoke
70
//       this function
71

            
72
async fn session_handler_with_db_connection(
73
    mut stream: ServerToClientMessageStream,
74
    unix_user: &UnixUser,
75
    db_connection: &mut MySqlConnection,
76
) -> anyhow::Result<()> {
77
    stream.send(Response::Ready).await?;
78
    loop {
79
        // TODO: better error handling
80
        let request = match stream.next().await {
81
            Some(Ok(request)) => request,
82
            Some(Err(e)) => return Err(e.into()),
83
            None => {
84
                log::warn!("Client disconnected without sending an exit message");
85
                break;
86
            }
87
        };
88

            
89
        // TODO: don't clone the request
90
        let request_to_display = match &request {
91
            Request::PasswdUser((db_user, _)) => {
92
                Request::PasswdUser((db_user.to_owned(), "<REDACTED>".to_string()))
93
            }
94
            request => request.to_owned(),
95
        };
96
        log::info!("Received request: {:#?}", request_to_display);
97

            
98
        let response = match request {
99
            Request::CreateDatabases(databases_names) => {
100
                let result = create_databases(databases_names, unix_user, db_connection).await;
101
                Response::CreateDatabases(result)
102
            }
103
            Request::DropDatabases(databases_names) => {
104
                let result = drop_databases(databases_names, unix_user, db_connection).await;
105
                Response::DropDatabases(result)
106
            }
107
            Request::ListDatabases(database_names) => match database_names {
108
                Some(database_names) => {
109
                    let result = list_databases(database_names, unix_user, db_connection).await;
110
                    Response::ListDatabases(result)
111
                }
112
                None => {
113
                    let result = list_all_databases_for_user(unix_user, db_connection).await;
114
                    Response::ListAllDatabases(result)
115
                }
116
            },
117
            Request::ListPrivileges(database_names) => match database_names {
118
                Some(database_names) => {
119
                    let privilege_data =
120
                        get_databases_privilege_data(database_names, unix_user, db_connection)
121
                            .await;
122
                    Response::ListPrivileges(privilege_data)
123
                }
124
                None => {
125
                    let privilege_data =
126
                        get_all_database_privileges(unix_user, db_connection).await;
127
                    Response::ListAllPrivileges(privilege_data)
128
                }
129
            },
130
            Request::ModifyPrivileges(database_privilege_diffs) => {
131
                let result = apply_privilege_diffs(
132
                    BTreeSet::from_iter(database_privilege_diffs),
133
                    unix_user,
134
                    db_connection,
135
                )
136
                .await;
137
                Response::ModifyPrivileges(result)
138
            }
139
            Request::CreateUsers(db_users) => {
140
                let result = create_database_users(db_users, unix_user, db_connection).await;
141
                Response::CreateUsers(result)
142
            }
143
            Request::DropUsers(db_users) => {
144
                let result = drop_database_users(db_users, unix_user, db_connection).await;
145
                Response::DropUsers(result)
146
            }
147
            Request::PasswdUser((db_user, password)) => {
148
                let result =
149
                    set_password_for_database_user(&db_user, &password, unix_user, db_connection)
150
                        .await;
151
                Response::SetUserPassword(result)
152
            }
153
            Request::ListUsers(db_users) => match db_users {
154
                Some(db_users) => {
155
                    let result = list_database_users(db_users, unix_user, db_connection).await;
156
                    Response::ListUsers(result)
157
                }
158
                None => {
159
                    let result =
160
                        list_all_database_users_for_unix_user(unix_user, db_connection).await;
161
                    Response::ListAllUsers(result)
162
                }
163
            },
164
            Request::LockUsers(db_users) => {
165
                let result = lock_database_users(db_users, unix_user, db_connection).await;
166
                Response::LockUsers(result)
167
            }
168
            Request::UnlockUsers(db_users) => {
169
                let result = unlock_database_users(db_users, unix_user, db_connection).await;
170
                Response::UnlockUsers(result)
171
            }
172
            Request::Exit => {
173
                break;
174
            }
175
        };
176

            
177
        // TODO: don't clone the response
178
        let response_to_display = match &response {
179
            Response::SetUserPassword(Err(SetPasswordError::MySqlError(_))) => {
180
                Response::SetUserPassword(Err(SetPasswordError::MySqlError(
181
                    "<REDACTED>".to_string(),
182
                )))
183
            }
184
            response => response.to_owned(),
185
        };
186
        log::info!("Response: {:#?}", response_to_display);
187

            
188
        stream.send(response).await?;
189
        stream.flush().await?;
190
        log::debug!("Successfully processed request");
191
    }
192

            
193
    Ok(())
194
}
195

            
196
async fn close_or_ignore_db_connection(db_connection: PoolConnection<MySql>) {
197
    if let Err(e) = db_connection.close().await {
198
        log::error!("Failed to close database connection: {}", e);
199
        log::error!("{}", e);
200
        log::error!("Ignoring...");
201
    }
202
}