1
use clap::Parser;
2
use clap_complete::ArgValueCompleter;
3
use futures_util::SinkExt;
4
use tokio_stream::StreamExt;
5

            
6
use crate::{
7
    client::commands::{erroneous_server_response, print_authorization_owner_hint},
8
    core::{
9
        completion::mysql_user_completer,
10
        protocol::{
11
            ClientToServerMessageStream, ListUsersError, Request, Response,
12
            print_list_users_output_status, print_list_users_output_status_json,
13
            request_validation::ValidationError,
14
        },
15
        types::MySQLUser,
16
    },
17
};
18

            
19
#[derive(Parser, Debug, Clone)]
20
pub struct ShowUserArgs {
21
    /// The `MySQL` user(s) to show
22
    #[cfg_attr(not(feature = "suid-sgid-mode"), arg(add = ArgValueCompleter::new(mysql_user_completer)))]
23
    #[arg(num_args = 0.., value_name = "USER_NAME")]
24
    username: Vec<MySQLUser>,
25

            
26
    /// Print the information as JSON
27
    #[arg(short, long)]
28
    json: bool,
29
}
30

            
31
pub async fn show_users(
32
    args: ShowUserArgs,
33
    mut server_connection: ClientToServerMessageStream,
34
) -> anyhow::Result<()> {
35
    let message = if args.username.is_empty() {
36
        Request::ListUsers(None)
37
    } else {
38
        Request::ListUsers(Some(args.username.clone()))
39
    };
40

            
41
    if let Err(err) = server_connection.send(message).await {
42
        server_connection.close().await.ok();
43
        anyhow::bail!(err);
44
    }
45

            
46
    let users = match server_connection.next().await {
47
        Some(Ok(Response::ListUsers(users))) => users,
48
        Some(Ok(Response::ListAllUsers(users))) => match users {
49
            Ok(users) => users
50
                .into_iter()
51
                .map(|user| (user.user.clone(), Ok(user)))
52
                .collect(),
53
            Err(err) => {
54
                server_connection.send(Request::Exit).await?;
55
                return Err(
56
                    anyhow::anyhow!(err.to_error_message()).context("Failed to list all users")
57
                );
58
            }
59
        },
60
        response => return erroneous_server_response(response),
61
    };
62

            
63
    if args.json {
64
        print_list_users_output_status_json(&users);
65
    } else {
66
        print_list_users_output_status(&users);
67

            
68
        if users.iter().any(|(_, res)| {
69
            matches!(
70
                res,
71
                Err(ListUsersError::ValidationError(
72
                    ValidationError::AuthorizationError(_)
73
                ))
74
            )
75
        }) {
76
            print_authorization_owner_hint(&mut server_connection).await?;
77
        }
78
    }
79

            
80
    server_connection.send(Request::Exit).await?;
81

            
82
    if users.values().any(std::result::Result::is_err) {
83
        std::process::exit(1);
84
    }
85

            
86
    Ok(())
87
}