1
use anyhow::Context;
2
use clap::Parser;
3
use dialoguer::Password;
4
use futures_util::SinkExt;
5
use tokio_stream::StreamExt;
6

            
7
use crate::{
8
    client::commands::erroneous_server_response,
9
    core::{
10
        protocol::{
11
            ClientToServerMessageStream, ListUsersError, Request, Response,
12
            print_set_password_output_status,
13
        },
14
        types::MySQLUser,
15
    },
16
};
17

            
18
#[derive(Parser, Debug, Clone)]
19
pub struct PasswdUserArgs {
20
    username: MySQLUser,
21

            
22
    #[clap(short, long)]
23
    password_file: Option<String>,
24

            
25
    /// Print the information as JSON
26
    #[arg(short, long)]
27
    json: bool,
28
}
29

            
30
pub fn read_password_from_stdin_with_double_check(username: &MySQLUser) -> anyhow::Result<String> {
31
    Password::new()
32
        .with_prompt(format!("New MySQL password for user '{}'", username))
33
        .with_confirmation(
34
            format!("Retype new MySQL password for user '{}'", username),
35
            "Passwords do not match",
36
        )
37
        .interact()
38
        .map_err(Into::into)
39
}
40

            
41
pub async fn passwd_user(
42
    args: PasswdUserArgs,
43
    mut server_connection: ClientToServerMessageStream,
44
) -> anyhow::Result<()> {
45
    // TODO: create a "user" exists check" command
46
    let message = Request::ListUsers(Some(vec![args.username.to_owned()]));
47
    if let Err(err) = server_connection.send(message).await {
48
        server_connection.close().await.ok();
49
        anyhow::bail!(err);
50
    }
51
    let response = match server_connection.next().await {
52
        Some(Ok(Response::ListUsers(users))) => users,
53
        response => return erroneous_server_response(response),
54
    };
55
    match response
56
        .get(&args.username)
57
        .unwrap_or(&Err(ListUsersError::UserDoesNotExist))
58
    {
59
        Ok(_) => {}
60
        Err(err) => {
61
            server_connection.send(Request::Exit).await?;
62
            server_connection.close().await.ok();
63
            anyhow::bail!("{}", err.to_error_message(&args.username));
64
        }
65
    }
66

            
67
    let password = if let Some(password_file) = args.password_file {
68
        std::fs::read_to_string(password_file)
69
            .context("Failed to read password file")?
70
            .trim()
71
            .to_string()
72
    } else {
73
        read_password_from_stdin_with_double_check(&args.username)?
74
    };
75

            
76
    let message = Request::PasswdUser((args.username.to_owned(), password));
77

            
78
    if let Err(err) = server_connection.send(message).await {
79
        server_connection.close().await.ok();
80
        anyhow::bail!(err);
81
    }
82

            
83
    let result = match server_connection.next().await {
84
        Some(Ok(Response::SetUserPassword(result))) => result,
85
        response => return erroneous_server_response(response),
86
    };
87

            
88
    server_connection.send(Request::Exit).await?;
89

            
90
    print_set_password_output_status(&result, &args.username);
91

            
92
    Ok(())
93
}