1use anyhow::Context;
2use clap::Parser;
3use dialoguer::{Confirm, Password};
4use futures_util::{SinkExt, StreamExt};
5
6use crate::core::protocol::{
7 ClientToServerMessageStream, ListUsersError, MySQLUser, Request, Response,
8 print_create_users_output_status, print_create_users_output_status_json,
9 print_drop_users_output_status, print_drop_users_output_status_json,
10 print_lock_users_output_status, print_lock_users_output_status_json,
11 print_set_password_output_status, print_unlock_users_output_status,
12 print_unlock_users_output_status_json,
13};
14
15use super::common::erroneous_server_response;
16
17#[derive(Parser, Debug, Clone)]
18pub struct UserArgs {
19 #[clap(subcommand)]
20 subcmd: UserCommand,
21}
22
23#[allow(clippy::enum_variant_names)]
24#[derive(Parser, Debug, Clone)]
25pub enum UserCommand {
26 #[command()]
28 CreateUser(UserCreateArgs),
29
30 #[command()]
32 DropUser(UserDeleteArgs),
33
34 #[command()]
36 PasswdUser(UserPasswdArgs),
37
38 #[command()]
42 ShowUser(UserShowArgs),
43
44 #[command()]
46 LockUser(UserLockArgs),
47
48 #[command()]
50 UnlockUser(UserUnlockArgs),
51}
52
53#[derive(Parser, Debug, Clone)]
54pub struct UserCreateArgs {
55 #[arg(num_args = 1..)]
56 username: Vec<MySQLUser>,
57
58 #[clap(long)]
60 no_password: bool,
61
62 #[arg(short, long)]
66 json: bool,
67}
68
69#[derive(Parser, Debug, Clone)]
70pub struct UserDeleteArgs {
71 #[arg(num_args = 1..)]
72 username: Vec<MySQLUser>,
73
74 #[arg(short, long)]
76 json: bool,
77}
78
79#[derive(Parser, Debug, Clone)]
80pub struct UserPasswdArgs {
81 username: MySQLUser,
82
83 #[clap(short, long)]
84 password_file: Option<String>,
85
86 #[arg(short, long)]
88 json: bool,
89}
90
91#[derive(Parser, Debug, Clone)]
92pub struct UserShowArgs {
93 #[arg(num_args = 0..)]
94 username: Vec<MySQLUser>,
95
96 #[arg(short, long)]
98 json: bool,
99}
100
101#[derive(Parser, Debug, Clone)]
102pub struct UserLockArgs {
103 #[arg(num_args = 1..)]
104 username: Vec<MySQLUser>,
105
106 #[arg(short, long)]
108 json: bool,
109}
110
111#[derive(Parser, Debug, Clone)]
112pub struct UserUnlockArgs {
113 #[arg(num_args = 1..)]
114 username: Vec<MySQLUser>,
115
116 #[arg(short, long)]
118 json: bool,
119}
120
121pub async fn handle_command(
122 command: UserCommand,
123 server_connection: ClientToServerMessageStream,
124) -> anyhow::Result<()> {
125 match command {
126 UserCommand::CreateUser(args) => create_users(args, server_connection).await,
127 UserCommand::DropUser(args) => drop_users(args, server_connection).await,
128 UserCommand::PasswdUser(args) => passwd_user(args, server_connection).await,
129 UserCommand::ShowUser(args) => show_users(args, server_connection).await,
130 UserCommand::LockUser(args) => lock_users(args, server_connection).await,
131 UserCommand::UnlockUser(args) => unlock_users(args, server_connection).await,
132 }
133}
134
135async fn create_users(
136 args: UserCreateArgs,
137 mut server_connection: ClientToServerMessageStream,
138) -> anyhow::Result<()> {
139 if args.username.is_empty() {
140 anyhow::bail!("No usernames provided");
141 }
142
143 let message = Request::CreateUsers(args.username.to_owned());
144 if let Err(err) = server_connection.send(message).await {
145 server_connection.close().await.ok();
146 anyhow::bail!(anyhow::Error::from(err).context("Failed to communicate with server"));
147 }
148
149 let result = match server_connection.next().await {
150 Some(Ok(Response::CreateUsers(result))) => result,
151 response => return erroneous_server_response(response),
152 };
153
154 if args.json {
155 print_create_users_output_status_json(&result);
156 } else {
157 print_create_users_output_status(&result);
158
159 let successfully_created_users = result
160 .iter()
161 .filter_map(|(username, result)| result.as_ref().ok().map(|_| username))
162 .collect::<Vec<_>>();
163
164 for username in successfully_created_users {
165 if !args.no_password
166 && Confirm::new()
167 .with_prompt(format!(
168 "Do you want to set a password for user '{}'?",
169 username
170 ))
171 .default(false)
172 .interact()?
173 {
174 let password = read_password_from_stdin_with_double_check(username)?;
175 let message = Request::PasswdUser(username.to_owned(), password);
176
177 if let Err(err) = server_connection.send(message).await {
178 server_connection.close().await.ok();
179 anyhow::bail!(err);
180 }
181
182 match server_connection.next().await {
183 Some(Ok(Response::PasswdUser(result))) => {
184 print_set_password_output_status(&result, username)
185 }
186 response => return erroneous_server_response(response),
187 }
188
189 println!();
190 }
191 }
192 }
193
194 server_connection.send(Request::Exit).await?;
195
196 Ok(())
197}
198
199async fn drop_users(
200 args: UserDeleteArgs,
201 mut server_connection: ClientToServerMessageStream,
202) -> anyhow::Result<()> {
203 if args.username.is_empty() {
204 anyhow::bail!("No usernames provided");
205 }
206
207 let message = Request::DropUsers(args.username.to_owned());
208
209 if let Err(err) = server_connection.send(message).await {
210 server_connection.close().await.ok();
211 anyhow::bail!(err);
212 }
213
214 let result = match server_connection.next().await {
215 Some(Ok(Response::DropUsers(result))) => result,
216 response => return erroneous_server_response(response),
217 };
218
219 server_connection.send(Request::Exit).await?;
220
221 if args.json {
222 print_drop_users_output_status_json(&result);
223 } else {
224 print_drop_users_output_status(&result);
225 }
226
227 Ok(())
228}
229
230pub fn read_password_from_stdin_with_double_check(username: &MySQLUser) -> anyhow::Result<String> {
231 Password::new()
232 .with_prompt(format!("New MySQL password for user '{}'", username))
233 .with_confirmation(
234 format!("Retype new MySQL password for user '{}'", username),
235 "Passwords do not match",
236 )
237 .interact()
238 .map_err(Into::into)
239}
240
241async fn passwd_user(
242 args: UserPasswdArgs,
243 mut server_connection: ClientToServerMessageStream,
244) -> anyhow::Result<()> {
245 let message = Request::ListUsers(Some(vec![args.username.to_owned()]));
247 if let Err(err) = server_connection.send(message).await {
248 server_connection.close().await.ok();
249 anyhow::bail!(err);
250 }
251 let response = match server_connection.next().await {
252 Some(Ok(Response::ListUsers(users))) => users,
253 response => return erroneous_server_response(response),
254 };
255 match response
256 .get(&args.username)
257 .unwrap_or(&Err(ListUsersError::UserDoesNotExist))
258 {
259 Ok(_) => {}
260 Err(err) => {
261 server_connection.send(Request::Exit).await?;
262 server_connection.close().await.ok();
263 anyhow::bail!("{}", err.to_error_message(&args.username));
264 }
265 }
266
267 let password = if let Some(password_file) = args.password_file {
268 std::fs::read_to_string(password_file)
269 .context("Failed to read password file")?
270 .trim()
271 .to_string()
272 } else {
273 read_password_from_stdin_with_double_check(&args.username)?
274 };
275
276 let message = Request::PasswdUser(args.username.to_owned(), password);
277
278 if let Err(err) = server_connection.send(message).await {
279 server_connection.close().await.ok();
280 anyhow::bail!(err);
281 }
282
283 let result = match server_connection.next().await {
284 Some(Ok(Response::PasswdUser(result))) => result,
285 response => return erroneous_server_response(response),
286 };
287
288 server_connection.send(Request::Exit).await?;
289
290 print_set_password_output_status(&result, &args.username);
291
292 Ok(())
293}
294
295async fn show_users(
296 args: UserShowArgs,
297 mut server_connection: ClientToServerMessageStream,
298) -> anyhow::Result<()> {
299 let message = if args.username.is_empty() {
300 Request::ListUsers(None)
301 } else {
302 Request::ListUsers(Some(args.username.to_owned()))
303 };
304
305 if let Err(err) = server_connection.send(message).await {
306 server_connection.close().await.ok();
307 anyhow::bail!(err);
308 }
309
310 let users = match server_connection.next().await {
311 Some(Ok(Response::ListUsers(users))) => users
312 .into_iter()
313 .filter_map(|(username, result)| match result {
314 Ok(user) => Some(user),
315 Err(err) => {
316 eprintln!("{}", err.to_error_message(&username));
317 eprintln!("Skipping...");
318 None
319 }
320 })
321 .collect::<Vec<_>>(),
322 Some(Ok(Response::ListAllUsers(users))) => match users {
323 Ok(users) => users,
324 Err(err) => {
325 server_connection.send(Request::Exit).await?;
326 return Err(
327 anyhow::anyhow!(err.to_error_message()).context("Failed to list all users")
328 );
329 }
330 },
331 response => return erroneous_server_response(response),
332 };
333
334 server_connection.send(Request::Exit).await?;
335
336 if args.json {
337 println!(
338 "{}",
339 serde_json::to_string_pretty(&users).context("Failed to serialize users to JSON")?
340 );
341 } else if users.is_empty() {
342 println!("No users to show.");
343 } else {
344 let mut table = prettytable::Table::new();
345 table.add_row(row![
346 "User",
347 "Password is set",
348 "Locked",
349 "Databases where user has privileges"
350 ]);
351 for user in users {
352 table.add_row(row![
353 user.user,
354 user.has_password,
355 user.is_locked,
356 user.databases.join("\n")
357 ]);
358 }
359 table.printstd();
360 }
361
362 Ok(())
363}
364
365async fn lock_users(
366 args: UserLockArgs,
367 mut server_connection: ClientToServerMessageStream,
368) -> anyhow::Result<()> {
369 if args.username.is_empty() {
370 anyhow::bail!("No usernames provided");
371 }
372
373 let message = Request::LockUsers(args.username.to_owned());
374
375 if let Err(err) = server_connection.send(message).await {
376 server_connection.close().await.ok();
377 anyhow::bail!(err);
378 }
379
380 let result = match server_connection.next().await {
381 Some(Ok(Response::LockUsers(result))) => result,
382 response => return erroneous_server_response(response),
383 };
384
385 server_connection.send(Request::Exit).await?;
386
387 if args.json {
388 print_lock_users_output_status_json(&result);
389 } else {
390 print_lock_users_output_status(&result);
391 }
392
393 Ok(())
394}
395
396async fn unlock_users(
397 args: UserUnlockArgs,
398 mut server_connection: ClientToServerMessageStream,
399) -> anyhow::Result<()> {
400 if args.username.is_empty() {
401 anyhow::bail!("No usernames provided");
402 }
403
404 let message = Request::UnlockUsers(args.username.to_owned());
405
406 if let Err(err) = server_connection.send(message).await {
407 server_connection.close().await.ok();
408 anyhow::bail!(err);
409 }
410
411 let result = match server_connection.next().await {
412 Some(Ok(Response::UnlockUsers(result))) => result,
413 response => return erroneous_server_response(response),
414 };
415
416 server_connection.send(Request::Exit).await?;
417
418 if args.json {
419 print_unlock_users_output_status_json(&result);
420 } else {
421 print_unlock_users_output_status(&result);
422 }
423
424 Ok(())
425}