1
use std::collections::BTreeSet;
2

            
3
use anyhow::Context;
4
use clap::Parser;
5
use dialoguer::{Confirm, Editor};
6
use futures_util::SinkExt;
7
use nix::unistd::{User, getuid};
8
use tokio_stream::StreamExt;
9

            
10
use crate::{
11
    client::commands::erroneous_server_response,
12
    core::{
13
        database_privileges::{
14
            DatabasePrivilegeEditEntry, DatabasePrivilegeRow, DatabasePrivilegeRowDiff,
15
            DatabasePrivilegesDiff, create_or_modify_privilege_rows, diff_privileges,
16
            display_privilege_diffs, generate_editor_content_from_privilege_data,
17
            parse_privilege_data_from_editor_content, reduce_privilege_diffs,
18
        },
19
        protocol::{
20
            ClientToServerMessageStream, Request, Response,
21
            print_modify_database_privileges_output_status,
22
        },
23
        types::MySQLDatabase,
24
    },
25
};
26

            
27
#[derive(Parser, Debug, Clone)]
28
pub struct EditPrivsArgs {
29
    /// The name of the database to edit privileges for
30
    pub name: Option<MySQLDatabase>,
31

            
32
    #[arg(
33
      short,
34
      long,
35
      value_name = "[DATABASE:]USER:[+-]PRIVILEGES",
36
      num_args = 0..,
37
      value_parser = DatabasePrivilegeEditEntry::parse_from_str,
38
    )]
39
    pub privs: Vec<DatabasePrivilegeEditEntry>,
40

            
41
    /// Print the information as JSON
42
    #[arg(short, long)]
43
    pub json: bool,
44

            
45
    /// Specify the text editor to use for editing privileges
46
    #[arg(short, long)]
47
    pub editor: Option<String>,
48

            
49
    /// Disable interactive confirmation before saving changes
50
    #[arg(short, long)]
51
    pub yes: bool,
52
}
53

            
54
pub async fn edit_database_privileges(
55
    args: EditPrivsArgs,
56
    mut server_connection: ClientToServerMessageStream,
57
) -> anyhow::Result<()> {
58
    let message = Request::ListPrivileges(args.name.to_owned().map(|name| vec![name]));
59

            
60
    server_connection.send(message).await?;
61

            
62
    let existing_privilege_rows = match server_connection.next().await {
63
        Some(Ok(Response::ListPrivileges(databases))) => databases
64
            .into_iter()
65
            .filter_map(|(database_name, result)| match result {
66
                Ok(privileges) => Some(privileges),
67
                Err(err) => {
68
                    eprintln!("{}", err.to_error_message(&database_name));
69
                    eprintln!("Skipping...");
70
                    println!();
71
                    None
72
                }
73
            })
74
            .flatten()
75
            .collect::<Vec<_>>(),
76
        Some(Ok(Response::ListAllPrivileges(privilege_rows))) => match privilege_rows {
77
            Ok(list) => list,
78
            Err(err) => {
79
                server_connection.send(Request::Exit).await?;
80
                return Err(anyhow::anyhow!(err.to_error_message())
81
                    .context("Failed to list database privileges"));
82
            }
83
        },
84
        response => return erroneous_server_response(response),
85
    };
86

            
87
    let diffs: BTreeSet<DatabasePrivilegesDiff> = if !args.privs.is_empty() {
88
        let privileges_to_change = parse_privilege_tables_from_args(&args)?;
89
        create_or_modify_privilege_rows(&existing_privilege_rows, &privileges_to_change)?
90
    } else {
91
        let privileges_to_change =
92
            edit_privileges_with_editor(&existing_privilege_rows, args.name.as_ref())?;
93
        diff_privileges(&existing_privilege_rows, &privileges_to_change)
94
    };
95
    let diffs = reduce_privilege_diffs(&existing_privilege_rows, diffs)?;
96

            
97
    if diffs.is_empty() {
98
        println!("No changes to make.");
99
        server_connection.send(Request::Exit).await?;
100
        return Ok(());
101
    }
102

            
103
    println!("The following changes will be made:\n");
104
    println!("{}", display_privilege_diffs(&diffs));
105

            
106
    if !args.yes
107
        && !Confirm::new()
108
            .with_prompt("Do you want to apply these changes?")
109
            .default(false)
110
            .show_default(true)
111
            .interact()?
112
    {
113
        server_connection.send(Request::Exit).await?;
114
        return Ok(());
115
    }
116

            
117
    let message = Request::ModifyPrivileges(diffs);
118
    server_connection.send(message).await?;
119

            
120
    let result = match server_connection.next().await {
121
        Some(Ok(Response::ModifyPrivileges(result))) => result,
122
        response => return erroneous_server_response(response),
123
    };
124

            
125
    print_modify_database_privileges_output_status(&result);
126

            
127
    server_connection.send(Request::Exit).await?;
128

            
129
    Ok(())
130
}
131

            
132
fn parse_privilege_tables_from_args(
133
    args: &EditPrivsArgs,
134
) -> anyhow::Result<BTreeSet<DatabasePrivilegeRowDiff>> {
135
    debug_assert!(!args.privs.is_empty());
136
    args.privs
137
        .iter()
138
        .map(|priv_edit_entry| {
139
            priv_edit_entry
140
                .as_database_privileges_diff(args.name.as_ref())
141
                .context(format!(
142
                    "Failed parsing database privileges: `{}`",
143
                    priv_edit_entry
144
                ))
145
        })
146
        .collect::<anyhow::Result<BTreeSet<DatabasePrivilegeRowDiff>>>()
147
}
148

            
149
fn edit_privileges_with_editor(
150
    privilege_data: &[DatabasePrivilegeRow],
151
    database_name: Option<&MySQLDatabase>,
152
) -> anyhow::Result<Vec<DatabasePrivilegeRow>> {
153
    let unix_user = User::from_uid(getuid())
154
        .context("Failed to look up your UNIX username")
155
        .and_then(|u| u.ok_or(anyhow::anyhow!("Failed to look up your UNIX username")))?;
156

            
157
    let editor_content =
158
        generate_editor_content_from_privilege_data(privilege_data, &unix_user.name, database_name);
159

            
160
    // TODO: handle errors better here
161
    let result = Editor::new().extension("tsv").edit(&editor_content)?;
162

            
163
    match result {
164
        None => Ok(privilege_data.to_vec()),
165
        Some(result) => parse_privilege_data_from_editor_content(result)
166
            .context("Could not parse privilege data from editor"),
167
    }
168
}