1
use std::{io::IsTerminal, path::PathBuf};
2

            
3
use anyhow::Context;
4
use clap::Parser;
5
use clap_complete::ArgValueCompleter;
6
use dialoguer::Password;
7
use futures_util::SinkExt;
8
use tokio_stream::StreamExt;
9

            
10
use crate::{
11
    client::commands::{erroneous_server_response, print_authorization_owner_hint},
12
    core::{
13
        completion::mysql_user_completer,
14
        protocol::{
15
            ClientToServerMessageStream, ListUsersError, Request, Response, SetPasswordError,
16
            print_set_password_output_status, request_validation::ValidationError,
17
        },
18
        types::MySQLUser,
19
    },
20
};
21

            
22
#[derive(Parser, Debug, Clone)]
23
pub struct PasswdUserArgs {
24
    /// The `MySQL` user whose password is to be changed
25
    #[cfg_attr(not(feature = "suid-sgid-mode"), arg(add = ArgValueCompleter::new(mysql_user_completer)))]
26
    #[arg(value_name = "USER_NAME")]
27
    username: MySQLUser,
28

            
29
    /// Read the new password from a file instead of prompting for it
30
    #[clap(short, long, value_name = "PATH", conflicts_with = "stdin")]
31
    password_file: Option<PathBuf>,
32

            
33
    /// Read the new password from stdin instead of prompting for it
34
    #[clap(short = 'i', long, conflicts_with = "password_file")]
35
    stdin: bool,
36

            
37
    /// Print the information as JSON
38
    #[arg(short, long)]
39
    json: bool,
40
}
41

            
42
pub fn read_password_from_stdin_with_double_check(username: &MySQLUser) -> anyhow::Result<String> {
43
    Password::new()
44
        .with_prompt(format!("New MySQL password for user '{username}'"))
45
        .with_confirmation(
46
            format!("Retype new MySQL password for user '{username}'"),
47
            "Passwords do not match",
48
        )
49
        .interact()
50
        .map_err(Into::into)
51
}
52

            
53
pub async fn passwd_user(
54
    args: PasswdUserArgs,
55
    mut server_connection: ClientToServerMessageStream,
56
) -> anyhow::Result<()> {
57
    // TODO: create a "user" exists check" command
58
    let message = Request::ListUsers(Some(vec![args.username.clone()]));
59
    if let Err(err) = server_connection.send(message).await {
60
        server_connection.close().await.ok();
61
        anyhow::bail!(err);
62
    }
63
    let response = match server_connection.next().await {
64
        Some(Ok(Response::ListUsers(users))) => users,
65
        response => return erroneous_server_response(response),
66
    };
67
    match response
68
        .get(&args.username)
69
        .unwrap_or(&Err(ListUsersError::UserDoesNotExist))
70
    {
71
        Ok(_) => {}
72
        Err(err) => {
73
            server_connection.send(Request::Exit).await?;
74
            server_connection.close().await.ok();
75
            anyhow::bail!("{}", err.to_error_message(&args.username));
76
        }
77
    }
78

            
79
    let password = if let Some(password_file) = args.password_file {
80
        std::fs::read_to_string(password_file)
81
            .context("Failed to read password file")?
82
            .trim()
83
            .to_string()
84
    } else if args.stdin {
85
        let mut buffer = String::new();
86
        std::io::stdin()
87
            .read_line(&mut buffer)
88
            .context("Failed to read password from stdin")?;
89
        buffer.trim().to_string()
90
    } else {
91
        if !std::io::stdin().is_terminal() {
92
            anyhow::bail!(
93
                "Cannot prompt for password in non-interactive mode. Use --stdin or --password-file to provide the password."
94
            );
95
        }
96
        read_password_from_stdin_with_double_check(&args.username)?
97
    };
98

            
99
    let message = Request::PasswdUser((args.username.clone(), password));
100

            
101
    if let Err(err) = server_connection.send(message).await {
102
        server_connection.close().await.ok();
103
        anyhow::bail!(err);
104
    }
105

            
106
    let result = match server_connection.next().await {
107
        Some(Ok(Response::SetUserPassword(result))) => result,
108
        response => return erroneous_server_response(response),
109
    };
110

            
111
    print_set_password_output_status(&result, &args.username);
112

            
113
    if matches!(
114
        result,
115
        Err(SetPasswordError::ValidationError(
116
            ValidationError::AuthorizationError(_)
117
        ))
118
    ) {
119
        print_authorization_owner_hint(&mut server_connection).await?;
120
    }
121

            
122
    server_connection.send(Request::Exit).await?;
123

            
124
    if result.is_err() {
125
        std::process::exit(1);
126
    }
127

            
128
    Ok(())
129
}