1
mod check_authorization;
2
mod complete_database_name;
3
mod complete_user_name;
4
mod create_databases;
5
mod create_users;
6
mod drop_databases;
7
mod drop_users;
8
mod list_all_databases;
9
mod list_all_privileges;
10
mod list_all_users;
11
mod list_databases;
12
mod list_privileges;
13
mod list_users;
14
mod list_valid_name_prefixes;
15
mod lock_users;
16
mod modify_privileges;
17
mod passwd_user;
18
mod unlock_users;
19

            
20
pub use check_authorization::*;
21
pub use complete_database_name::*;
22
pub use complete_user_name::*;
23
pub use create_databases::*;
24
pub use create_users::*;
25
pub use drop_databases::*;
26
pub use drop_users::*;
27
pub use list_all_databases::*;
28
pub use list_all_privileges::*;
29
pub use list_all_users::*;
30
pub use list_databases::*;
31
pub use list_privileges::*;
32
pub use list_users::*;
33
pub use list_valid_name_prefixes::*;
34
pub use lock_users::*;
35
pub use modify_privileges::*;
36
pub use passwd_user::*;
37
pub use unlock_users::*;
38

            
39
use std::collections::BTreeSet;
40
use std::fmt;
41

            
42
use serde::{Deserialize, Serialize};
43
use tokio::net::UnixStream;
44
use tokio_serde::{Framed as SerdeFramed, formats::Bincode};
45
use tokio_util::codec::{Framed, LengthDelimitedCodec};
46

            
47
use crate::core::types::{MySQLDatabase, MySQLUser};
48

            
49
pub type ServerToClientMessageStream = SerdeFramed<
50
    Framed<UnixStream, LengthDelimitedCodec>,
51
    Request,
52
    Response,
53
    Bincode<Request, Response>,
54
>;
55

            
56
pub type ClientToServerMessageStream = SerdeFramed<
57
    Framed<UnixStream, LengthDelimitedCodec>,
58
    Response,
59
    Request,
60
    Bincode<Response, Request>,
61
>;
62

            
63
const MAX_REQUEST_FRAME_LENGTH: usize = 100 * 1024; // 100 KB
64
const MAX_RESPONSE_FRAME_LENGTH: usize = 1024 * 1024; // 1 MB
65

            
66
pub fn create_client_to_server_message_stream(socket: UnixStream) -> ClientToServerMessageStream {
67
    let codec = {
68
        let mut codec = LengthDelimitedCodec::new();
69
        codec.set_max_frame_length(MAX_REQUEST_FRAME_LENGTH);
70
        codec
71
    };
72
    let length_delimited = Framed::new(socket, codec);
73
    tokio_serde::Framed::new(length_delimited, Bincode::default())
74
}
75

            
76
pub fn create_server_to_client_message_stream(socket: UnixStream) -> ServerToClientMessageStream {
77
    let codec = {
78
        let mut codec = LengthDelimitedCodec::new();
79
        codec.set_max_frame_length(MAX_RESPONSE_FRAME_LENGTH);
80
        codec
81
    };
82
    let length_delimited = Framed::new(socket, codec);
83
    tokio_serde::Framed::new(length_delimited, Bincode::default())
84
}
85

            
86
#[non_exhaustive]
87
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
88
pub enum Request {
89
    CheckAuthorization(CheckAuthorizationRequest),
90

            
91
    ListValidNamePrefixes,
92
    CompleteDatabaseName(CompleteDatabaseNameRequest),
93
    CompleteUserName(CompleteUserNameRequest),
94

            
95
    CreateDatabases(CreateDatabasesRequest),
96
    DropDatabases(DropDatabasesRequest),
97
    ListDatabases(ListDatabasesRequest),
98
    ListPrivileges(ListPrivilegesRequest),
99
    ModifyPrivileges(ModifyPrivilegesRequest),
100

            
101
    CreateUsers(CreateUsersRequest),
102
    DropUsers(DropUsersRequest),
103
    PasswdUser(SetUserPasswordRequest),
104
    ListUsers(ListUsersRequest),
105
    LockUsers(LockUsersRequest),
106
    UnlockUsers(UnlockUsersRequest),
107

            
108
    // Commit,
109
    Exit,
110
}
111

            
112
impl Request {
113
    /// Get the command name associated with this request.
114
    pub fn command_name(&self) -> &str {
115
        match self {
116
            Request::CheckAuthorization(_) => "check-authorization",
117
            Request::ListValidNamePrefixes => "list-valid-name-prefixes",
118
            Request::CompleteDatabaseName(_) => "complete-database-name",
119
            Request::CompleteUserName(_) => "complete-user-name",
120
            Request::CreateDatabases(_) => "create-databases",
121
            Request::DropDatabases(_) => "drop-databases",
122
            Request::ListDatabases(_) => "list-databases",
123
            Request::ListPrivileges(_) => "list-privileges",
124
            Request::ModifyPrivileges(_) => "modify-privileges",
125
            Request::CreateUsers(_) => "create-users",
126
            Request::DropUsers(_) => "drop-users",
127
            Request::PasswdUser(_) => "passwd-user",
128
            Request::ListUsers(_) => "list-users",
129
            Request::LockUsers(_) => "lock-users",
130
            Request::UnlockUsers(_) => "unlock-users",
131
            Request::Exit => "exit",
132
        }
133
    }
134

            
135
    /// Generate a short summary string representing this request for logging purposes.
136
    pub fn log_summary(&self) -> String {
137
        match self {
138
            Request::CheckAuthorization(req) => format!("{}({})", self.command_name(), req.len()),
139

            
140
            Request::CreateDatabases(req) => format!("{}({})", self.command_name(), req.len()),
141
            Request::DropDatabases(req) => format!("{}({})", self.command_name(), req.len()),
142
            Request::ListDatabases(req) => format!(
143
                "{}{}",
144
                self.command_name(),
145
                req.as_ref()
146
                    .map_or("".to_string(), |r| format!("({})", r.len()))
147
            ),
148
            Request::ListPrivileges(req) => format!(
149
                "{}{}",
150
                self.command_name(),
151
                req.as_ref()
152
                    .map_or("".to_string(), |r| format!("({})", r.len()))
153
            ),
154
            Request::ModifyPrivileges(req) => format!("{}({})", self.command_name(), req.len()),
155

            
156
            Request::CreateUsers(req) => format!("{}({})", self.command_name(), req.len()),
157
            Request::DropUsers(req) => format!("{}({})", self.command_name(), req.len()),
158
            Request::ListUsers(req) => format!(
159
                "{}{}",
160
                self.command_name(),
161
                req.as_ref()
162
                    .map_or("".to_string(), |r| format!("({})", r.len()))
163
            ),
164
            Request::LockUsers(req) => format!("{}({})", self.command_name(), req.len()),
165
            Request::UnlockUsers(req) => format!("{}({})", self.command_name(), req.len()),
166

            
167
            _ => self.command_name().to_string(),
168
        }
169
    }
170

            
171
    /// Get the set of users affected by this request.
172
    pub fn affected_users(&self) -> BTreeSet<MySQLUser> {
173
        match self {
174
            Request::CheckAuthorization(_) => Default::default(),
175
            Request::ListValidNamePrefixes => Default::default(),
176
            Request::CompleteDatabaseName(_) => Default::default(),
177
            Request::CompleteUserName(_) => Default::default(),
178
            Request::CreateDatabases(_) => Default::default(),
179
            Request::DropDatabases(_) => Default::default(),
180
            Request::ListDatabases(_) => Default::default(),
181
            Request::ListPrivileges(_) => Default::default(),
182
            Request::ModifyPrivileges(priv_diffs) => priv_diffs
183
                .iter()
184
                .map(|priv_diff| priv_diff.get_user_name().clone())
185
                .collect(),
186
            Request::CreateUsers(users) => users.iter().cloned().collect(),
187
            Request::DropUsers(users) => users.iter().cloned().collect(),
188
            Request::PasswdUser(user_passwd_req) => {
189
                let mut result = BTreeSet::new();
190
                result.insert(user_passwd_req.0.clone());
191
                result
192
            }
193
            Request::ListUsers(users) => users.clone().unwrap_or_default().into_iter().collect(),
194
            Request::LockUsers(users) => users.iter().cloned().collect(),
195
            Request::UnlockUsers(users) => users.iter().cloned().collect(),
196
            Request::Exit => Default::default(),
197
        }
198
    }
199

            
200
    /// Get the set of databases affected by this request.
201
    pub fn affected_databases(&self) -> BTreeSet<MySQLDatabase> {
202
        match self {
203
            Request::CheckAuthorization(_) => Default::default(),
204
            Request::ListValidNamePrefixes => Default::default(),
205
            Request::CompleteDatabaseName(_) => Default::default(),
206
            Request::CompleteUserName(_) => Default::default(),
207
            Request::CreateDatabases(databases) => databases.iter().cloned().collect(),
208
            Request::DropDatabases(databases) => databases.iter().cloned().collect(),
209
            Request::ListDatabases(databases) => {
210
                databases.clone().unwrap_or_default().into_iter().collect()
211
            }
212
            Request::ListPrivileges(databases) => {
213
                databases.clone().unwrap_or_default().into_iter().collect()
214
            }
215
            Request::ModifyPrivileges(priv_diffs) => priv_diffs
216
                .iter()
217
                .map(|priv_diff| priv_diff.get_database_name().clone())
218
                .collect(),
219
            Request::CreateUsers(_) => Default::default(),
220
            Request::DropUsers(_) => Default::default(),
221
            Request::PasswdUser(_) => Default::default(),
222
            Request::ListUsers(_) => Default::default(),
223
            Request::LockUsers(_) => Default::default(),
224
            Request::UnlockUsers(_) => Default::default(),
225
            Request::Exit => Default::default(),
226
        }
227
    }
228
}
229

            
230
// TODO: include a generic "message" that will display a message to the user?
231

            
232
#[non_exhaustive]
233
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
234
pub enum Response {
235
    CheckAuthorization(CheckAuthorizationResponse),
236

            
237
    ListValidNamePrefixes(ListValidNamePrefixesResponse),
238
    CompleteDatabaseName(CompleteDatabaseNameResponse),
239
    CompleteUserName(CompleteUserNameResponse),
240

            
241
    // Specific data for specific commands
242
    CreateDatabases(CreateDatabasesResponse),
243
    DropDatabases(DropDatabasesResponse),
244
    ListDatabases(ListDatabasesResponse),
245
    ListAllDatabases(ListAllDatabasesResponse),
246
    ListPrivileges(ListPrivilegesResponse),
247
    ListAllPrivileges(ListAllPrivilegesResponse),
248
    ModifyPrivileges(ModifyPrivilegesResponse),
249

            
250
    CreateUsers(CreateUsersResponse),
251
    DropUsers(DropUsersResponse),
252
    SetUserPassword(SetUserPasswordResponse),
253
    ListUsers(ListUsersResponse),
254
    ListAllUsers(ListAllUsersResponse),
255
    LockUsers(LockUsersResponse),
256
    UnlockUsers(UnlockUsersResponse),
257

            
258
    // Generic responses
259
    Ready,
260
    Error(String),
261
}
262

            
263
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
264
pub enum ResponseOkStatus {
265
    Success,
266
    PartialSuccess(usize, usize), // succeeded, total
267
    Error,
268
}
269

            
270
impl ResponseOkStatus {
271
    pub fn from_counts(total: usize, succeeded: usize) -> Self {
272
        if succeeded == total {
273
            ResponseOkStatus::Success
274
        } else if succeeded == 0 {
275
            ResponseOkStatus::Error
276
        } else {
277
            ResponseOkStatus::PartialSuccess(succeeded, total)
278
        }
279
    }
280

            
281
    pub fn from_bool(is_ok: bool) -> Self {
282
        if is_ok {
283
            ResponseOkStatus::Success
284
        } else {
285
            ResponseOkStatus::Error
286
        }
287
    }
288
}
289

            
290
impl fmt::Display for ResponseOkStatus {
291
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
292
        match self {
293
            ResponseOkStatus::Success => write!(f, "OK"),
294
            ResponseOkStatus::PartialSuccess(succeeded, total) => {
295
                write!(f, "PARTIAL_OK({}/{})", succeeded, total)
296
            }
297
            ResponseOkStatus::Error => write!(f, "ERR"),
298
        }
299
    }
300
}
301

            
302
impl Response {
303
    pub fn ok_status(&self) -> ResponseOkStatus {
304
        match self {
305
            Response::CheckAuthorization(res) => {
306
                ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
307
            }
308

            
309
            Response::ListValidNamePrefixes(_) => ResponseOkStatus::Success,
310
            Response::CompleteDatabaseName(_) => ResponseOkStatus::Success,
311
            Response::CompleteUserName(_) => ResponseOkStatus::Success,
312

            
313
            Response::CreateDatabases(res) => {
314
                ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
315
            }
316
            Response::DropDatabases(res) => {
317
                ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
318
            }
319
            Response::ListDatabases(res) => {
320
                ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
321
            }
322
            Response::ListAllDatabases(res) => ResponseOkStatus::from_bool(res.is_ok()),
323
            Response::ListPrivileges(res) => {
324
                ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
325
            }
326
            Response::ListAllPrivileges(res) => ResponseOkStatus::from_bool(res.is_ok()),
327
            Response::ModifyPrivileges(res) => {
328
                ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
329
            }
330

            
331
            Response::CreateUsers(res) => {
332
                ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
333
            }
334
            Response::DropUsers(res) => {
335
                ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
336
            }
337
            Response::SetUserPassword(res) => ResponseOkStatus::from_bool(res.is_ok()),
338
            Response::ListUsers(res) => {
339
                ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
340
            }
341
            Response::ListAllUsers(res) => ResponseOkStatus::from_bool(res.is_ok()),
342
            Response::LockUsers(res) => {
343
                ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
344
            }
345
            Response::UnlockUsers(res) => {
346
                ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
347
            }
348

            
349
            Response::Ready => ResponseOkStatus::Success,
350
            Response::Error(_) => ResponseOkStatus::Error,
351
        }
352
    }
353
}