1
use std::{
2
    collections::{BTreeMap, BTreeSet},
3
    io::IsTerminal,
4
};
5

            
6
use anyhow::Context;
7
use clap::{Args, Parser};
8
use clap_complete::ArgValueCompleter;
9
use dialoguer::{Confirm, Editor};
10
use futures_util::SinkExt;
11
use nix::unistd::{User, getuid};
12
use tokio_stream::StreamExt;
13

            
14
use crate::{
15
    client::commands::{erroneous_server_response, print_authorization_owner_hint},
16
    core::{
17
        completion::{mysql_database_completer, mysql_user_completer},
18
        database_privileges::{
19
            DatabasePrivilegeEdit, DatabasePrivilegeEditEntry, DatabasePrivilegeRow,
20
            DatabasePrivilegeRowDiff, DatabasePrivilegesDiff, create_or_modify_privilege_rows,
21
            diff_privileges, display_privilege_diffs, generate_editor_content_from_privilege_data,
22
            parse_privilege_data_from_editor_content, reduce_privilege_diffs,
23
        },
24
        protocol::{
25
            ClientToServerMessageStream, ListDatabasesError, ListUsersError,
26
            ModifyDatabasePrivilegesError, Request, Response,
27
            print_modify_database_privileges_output_status, request_validation::ValidationError,
28
        },
29
        types::{MySQLDatabase, MySQLUser},
30
    },
31
};
32

            
33
#[derive(Parser, Debug, Clone)]
34
pub struct EditPrivsArgs {
35
    /// The privileges to set, grant or revoke, in the format `DATABASE:USER:[+-]PRIVILEGES`
36
    ///
37
    /// This option allows for changing privileges for multiple databases and users in batch.
38
    ///
39
    /// This can not be used together with the positional `DB_NAME`, `USER_NAME` and `PRIVILEGES` arguments.
40
    #[arg(
41
      short,
42
      long,
43
      value_name = "DB_NAME:USER_NAME:[+-]PRIVILEGES",
44
      num_args = 0..,
45
      value_parser = DatabasePrivilegeEditEntry::parse_from_str,
46
      conflicts_with("single_priv"),
47
    )]
48
    pub privs: Vec<DatabasePrivilegeEditEntry>,
49

            
50
    #[command(flatten)]
51
    pub single_priv: Option<SinglePrivilegeEditArgs>,
52

            
53
    /// Print the information as JSON
54
    #[arg(short, long)]
55
    pub json: bool,
56

            
57
    /// Specify the text editor to use for editing privileges
58
    #[arg(
59
      short,
60
      long,
61
      value_name = "COMMAND",
62
      value_hint = clap::ValueHint::CommandString,
63
    )]
64
    pub editor: Option<String>,
65

            
66
    /// Disable interactive confirmation before saving changes
67
    #[arg(short, long)]
68
    pub yes: bool,
69
}
70

            
71
#[derive(Args, Debug, Clone)]
72
pub struct SinglePrivilegeEditArgs {
73
    /// The `MySQL` database to edit privileges for
74
    #[cfg_attr(not(feature = "suid-sgid-mode"), arg(add = ArgValueCompleter::new(mysql_database_completer)))]
75
    #[arg(
76
        value_name = "DB_NAME",
77
        requires = "user_name",
78
        requires = "single_priv"
79
    )]
80
    pub db_name: Option<MySQLDatabase>,
81

            
82
    /// The `MySQL` database to edit privileges for
83
    #[cfg_attr(not(feature = "suid-sgid-mode"), arg(add = ArgValueCompleter::new(mysql_user_completer)))]
84
    #[arg(value_name = "USER_NAME")]
85
    pub user_name: Option<MySQLUser>,
86

            
87
    /// The privileges to set, grant or revoke
88
    #[arg(
89
      allow_hyphen_values = true,
90
      value_name = "[+-]PRIVILEGES",
91
      value_parser = DatabasePrivilegeEdit::parse_from_str,
92
    )]
93
    pub single_priv: Option<DatabasePrivilegeEdit>,
94
}
95

            
96
async fn users_exist(
97
    server_connection: &mut ClientToServerMessageStream,
98
    privilege_diff: &BTreeSet<DatabasePrivilegesDiff>,
99
) -> anyhow::Result<BTreeMap<MySQLUser, Result<(), ListUsersError>>> {
100
    let user_list = privilege_diff
101
        .iter()
102
        .map(|diff| diff.get_user_name().clone())
103
        .collect();
104

            
105
    let message = Request::ListUsers(Some(user_list));
106
    server_connection.send(message).await?;
107

            
108
    let result = match server_connection.next().await {
109
        Some(Ok(Response::ListUsers(user_map))) => user_map,
110
        response => {
111
            erroneous_server_response(response)?;
112
            // Unreachable, but needed to satisfy the type checker
113
            BTreeMap::new()
114
        }
115
    };
116

            
117
    let result = result
118
        .into_iter()
119
        .map(|(user, user_result)| (user, user_result.map(|_| ())))
120
        .collect();
121

            
122
    Ok(result)
123
}
124

            
125
async fn databases_exist(
126
    server_connection: &mut ClientToServerMessageStream,
127
    privilege_diff: &BTreeSet<DatabasePrivilegesDiff>,
128
) -> anyhow::Result<BTreeMap<MySQLDatabase, Result<(), ListDatabasesError>>> {
129
    let database_list = privilege_diff
130
        .iter()
131
        .map(|diff| diff.get_database_name().clone())
132
        .collect();
133

            
134
    let message = Request::ListDatabases(Some(database_list));
135
    server_connection.send(message).await?;
136

            
137
    let result = match server_connection.next().await {
138
        Some(Ok(Response::ListDatabases(database_map))) => database_map,
139
        response => {
140
            erroneous_server_response(response)?;
141
            // Unreachable, but needed to satisfy the type checker
142
            BTreeMap::new()
143
        }
144
    };
145

            
146
    let result = result
147
        .into_iter()
148
        .map(|(database, db_result)| (database, db_result.map(|_| ())))
149
        .collect();
150

            
151
    Ok(result)
152
}
153

            
154
// TODO: reduce the complexity of this function
155
pub async fn edit_database_privileges(
156
    args: EditPrivsArgs,
157
    // NOTE: this is only used for backwards compat with mysql-admutils
158
    use_database: Option<MySQLDatabase>,
159
    mut server_connection: ClientToServerMessageStream,
160
) -> anyhow::Result<()> {
161
    let message = Request::ListPrivileges(use_database.clone().map(|db| vec![db]));
162

            
163
    server_connection.send(message).await?;
164

            
165
    debug_assert!(args.privs.is_empty() ^ args.single_priv.is_none());
166

            
167
    let privs = if let Some(single_priv_entry) = &args.single_priv {
168
        let database = single_priv_entry.db_name.clone().ok_or_else(|| {
169
            anyhow::anyhow!(
170
                "DB_NAME must be specified when editing privileges in single privilege mode"
171
            )
172
        })?;
173
        let user = single_priv_entry.user_name.clone().ok_or_else(|| {
174
            anyhow::anyhow!(
175
                "USER_NAME must be specified when DB_NAME is specified in single privilege mode"
176
            )
177
        })?;
178
        let privilege_edit = single_priv_entry.single_priv.clone().ok_or_else(|| {
179
            anyhow::anyhow!(
180
                "PRIVILEGES must be specified when DB_NAME is specified in single privilege mode"
181
            )
182
        })?;
183

            
184
        vec![DatabasePrivilegeEditEntry {
185
            database,
186
            user,
187
            privilege_edit,
188
        }]
189
    } else {
190
        args.privs.clone()
191
    };
192

            
193
    let existing_privilege_rows = match server_connection.next().await {
194
        Some(Ok(Response::ListPrivileges(databases))) => databases
195
            .into_iter()
196
            .filter_map(|(database_name, result)| match result {
197
                Ok(privileges) => Some(privileges),
198
                Err(err) => {
199
                    eprintln!("{}", err.to_error_message(&database_name));
200
                    eprintln!("Skipping...");
201
                    println!();
202
                    None
203
                }
204
            })
205
            .flatten()
206
            .collect::<Vec<_>>(),
207
        Some(Ok(Response::ListAllPrivileges(privilege_rows))) => match privilege_rows {
208
            Ok(list) => list,
209
            Err(err) => {
210
                server_connection.send(Request::Exit).await?;
211
                return Err(anyhow::anyhow!(err.to_error_message())
212
                    .context("Failed to list database privileges"));
213
            }
214
        },
215
        response => return erroneous_server_response(response),
216
    };
217

            
218
    let diffs: BTreeSet<DatabasePrivilegesDiff> = if privs.is_empty() {
219
        if !std::io::stdin().is_terminal() {
220
            anyhow::bail!(
221
                "Cannot launch editor in non-interactive mode. Please provide privileges via command line arguments."
222
            );
223
        }
224
        let privileges_to_change =
225
            edit_privileges_with_editor(&existing_privilege_rows, use_database.as_ref())?;
226
        diff_privileges(&existing_privilege_rows, &privileges_to_change)
227
    } else {
228
        let privileges_to_change = parse_privilege_tables(&privs)?;
229
        create_or_modify_privilege_rows(&existing_privilege_rows, &privileges_to_change)?
230
    };
231

            
232
    let database_existence_map = databases_exist(&mut server_connection, &diffs).await?;
233
    let user_existence_map = users_exist(&mut server_connection, &diffs).await?;
234

            
235
    let diffs = reduce_privilege_diffs(&existing_privilege_rows, diffs)?
236
        .into_iter()
237
        .filter(|diff| {
238
            let database_name = diff.get_database_name();
239
            let username = diff.get_user_name();
240

            
241
            if let Some(Err(err)) = database_existence_map.get(database_name) {
242
                println!("{}", err.to_error_message(database_name));
243
                println!("Skipping...");
244
                return false;
245
            }
246

            
247
            if let Some(Err(err)) = user_existence_map.get(username) {
248
                println!("{}", err.to_error_message(username));
249
                println!("Skipping...");
250
                return false;
251
            }
252

            
253
            true
254
        })
255
        .collect::<BTreeSet<_>>();
256

            
257
    if database_existence_map.values().any(|res| {
258
        matches!(
259
            res,
260
            Err(ListDatabasesError::ValidationError(
261
                ValidationError::AuthorizationError(_)
262
            ))
263
        )
264
    }) || user_existence_map.values().any(|res| {
265
        matches!(
266
            res,
267
            Err(ListUsersError::ValidationError(
268
                ValidationError::AuthorizationError(_)
269
            ))
270
        )
271
    }) {
272
        println!();
273
        print_authorization_owner_hint(&mut server_connection).await?;
274
        println!();
275
    }
276

            
277
    if diffs.is_empty() {
278
        println!("No changes to make.");
279
        server_connection.send(Request::Exit).await?;
280
        return Ok(());
281
    }
282

            
283
    println!("The following changes will be made:\n");
284
    println!("{}", display_privilege_diffs(&diffs));
285

            
286
    if std::io::stdin().is_terminal()
287
        && !args.yes
288
        && !Confirm::new()
289
            .with_prompt("Do you want to apply these changes?")
290
            .default(false)
291
            .show_default(true)
292
            .interact()?
293
    {
294
        server_connection.send(Request::Exit).await?;
295
        return Ok(());
296
    }
297

            
298
    let message = Request::ModifyPrivileges(diffs);
299
    server_connection.send(message).await?;
300

            
301
    let result = match server_connection.next().await {
302
        Some(Ok(Response::ModifyPrivileges(result))) => result,
303
        response => return erroneous_server_response(response),
304
    };
305

            
306
    print_modify_database_privileges_output_status(&result);
307

            
308
    if result.iter().any(|(_, res)| {
309
        matches!(
310
            res,
311
            Err(ModifyDatabasePrivilegesError::UserValidationError(
312
                ValidationError::AuthorizationError(_)
313
            ) | ModifyDatabasePrivilegesError::DatabaseValidationError(
314
                ValidationError::AuthorizationError(_)
315
            ))
316
        )
317
    }) {
318
        print_authorization_owner_hint(&mut server_connection).await?;
319
    }
320

            
321
    server_connection.send(Request::Exit).await?;
322

            
323
    if result.values().any(std::result::Result::is_err) {
324
        std::process::exit(1);
325
    }
326

            
327
    Ok(())
328
}
329

            
330
fn parse_privilege_tables(
331
    privs: &[DatabasePrivilegeEditEntry],
332
) -> anyhow::Result<BTreeSet<DatabasePrivilegeRowDiff>> {
333
    debug_assert!(!privs.is_empty());
334
    privs
335
        .iter()
336
        .map(|priv_edit_entry| {
337
            priv_edit_entry
338
                .as_database_privileges_diff()
339
                .context(format!(
340
                    "Failed parsing database privileges: `{priv_edit_entry}`"
341
                ))
342
        })
343
        .collect::<anyhow::Result<BTreeSet<DatabasePrivilegeRowDiff>>>()
344
}
345

            
346
fn edit_privileges_with_editor(
347
    privilege_data: &[DatabasePrivilegeRow],
348
    // NOTE: this is only used for backwards compat with mysql-admtools
349
    database_name: Option<&MySQLDatabase>,
350
) -> anyhow::Result<Vec<DatabasePrivilegeRow>> {
351
    let unix_user = User::from_uid(getuid())
352
        .context("Failed to look up your UNIX username")
353
        .and_then(|u| u.ok_or(anyhow::anyhow!("Failed to look up your UNIX username")))?;
354

            
355
    let editor_content =
356
        generate_editor_content_from_privilege_data(privilege_data, &unix_user.name, database_name);
357

            
358
    // TODO: handle errors better here
359
    let result = Editor::new().extension("tsv").edit(&editor_content)?;
360

            
361
    match result {
362
        None => Ok(privilege_data.to_vec()),
363
        Some(result) => parse_privilege_data_from_editor_content(&result)
364
            .context("Could not parse privilege data from editor"),
365
    }
366
}