1
use crate::{
2
    client::commands::erroneous_server_response,
3
    core::{
4
        protocol::{
5
            ClientToServerMessageStream, Request, Response,
6
            print_check_authorization_output_status, print_check_authorization_output_status_json,
7
        },
8
        types::DbOrUser,
9
    },
10
};
11
use clap::Parser;
12
use futures_util::SinkExt;
13
use tokio_stream::StreamExt;
14

            
15
#[derive(Parser, Debug, Clone)]
16
pub struct CheckAuthArgs {
17
    /// The `MySQL` database(s) or user(s) to check authorization for
18
    #[arg(num_args = 1.., value_name = "NAME")]
19
    name: Vec<String>,
20

            
21
    /// Treat the provided names as users instead of databases
22
    #[arg(short, long)]
23
    users: bool,
24

            
25
    /// Print the information as JSON
26
    #[arg(short, long)]
27
    json: bool,
28
}
29

            
30
pub async fn check_authorization(
31
    args: CheckAuthArgs,
32
    mut server_connection: ClientToServerMessageStream,
33
) -> anyhow::Result<()> {
34
    if args.name.is_empty() {
35
        anyhow::bail!("No database/user names provided");
36
    }
37

            
38
    let payload = args
39
        .name
40
        .into_iter()
41
        .map(|name| {
42
            if args.users {
43
                DbOrUser::User(name.into())
44
            } else {
45
                DbOrUser::Database(name.into())
46
            }
47
        })
48
        .collect::<Vec<_>>();
49

            
50
    let message = Request::CheckAuthorization(payload);
51
    server_connection.send(message).await?;
52

            
53
    let result = match server_connection.next().await {
54
        Some(Ok(Response::CheckAuthorization(response))) => response,
55
        response => return erroneous_server_response(response),
56
    };
57

            
58
    server_connection.send(Request::Exit).await?;
59

            
60
    if args.json {
61
        print_check_authorization_output_status_json(&result);
62
    } else {
63
        print_check_authorization_output_status(&result);
64
    }
65

            
66
    if result.values().any(std::result::Result::is_err) {
67
        std::process::exit(1);
68
    }
69

            
70
    Ok(())
71
}