1
use clap::Parser;
2
use clap_complete::ArgValueCompleter;
3
use futures_util::SinkExt;
4
use itertools::Itertools;
5
use tokio_stream::StreamExt;
6

            
7
use crate::{
8
    client::commands::{erroneous_server_response, print_authorization_owner_hint},
9
    core::{
10
        completion::mysql_database_completer,
11
        protocol::{
12
            ClientToServerMessageStream, ListPrivilegesError, Request, Response,
13
            print_list_privileges_output_status, print_list_privileges_output_status_json,
14
            request_validation::ValidationError,
15
        },
16
        types::MySQLDatabase,
17
    },
18
};
19

            
20
#[derive(Parser, Debug, Clone)]
21
pub struct ShowPrivsArgs {
22
    /// The `MySQL` database(s) to show privileges for
23
    #[arg(num_args = 0.., value_name = "DB_NAME")]
24
    #[cfg_attr(not(feature = "suid-sgid-mode"), arg(add = ArgValueCompleter::new(mysql_database_completer)))]
25
    name: Vec<MySQLDatabase>,
26

            
27
    /// Print the information as JSON
28
    #[arg(short, long)]
29
    json: bool,
30

            
31
    /// Show single-character privilege names in addition to human-readable names
32
    ///
33
    /// This flag has no effect when used with --json
34
    #[arg(short, long)]
35
    long: bool,
36
}
37

            
38
pub async fn show_database_privileges(
39
    args: ShowPrivsArgs,
40
    mut server_connection: ClientToServerMessageStream,
41
) -> anyhow::Result<()> {
42
    let message = if args.name.is_empty() {
43
        Request::ListPrivileges(None)
44
    } else {
45
        Request::ListPrivileges(Some(args.name.clone()))
46
    };
47
    server_connection.send(message).await?;
48

            
49
    let privilege_data = match server_connection.next().await {
50
        Some(Ok(Response::ListPrivileges(databases))) => databases,
51
        Some(Ok(Response::ListAllPrivileges(privilege_rows))) => match privilege_rows {
52
            Ok(list) => list
53
                .into_iter()
54
                .map(|row| (row.db.clone(), row))
55
                .into_group_map()
56
                .into_iter()
57
                .map(|(db, rows)| (db, Ok(rows)))
58
                .collect(),
59
            Err(err) => {
60
                server_connection.send(Request::Exit).await?;
61
                return Err(anyhow::anyhow!(err.to_error_message())
62
                    .context("Failed to list database privileges"));
63
            }
64
        },
65
        response => return erroneous_server_response(response),
66
    };
67

            
68
    if args.json {
69
        print_list_privileges_output_status_json(&privilege_data);
70
    } else {
71
        print_list_privileges_output_status(&privilege_data, args.long);
72

            
73
        if privilege_data.iter().any(|(_, res)| {
74
            matches!(
75
                res,
76
                Err(ListPrivilegesError::ValidationError(
77
                    ValidationError::AuthorizationError(_)
78
                ))
79
            )
80
        }) {
81
            print_authorization_owner_hint(&mut server_connection).await?;
82
        }
83
    }
84

            
85
    server_connection.send(Request::Exit).await?;
86

            
87
    if privilege_data.values().any(std::result::Result::is_err) {
88
        std::process::exit(1);
89
    }
90

            
91
    Ok(())
92
}