1use std::collections::BTreeSet;
2
3use anyhow::Context;
4use clap::Parser;
5use dialoguer::{Confirm, Editor};
6use futures_util::{SinkExt, StreamExt};
7use nix::unistd::{User, getuid};
8use prettytable::{Cell, Row, Table};
9
10use crate::{
11 cli::common::erroneous_server_response,
12 core::{
13 common::yn,
14 database_privileges::{
15 DATABASE_PRIVILEGE_FIELDS, DatabasePrivilegeEditEntry, DatabasePrivilegeRow,
16 DatabasePrivilegeRowDiff, DatabasePrivilegesDiff, create_or_modify_privilege_rows,
17 db_priv_field_human_readable_name, diff_privileges, display_privilege_diffs,
18 generate_editor_content_from_privilege_data, parse_privilege_data_from_editor_content,
19 reduce_privilege_diffs,
20 },
21 protocol::{
22 ClientToServerMessageStream, MySQLDatabase, Request, Response,
23 print_create_databases_output_status, print_create_databases_output_status_json,
24 print_drop_databases_output_status, print_drop_databases_output_status_json,
25 print_modify_database_privileges_output_status,
26 },
27 },
28};
29
30#[derive(Parser, Debug, Clone)]
31pub enum DatabaseCommand {
33 #[command()]
35 CreateDb(DatabaseCreateArgs),
36
37 #[command()]
39 DropDb(DatabaseDropArgs),
40
41 #[command()]
45 ShowDb(DatabaseShowArgs),
46
47 #[command()]
51 ShowDbPrivs(DatabaseShowPrivsArgs),
52
53 #[command(verbatim_doc_comment)]
104 EditDbPrivs(DatabaseEditPrivsArgs),
105}
106
107#[derive(Parser, Debug, Clone)]
108pub struct DatabaseCreateArgs {
109 #[arg(num_args = 1..)]
111 name: Vec<MySQLDatabase>,
112
113 #[arg(short, long)]
115 json: bool,
116}
117
118#[derive(Parser, Debug, Clone)]
119pub struct DatabaseDropArgs {
120 #[arg(num_args = 1..)]
122 name: Vec<MySQLDatabase>,
123
124 #[arg(short, long)]
126 json: bool,
127}
128
129#[derive(Parser, Debug, Clone)]
130pub struct DatabaseShowArgs {
131 #[arg(num_args = 0..)]
133 name: Vec<MySQLDatabase>,
134
135 #[arg(short, long)]
137 json: bool,
138}
139
140#[derive(Parser, Debug, Clone)]
141pub struct DatabaseShowPrivsArgs {
142 #[arg(num_args = 0..)]
144 name: Vec<MySQLDatabase>,
145
146 #[arg(short, long)]
148 json: bool,
149}
150
151#[derive(Parser, Debug, Clone)]
152pub struct DatabaseEditPrivsArgs {
153 pub name: Option<MySQLDatabase>,
155
156 #[arg(short, long, value_name = "[DATABASE:]USER:[+-]PRIVILEGES", num_args = 0.., value_parser = DatabasePrivilegeEditEntry::parse_from_str)]
157 pub privs: Vec<DatabasePrivilegeEditEntry>,
158
159 #[arg(short, long)]
161 pub json: bool,
162
163 #[arg(short, long)]
165 pub editor: Option<String>,
166
167 #[arg(short, long)]
169 pub yes: bool,
170}
171
172pub async fn handle_command(
173 command: DatabaseCommand,
174 server_connection: ClientToServerMessageStream,
175) -> anyhow::Result<()> {
176 match command {
177 DatabaseCommand::CreateDb(args) => create_databases(args, server_connection).await,
178 DatabaseCommand::DropDb(args) => drop_databases(args, server_connection).await,
179 DatabaseCommand::ShowDb(args) => show_databases(args, server_connection).await,
180 DatabaseCommand::ShowDbPrivs(args) => {
181 show_database_privileges(args, server_connection).await
182 }
183 DatabaseCommand::EditDbPrivs(args) => {
184 edit_database_privileges(args, server_connection).await
185 }
186 }
187}
188
189async fn create_databases(
190 args: DatabaseCreateArgs,
191 mut server_connection: ClientToServerMessageStream,
192) -> anyhow::Result<()> {
193 if args.name.is_empty() {
194 anyhow::bail!("No database names provided");
195 }
196
197 let message = Request::CreateDatabases(args.name.to_owned());
198 server_connection.send(message).await?;
199
200 let result = match server_connection.next().await {
201 Some(Ok(Response::CreateDatabases(result))) => result,
202 response => return erroneous_server_response(response),
203 };
204
205 server_connection.send(Request::Exit).await?;
206
207 if args.json {
208 print_create_databases_output_status_json(&result);
209 } else {
210 print_create_databases_output_status(&result);
211 }
212
213 Ok(())
214}
215
216async fn drop_databases(
217 args: DatabaseDropArgs,
218 mut server_connection: ClientToServerMessageStream,
219) -> anyhow::Result<()> {
220 if args.name.is_empty() {
221 anyhow::bail!("No database names provided");
222 }
223
224 let message = Request::DropDatabases(args.name.to_owned());
225 server_connection.send(message).await?;
226
227 let result = match server_connection.next().await {
228 Some(Ok(Response::DropDatabases(result))) => result,
229 response => return erroneous_server_response(response),
230 };
231
232 server_connection.send(Request::Exit).await?;
233
234 if args.json {
235 print_drop_databases_output_status_json(&result);
236 } else {
237 print_drop_databases_output_status(&result);
238 };
239
240 Ok(())
241}
242
243async fn show_databases(
244 args: DatabaseShowArgs,
245 mut server_connection: ClientToServerMessageStream,
246) -> anyhow::Result<()> {
247 let message = if args.name.is_empty() {
248 Request::ListDatabases(None)
249 } else {
250 Request::ListDatabases(Some(args.name.to_owned()))
251 };
252
253 server_connection.send(message).await?;
254
255 let database_list = match server_connection.next().await {
258 Some(Ok(Response::ListDatabases(databases))) => databases
259 .into_iter()
260 .filter_map(|(database_name, result)| match result {
261 Ok(database_row) => Some(database_row),
262 Err(err) => {
263 eprintln!("{}", err.to_error_message(&database_name));
264 eprintln!("Skipping...");
265 println!();
266 None
267 }
268 })
269 .collect::<Vec<_>>(),
270 Some(Ok(Response::ListAllDatabases(database_list))) => match database_list {
271 Ok(list) => list,
272 Err(err) => {
273 server_connection.send(Request::Exit).await?;
274 return Err(
275 anyhow::anyhow!(err.to_error_message()).context("Failed to list databases")
276 );
277 }
278 },
279 response => return erroneous_server_response(response),
280 };
281
282 server_connection.send(Request::Exit).await?;
283
284 if args.json {
285 println!("{}", serde_json::to_string_pretty(&database_list)?);
286 } else if database_list.is_empty() {
287 println!("No databases to show.");
288 } else {
289 let mut table = Table::new();
290 table.add_row(Row::new(vec![Cell::new("Database")]));
291 for db in database_list {
292 table.add_row(row![db.database]);
293 }
294 table.printstd();
295 }
296
297 Ok(())
298}
299
300async fn show_database_privileges(
301 args: DatabaseShowPrivsArgs,
302 mut server_connection: ClientToServerMessageStream,
303) -> anyhow::Result<()> {
304 let message = if args.name.is_empty() {
305 Request::ListPrivileges(None)
306 } else {
307 Request::ListPrivileges(Some(args.name.to_owned()))
308 };
309 server_connection.send(message).await?;
310
311 let privilege_data = match server_connection.next().await {
312 Some(Ok(Response::ListPrivileges(databases))) => databases
313 .into_iter()
314 .filter_map(|(database_name, result)| match result {
315 Ok(privileges) => Some(privileges),
316 Err(err) => {
317 eprintln!("{}", err.to_error_message(&database_name));
318 eprintln!("Skipping...");
319 println!();
320 None
321 }
322 })
323 .flatten()
324 .collect::<Vec<_>>(),
325 Some(Ok(Response::ListAllPrivileges(privilege_rows))) => match privilege_rows {
326 Ok(list) => list,
327 Err(err) => {
328 server_connection.send(Request::Exit).await?;
329 return Err(anyhow::anyhow!(err.to_error_message())
330 .context("Failed to list database privileges"));
331 }
332 },
333 response => return erroneous_server_response(response),
334 };
335
336 server_connection.send(Request::Exit).await?;
337
338 if args.json {
339 println!("{}", serde_json::to_string_pretty(&privilege_data)?);
340 } else if privilege_data.is_empty() {
341 println!("No database privileges to show.");
342 } else {
343 let mut table = Table::new();
344 table.add_row(Row::new(
345 DATABASE_PRIVILEGE_FIELDS
346 .into_iter()
347 .map(db_priv_field_human_readable_name)
348 .map(|name| Cell::new(&name))
349 .collect(),
350 ));
351
352 for row in privilege_data {
353 table.add_row(row![
354 row.db,
355 row.user,
356 c->yn(row.select_priv),
357 c->yn(row.insert_priv),
358 c->yn(row.update_priv),
359 c->yn(row.delete_priv),
360 c->yn(row.create_priv),
361 c->yn(row.drop_priv),
362 c->yn(row.alter_priv),
363 c->yn(row.index_priv),
364 c->yn(row.create_tmp_table_priv),
365 c->yn(row.lock_tables_priv),
366 c->yn(row.references_priv),
367 ]);
368 }
369 table.printstd();
370 }
371
372 Ok(())
373}
374
375pub async fn edit_database_privileges(
376 args: DatabaseEditPrivsArgs,
377 mut server_connection: ClientToServerMessageStream,
378) -> anyhow::Result<()> {
379 let message = Request::ListPrivileges(args.name.to_owned().map(|name| vec![name]));
380
381 server_connection.send(message).await?;
382
383 let existing_privilege_rows = match server_connection.next().await {
384 Some(Ok(Response::ListPrivileges(databases))) => databases
385 .into_iter()
386 .filter_map(|(database_name, result)| match result {
387 Ok(privileges) => Some(privileges),
388 Err(err) => {
389 eprintln!("{}", err.to_error_message(&database_name));
390 eprintln!("Skipping...");
391 println!();
392 None
393 }
394 })
395 .flatten()
396 .collect::<Vec<_>>(),
397 Some(Ok(Response::ListAllPrivileges(privilege_rows))) => match privilege_rows {
398 Ok(list) => list,
399 Err(err) => {
400 server_connection.send(Request::Exit).await?;
401 return Err(anyhow::anyhow!(err.to_error_message())
402 .context("Failed to list database privileges"));
403 }
404 },
405 response => return erroneous_server_response(response),
406 };
407
408 let diffs: BTreeSet<DatabasePrivilegesDiff> = if !args.privs.is_empty() {
409 let privileges_to_change = parse_privilege_tables_from_args(&args)?;
410 create_or_modify_privilege_rows(&existing_privilege_rows, &privileges_to_change)?
411 } else {
412 let privileges_to_change =
413 edit_privileges_with_editor(&existing_privilege_rows, args.name.as_ref())?;
414 diff_privileges(&existing_privilege_rows, &privileges_to_change)
415 };
416 let diffs = reduce_privilege_diffs(&existing_privilege_rows, diffs)?;
417
418 if diffs.is_empty() {
419 println!("No changes to make.");
420 server_connection.send(Request::Exit).await?;
421 return Ok(());
422 }
423
424 println!("The following changes will be made:\n");
425 println!("{}", display_privilege_diffs(&diffs));
426
427 if !args.yes
428 && !Confirm::new()
429 .with_prompt("Do you want to apply these changes?")
430 .default(false)
431 .show_default(true)
432 .interact()?
433 {
434 server_connection.send(Request::Exit).await?;
435 return Ok(());
436 }
437
438 let message = Request::ModifyPrivileges(diffs);
439 server_connection.send(message).await?;
440
441 let result = match server_connection.next().await {
442 Some(Ok(Response::ModifyPrivileges(result))) => result,
443 response => return erroneous_server_response(response),
444 };
445
446 print_modify_database_privileges_output_status(&result);
447
448 server_connection.send(Request::Exit).await?;
449
450 Ok(())
451}
452
453fn parse_privilege_tables_from_args(
454 args: &DatabaseEditPrivsArgs,
455) -> anyhow::Result<BTreeSet<DatabasePrivilegeRowDiff>> {
456 debug_assert!(!args.privs.is_empty());
457 args.privs
458 .iter()
459 .map(|priv_edit_entry| {
460 priv_edit_entry
461 .as_database_privileges_diff(args.name.as_ref())
462 .context(format!(
463 "Failed parsing database privileges: `{}`",
464 priv_edit_entry
465 ))
466 })
467 .collect::<anyhow::Result<BTreeSet<DatabasePrivilegeRowDiff>>>()
468}
469
470fn edit_privileges_with_editor(
471 privilege_data: &[DatabasePrivilegeRow],
472 database_name: Option<&MySQLDatabase>,
473) -> anyhow::Result<Vec<DatabasePrivilegeRow>> {
474 let unix_user = User::from_uid(getuid())
475 .context("Failed to look up your UNIX username")
476 .and_then(|u| u.ok_or(anyhow::anyhow!("Failed to look up your UNIX username")))?;
477
478 let editor_content =
479 generate_editor_content_from_privilege_data(privilege_data, &unix_user.name, database_name);
480
481 let result = Editor::new().extension("tsv").edit(&editor_content)?;
483
484 match result {
485 None => Ok(privilege_data.to_vec()),
486 Some(result) => parse_privilege_data_from_editor_content(result)
487 .context("Could not parse privilege data from editor"),
488 }
489}