mysqladm/cli/
user_command.rs

1use anyhow::Context;
2use clap::Parser;
3use dialoguer::{Confirm, Password};
4use futures_util::{SinkExt, StreamExt};
5
6use crate::core::protocol::{
7    ClientToServerMessageStream, ListUsersError, MySQLUser, Request, Response,
8    print_create_users_output_status, print_create_users_output_status_json,
9    print_drop_users_output_status, print_drop_users_output_status_json,
10    print_lock_users_output_status, print_lock_users_output_status_json,
11    print_set_password_output_status, print_unlock_users_output_status,
12    print_unlock_users_output_status_json,
13};
14
15use super::common::erroneous_server_response;
16
17#[derive(Parser, Debug, Clone)]
18pub struct UserArgs {
19    #[clap(subcommand)]
20    subcmd: UserCommand,
21}
22
23#[allow(clippy::enum_variant_names)]
24#[derive(Parser, Debug, Clone)]
25pub enum UserCommand {
26    /// Create one or more users
27    #[command()]
28    CreateUser(UserCreateArgs),
29
30    /// Delete one or more users
31    #[command()]
32    DropUser(UserDeleteArgs),
33
34    /// Change the MySQL password for a user
35    #[command()]
36    PasswdUser(UserPasswdArgs),
37
38    /// Print information about one or more users
39    ///
40    /// If no username is provided, all users you have access will be shown.
41    #[command()]
42    ShowUser(UserShowArgs),
43
44    /// Lock account for one or more users
45    #[command()]
46    LockUser(UserLockArgs),
47
48    /// Unlock account for one or more users
49    #[command()]
50    UnlockUser(UserUnlockArgs),
51}
52
53#[derive(Parser, Debug, Clone)]
54pub struct UserCreateArgs {
55    #[arg(num_args = 1..)]
56    username: Vec<MySQLUser>,
57
58    /// Do not ask for a password, leave it unset
59    #[clap(long)]
60    no_password: bool,
61
62    /// Print the information as JSON
63    ///
64    /// Note that this implies `--no-password`, since the command will become non-interactive.
65    #[arg(short, long)]
66    json: bool,
67}
68
69#[derive(Parser, Debug, Clone)]
70pub struct UserDeleteArgs {
71    #[arg(num_args = 1..)]
72    username: Vec<MySQLUser>,
73
74    /// Print the information as JSON
75    #[arg(short, long)]
76    json: bool,
77}
78
79#[derive(Parser, Debug, Clone)]
80pub struct UserPasswdArgs {
81    username: MySQLUser,
82
83    #[clap(short, long)]
84    password_file: Option<String>,
85
86    /// Print the information as JSON
87    #[arg(short, long)]
88    json: bool,
89}
90
91#[derive(Parser, Debug, Clone)]
92pub struct UserShowArgs {
93    #[arg(num_args = 0..)]
94    username: Vec<MySQLUser>,
95
96    /// Print the information as JSON
97    #[arg(short, long)]
98    json: bool,
99}
100
101#[derive(Parser, Debug, Clone)]
102pub struct UserLockArgs {
103    #[arg(num_args = 1..)]
104    username: Vec<MySQLUser>,
105
106    /// Print the information as JSON
107    #[arg(short, long)]
108    json: bool,
109}
110
111#[derive(Parser, Debug, Clone)]
112pub struct UserUnlockArgs {
113    #[arg(num_args = 1..)]
114    username: Vec<MySQLUser>,
115
116    /// Print the information as JSON
117    #[arg(short, long)]
118    json: bool,
119}
120
121pub async fn handle_command(
122    command: UserCommand,
123    server_connection: ClientToServerMessageStream,
124) -> anyhow::Result<()> {
125    match command {
126        UserCommand::CreateUser(args) => create_users(args, server_connection).await,
127        UserCommand::DropUser(args) => drop_users(args, server_connection).await,
128        UserCommand::PasswdUser(args) => passwd_user(args, server_connection).await,
129        UserCommand::ShowUser(args) => show_users(args, server_connection).await,
130        UserCommand::LockUser(args) => lock_users(args, server_connection).await,
131        UserCommand::UnlockUser(args) => unlock_users(args, server_connection).await,
132    }
133}
134
135async fn create_users(
136    args: UserCreateArgs,
137    mut server_connection: ClientToServerMessageStream,
138) -> anyhow::Result<()> {
139    if args.username.is_empty() {
140        anyhow::bail!("No usernames provided");
141    }
142
143    let message = Request::CreateUsers(args.username.to_owned());
144    if let Err(err) = server_connection.send(message).await {
145        server_connection.close().await.ok();
146        anyhow::bail!(anyhow::Error::from(err).context("Failed to communicate with server"));
147    }
148
149    let result = match server_connection.next().await {
150        Some(Ok(Response::CreateUsers(result))) => result,
151        response => return erroneous_server_response(response),
152    };
153
154    if args.json {
155        print_create_users_output_status_json(&result);
156    } else {
157        print_create_users_output_status(&result);
158
159        let successfully_created_users = result
160            .iter()
161            .filter_map(|(username, result)| result.as_ref().ok().map(|_| username))
162            .collect::<Vec<_>>();
163
164        for username in successfully_created_users {
165            if !args.no_password
166                && Confirm::new()
167                    .with_prompt(format!(
168                        "Do you want to set a password for user '{}'?",
169                        username
170                    ))
171                    .default(false)
172                    .interact()?
173            {
174                let password = read_password_from_stdin_with_double_check(username)?;
175                let message = Request::PasswdUser(username.to_owned(), password);
176
177                if let Err(err) = server_connection.send(message).await {
178                    server_connection.close().await.ok();
179                    anyhow::bail!(err);
180                }
181
182                match server_connection.next().await {
183                    Some(Ok(Response::PasswdUser(result))) => {
184                        print_set_password_output_status(&result, username)
185                    }
186                    response => return erroneous_server_response(response),
187                }
188
189                println!();
190            }
191        }
192    }
193
194    server_connection.send(Request::Exit).await?;
195
196    Ok(())
197}
198
199async fn drop_users(
200    args: UserDeleteArgs,
201    mut server_connection: ClientToServerMessageStream,
202) -> anyhow::Result<()> {
203    if args.username.is_empty() {
204        anyhow::bail!("No usernames provided");
205    }
206
207    let message = Request::DropUsers(args.username.to_owned());
208
209    if let Err(err) = server_connection.send(message).await {
210        server_connection.close().await.ok();
211        anyhow::bail!(err);
212    }
213
214    let result = match server_connection.next().await {
215        Some(Ok(Response::DropUsers(result))) => result,
216        response => return erroneous_server_response(response),
217    };
218
219    server_connection.send(Request::Exit).await?;
220
221    if args.json {
222        print_drop_users_output_status_json(&result);
223    } else {
224        print_drop_users_output_status(&result);
225    }
226
227    Ok(())
228}
229
230pub fn read_password_from_stdin_with_double_check(username: &MySQLUser) -> anyhow::Result<String> {
231    Password::new()
232        .with_prompt(format!("New MySQL password for user '{}'", username))
233        .with_confirmation(
234            format!("Retype new MySQL password for user '{}'", username),
235            "Passwords do not match",
236        )
237        .interact()
238        .map_err(Into::into)
239}
240
241async fn passwd_user(
242    args: UserPasswdArgs,
243    mut server_connection: ClientToServerMessageStream,
244) -> anyhow::Result<()> {
245    // TODO: create a "user" exists check" command
246    let message = Request::ListUsers(Some(vec![args.username.to_owned()]));
247    if let Err(err) = server_connection.send(message).await {
248        server_connection.close().await.ok();
249        anyhow::bail!(err);
250    }
251    let response = match server_connection.next().await {
252        Some(Ok(Response::ListUsers(users))) => users,
253        response => return erroneous_server_response(response),
254    };
255    match response
256        .get(&args.username)
257        .unwrap_or(&Err(ListUsersError::UserDoesNotExist))
258    {
259        Ok(_) => {}
260        Err(err) => {
261            server_connection.send(Request::Exit).await?;
262            server_connection.close().await.ok();
263            anyhow::bail!("{}", err.to_error_message(&args.username));
264        }
265    }
266
267    let password = if let Some(password_file) = args.password_file {
268        std::fs::read_to_string(password_file)
269            .context("Failed to read password file")?
270            .trim()
271            .to_string()
272    } else {
273        read_password_from_stdin_with_double_check(&args.username)?
274    };
275
276    let message = Request::PasswdUser(args.username.to_owned(), password);
277
278    if let Err(err) = server_connection.send(message).await {
279        server_connection.close().await.ok();
280        anyhow::bail!(err);
281    }
282
283    let result = match server_connection.next().await {
284        Some(Ok(Response::PasswdUser(result))) => result,
285        response => return erroneous_server_response(response),
286    };
287
288    server_connection.send(Request::Exit).await?;
289
290    print_set_password_output_status(&result, &args.username);
291
292    Ok(())
293}
294
295async fn show_users(
296    args: UserShowArgs,
297    mut server_connection: ClientToServerMessageStream,
298) -> anyhow::Result<()> {
299    let message = if args.username.is_empty() {
300        Request::ListUsers(None)
301    } else {
302        Request::ListUsers(Some(args.username.to_owned()))
303    };
304
305    if let Err(err) = server_connection.send(message).await {
306        server_connection.close().await.ok();
307        anyhow::bail!(err);
308    }
309
310    let users = match server_connection.next().await {
311        Some(Ok(Response::ListUsers(users))) => users
312            .into_iter()
313            .filter_map(|(username, result)| match result {
314                Ok(user) => Some(user),
315                Err(err) => {
316                    eprintln!("{}", err.to_error_message(&username));
317                    eprintln!("Skipping...");
318                    None
319                }
320            })
321            .collect::<Vec<_>>(),
322        Some(Ok(Response::ListAllUsers(users))) => match users {
323            Ok(users) => users,
324            Err(err) => {
325                server_connection.send(Request::Exit).await?;
326                return Err(
327                    anyhow::anyhow!(err.to_error_message()).context("Failed to list all users")
328                );
329            }
330        },
331        response => return erroneous_server_response(response),
332    };
333
334    server_connection.send(Request::Exit).await?;
335
336    if args.json {
337        println!(
338            "{}",
339            serde_json::to_string_pretty(&users).context("Failed to serialize users to JSON")?
340        );
341    } else if users.is_empty() {
342        println!("No users to show.");
343    } else {
344        let mut table = prettytable::Table::new();
345        table.add_row(row![
346            "User",
347            "Password is set",
348            "Locked",
349            "Databases where user has privileges"
350        ]);
351        for user in users {
352            table.add_row(row![
353                user.user,
354                user.has_password,
355                user.is_locked,
356                user.databases.join("\n")
357            ]);
358        }
359        table.printstd();
360    }
361
362    Ok(())
363}
364
365async fn lock_users(
366    args: UserLockArgs,
367    mut server_connection: ClientToServerMessageStream,
368) -> anyhow::Result<()> {
369    if args.username.is_empty() {
370        anyhow::bail!("No usernames provided");
371    }
372
373    let message = Request::LockUsers(args.username.to_owned());
374
375    if let Err(err) = server_connection.send(message).await {
376        server_connection.close().await.ok();
377        anyhow::bail!(err);
378    }
379
380    let result = match server_connection.next().await {
381        Some(Ok(Response::LockUsers(result))) => result,
382        response => return erroneous_server_response(response),
383    };
384
385    server_connection.send(Request::Exit).await?;
386
387    if args.json {
388        print_lock_users_output_status_json(&result);
389    } else {
390        print_lock_users_output_status(&result);
391    }
392
393    Ok(())
394}
395
396async fn unlock_users(
397    args: UserUnlockArgs,
398    mut server_connection: ClientToServerMessageStream,
399) -> anyhow::Result<()> {
400    if args.username.is_empty() {
401        anyhow::bail!("No usernames provided");
402    }
403
404    let message = Request::UnlockUsers(args.username.to_owned());
405
406    if let Err(err) = server_connection.send(message).await {
407        server_connection.close().await.ok();
408        anyhow::bail!(err);
409    }
410
411    let result = match server_connection.next().await {
412        Some(Ok(Response::UnlockUsers(result))) => result,
413        response => return erroneous_server_response(response),
414    };
415
416    server_connection.send(Request::Exit).await?;
417
418    if args.json {
419        print_unlock_users_output_status_json(&result);
420    } else {
421        print_unlock_users_output_status(&result);
422    }
423
424    Ok(())
425}