1
use clap::Parser;
2
use clap_complete::ArgValueCompleter;
3
use futures_util::SinkExt;
4
use tokio_stream::StreamExt;
5

            
6
use crate::{
7
    client::commands::{erroneous_server_response, print_authorization_owner_hint},
8
    core::{
9
        completion::mysql_database_completer,
10
        protocol::{
11
            ClientToServerMessageStream, ListDatabasesError, Request, Response,
12
            print_list_databases_output_status, print_list_databases_output_status_json,
13
            request_validation::ValidationError,
14
        },
15
        types::MySQLDatabase,
16
    },
17
};
18

            
19
#[derive(Parser, Debug, Clone)]
20
pub struct ShowDbArgs {
21
    /// The `MySQL` database(s) to show
22
    #[arg(num_args = 0.., value_name = "DB_NAME")]
23
    #[cfg_attr(not(feature = "suid-sgid-mode"), arg(add = ArgValueCompleter::new(mysql_database_completer)))]
24
    name: Vec<MySQLDatabase>,
25

            
26
    /// Print the information as JSON
27
    #[arg(short, long)]
28
    json: bool,
29

            
30
    /// Show sizes in bytes instead of human-readable format
31
    #[arg(short, long)]
32
    bytes: bool,
33
}
34

            
35
pub async fn show_databases(
36
    args: ShowDbArgs,
37
    mut server_connection: ClientToServerMessageStream,
38
) -> anyhow::Result<()> {
39
    let message = if args.name.is_empty() {
40
        Request::ListDatabases(None)
41
    } else {
42
        Request::ListDatabases(Some(args.name.clone()))
43
    };
44

            
45
    server_connection.send(message).await?;
46

            
47
    let databases = match server_connection.next().await {
48
        Some(Ok(Response::ListDatabases(databases))) => databases,
49
        Some(Ok(Response::ListAllDatabases(database_list))) => match database_list {
50
            Ok(list) => list
51
                .into_iter()
52
                .map(|db| (db.database.clone(), Ok(db)))
53
                .collect(),
54
            Err(err) => {
55
                server_connection.send(Request::Exit).await?;
56
                return Err(
57
                    anyhow::anyhow!(err.to_error_message()).context("Failed to list databases")
58
                );
59
            }
60
        },
61
        response => return erroneous_server_response(response),
62
    };
63

            
64
    if args.json {
65
        print_list_databases_output_status_json(&databases);
66
    } else {
67
        print_list_databases_output_status(&databases, args.bytes);
68

            
69
        if databases.iter().any(|(_, res)| {
70
            matches!(
71
                res,
72
                Err(ListDatabasesError::ValidationError(
73
                    ValidationError::AuthorizationError(_)
74
                ))
75
            )
76
        }) {
77
            print_authorization_owner_hint(&mut server_connection).await?;
78
        }
79
    }
80

            
81
    server_connection.send(Request::Exit).await?;
82

            
83
    if databases.values().any(std::result::Result::is_err) {
84
        std::process::exit(1);
85
    }
86

            
87
    Ok(())
88
}