1
use std::io::IsTerminal;
2

            
3
use clap::Parser;
4
use clap_complete::ArgValueCompleter;
5
use dialoguer::Confirm;
6
use futures_util::SinkExt;
7
use tokio_stream::StreamExt;
8

            
9
use crate::{
10
    client::commands::{
11
        erroneous_server_response, print_authorization_owner_hint,
12
        read_password_from_stdin_with_double_check,
13
    },
14
    core::{
15
        completion::prefix_completer,
16
        protocol::{
17
            ClientToServerMessageStream, CreateUserError, Request, Response,
18
            print_create_users_output_status, print_create_users_output_status_json,
19
            print_set_password_output_status, request_validation::ValidationError,
20
        },
21
        types::MySQLUser,
22
    },
23
};
24

            
25
#[derive(Parser, Debug, Clone)]
26
pub struct CreateUserArgs {
27
    /// The `MySQL` user(s) to create
28
    #[arg(num_args = 1.., value_name = "USER_NAME")]
29
    #[cfg_attr(not(feature = "suid-sgid-mode"), arg(add = ArgValueCompleter::new(prefix_completer)))]
30
    username: Vec<MySQLUser>,
31

            
32
    /// Do not ask for a password, leave it unset
33
    #[clap(long)]
34
    no_password: bool,
35

            
36
    /// Print the information as JSON
37
    ///
38
    /// Note that this implies `--no-password`, since the command will become non-interactive.
39
    #[arg(short, long)]
40
    json: bool,
41
}
42

            
43
pub async fn create_users(
44
    args: CreateUserArgs,
45
    mut server_connection: ClientToServerMessageStream,
46
) -> anyhow::Result<()> {
47
    if args.username.is_empty() {
48
        anyhow::bail!("No usernames provided");
49
    }
50

            
51
    let message = Request::CreateUsers(args.username.clone());
52
    if let Err(err) = server_connection.send(message).await {
53
        server_connection.close().await.ok();
54
        anyhow::bail!(anyhow::Error::from(err).context("Failed to communicate with server"));
55
    }
56

            
57
    let result = match server_connection.next().await {
58
        Some(Ok(Response::CreateUsers(result))) => result,
59
        response => return erroneous_server_response(response),
60
    };
61

            
62
    if args.json {
63
        print_create_users_output_status_json(&result);
64
    } else {
65
        print_create_users_output_status(&result);
66

            
67
        if result.iter().any(|(_, res)| {
68
            matches!(
69
                res,
70
                Err(CreateUserError::ValidationError(
71
                    ValidationError::AuthorizationError(_)
72
                ))
73
            )
74
        }) {
75
            print_authorization_owner_hint(&mut server_connection).await?;
76
        }
77

            
78
        let successfully_created_users = result
79
            .iter()
80
            .filter_map(|(username, result)| result.as_ref().ok().map(|()| username))
81
            .collect::<Vec<_>>();
82

            
83
        if !std::io::stdin().is_terminal()
84
            && !args.no_password
85
            && !successfully_created_users.is_empty()
86
        {
87
            anyhow::bail!(
88
                "Cannot prompt for passwords in non-interactive mode. Use --no-password to skip setting passwords."
89
            );
90
        }
91

            
92
        for username in successfully_created_users {
93
            if !args.no_password
94
                && Confirm::new()
95
                    .with_prompt(format!(
96
                        "Do you want to set a password for user '{username}'?"
97
                    ))
98
                    .default(false)
99
                    .interact()?
100
            {
101
                let password = read_password_from_stdin_with_double_check(username)?;
102
                let message = Request::PasswdUser((username.to_owned(), password));
103

            
104
                if let Err(err) = server_connection.send(message).await {
105
                    server_connection.close().await.ok();
106
                    anyhow::bail!(err);
107
                }
108

            
109
                match server_connection.next().await {
110
                    Some(Ok(Response::SetUserPassword(result))) => {
111
                        print_set_password_output_status(&result, username);
112
                    }
113
                    response => return erroneous_server_response(response),
114
                }
115

            
116
                println!();
117
            }
118
        }
119
    }
120

            
121
    server_connection.send(Request::Exit).await?;
122

            
123
    if result.values().any(std::result::Result::is_err) {
124
        std::process::exit(1);
125
    }
126

            
127
    Ok(())
128
}