1
#[macro_use]
2
extern crate prettytable;
3

            
4
use anyhow::Context;
5
use clap::{CommandFactory, Parser, ValueEnum};
6
use clap_complete::{Shell, generate};
7
use clap_verbosity_flag::Verbosity;
8

            
9
use std::path::PathBuf;
10

            
11
use std::os::unix::net::UnixStream as StdUnixStream;
12
use tokio::net::UnixStream as TokioUnixStream;
13

            
14
use futures_util::StreamExt;
15

            
16
use crate::{
17
    core::{
18
        bootstrap::bootstrap_server_connection_and_drop_privileges,
19
        common::executable_is_suid_or_sgid,
20
        protocol::{Response, create_client_to_server_message_stream},
21
    },
22
    server::command::ServerArgs,
23
};
24

            
25
#[cfg(feature = "mysql-admutils-compatibility")]
26
use crate::client::mysql_admutils_compatibility::{mysql_dbadm, mysql_useradm};
27

            
28
mod server;
29

            
30
mod client;
31
mod core;
32

            
33
/// Database administration tool for non-admin users to manage their own MySQL databases and users.
34
///
35
/// This tool allows you to manage users and databases in MySQL.
36
///
37
/// You are only allowed to manage databases and users that are prefixed with
38
/// either your username, or a group that you are a member of.
39
#[derive(Parser, Debug)]
40
#[command(bin_name = "muscl", version, about, disable_help_subcommand = true)]
41
struct Args {
42
    #[command(subcommand)]
43
    command: Command,
44

            
45
    /// Path to the socket of the server, if it already exists.
46
    #[arg(
47
        short,
48
        long,
49
        value_name = "PATH",
50
        global = true,
51
        hide_short_help = true
52
    )]
53
    server_socket_path: Option<PathBuf>,
54

            
55
    /// Config file to use for the server.
56
    #[arg(
57
        short,
58
        long,
59
        value_name = "PATH",
60
        global = true,
61
        hide_short_help = true
62
    )]
63
    config: Option<PathBuf>,
64

            
65
    #[command(flatten)]
66
    verbose: Verbosity,
67
}
68

            
69
#[derive(Parser, Debug, Clone)]
70
enum Command {
71
    #[command(flatten)]
72
    Client(client::commands::ClientCommand),
73

            
74
    #[command(hide = true)]
75
    Server(server::command::ServerArgs),
76

            
77
    #[command(hide = true)]
78
    GenerateCompletions(GenerateCompletionArgs),
79
}
80

            
81
#[derive(Parser, Debug, Clone)]
82
struct GenerateCompletionArgs {
83
    #[arg(long, default_value = "bash")]
84
    shell: Shell,
85

            
86
    #[arg(long, default_value = "muscl")]
87
    command: ToplevelCommands,
88
}
89

            
90
#[cfg(feature = "mysql-admutils-compatibility")]
91
#[derive(ValueEnum, Debug, Clone)]
92
enum ToplevelCommands {
93
    Muscl,
94
    MysqlDbadm,
95
    MysqlUseradm,
96
}
97

            
98
/// **WARNING:** This function may be run with elevated privileges.
99
fn main() -> anyhow::Result<()> {
100
    #[cfg(feature = "mysql-admutils-compatibility")]
101
    if handle_mysql_admutils_command()?.is_some() {
102
        return Ok(());
103
    }
104

            
105
    let args: Args = Args::parse();
106

            
107
    if handle_server_command(&args)?.is_some() {
108
        return Ok(());
109
    }
110

            
111
    if handle_generate_completions_command(&args)?.is_some() {
112
        return Ok(());
113
    }
114

            
115
    let connection = bootstrap_server_connection_and_drop_privileges(
116
        args.server_socket_path,
117
        args.config,
118
        args.verbose,
119
    )?;
120

            
121
    tokio_run_command(args.command, connection)?;
122

            
123
    Ok(())
124
}
125

            
126
/// **WARNING:** This function may be run with elevated privileges.
127
fn handle_mysql_admutils_command() -> anyhow::Result<Option<()>> {
128
    let argv0 = std::env::args().next().and_then(|s| {
129
        PathBuf::from(s)
130
            .file_name()
131
            .map(|s| s.to_string_lossy().to_string())
132
    });
133

            
134
    match argv0.as_deref() {
135
        Some("mysql-dbadm") => mysql_dbadm::main().map(Some),
136
        Some("mysql-useradm") => mysql_useradm::main().map(Some),
137
        _ => Ok(None),
138
    }
139
}
140

            
141
/// **WARNING:** This function may be run with elevated privileges.
142
fn handle_server_command(args: &Args) -> anyhow::Result<Option<()>> {
143
    match args.command {
144
        Command::Server(ref command) => {
145
            assert!(
146
                !executable_is_suid_or_sgid()?,
147
                "The executable should not be SUID or SGID when running the server manually"
148
            );
149
            tokio_start_server(
150
                args.config.to_owned(),
151
                args.verbose.to_owned(),
152
                command.to_owned(),
153
            )?;
154
            Ok(Some(()))
155
        }
156
        _ => Ok(None),
157
    }
158
}
159

            
160
/// **WARNING:** This function may be run with elevated privileges.
161
fn handle_generate_completions_command(args: &Args) -> anyhow::Result<Option<()>> {
162
    match args.command {
163
        Command::GenerateCompletions(ref completion_args) => {
164
            assert!(
165
                !executable_is_suid_or_sgid()?,
166
                "The executable should not be SUID or SGID when generating completions"
167
            );
168
            let mut cmd = match completion_args.command {
169
                ToplevelCommands::Muscl => Args::command(),
170
                #[cfg(feature = "mysql-admutils-compatibility")]
171
                ToplevelCommands::MysqlDbadm => mysql_dbadm::Args::command(),
172
                #[cfg(feature = "mysql-admutils-compatibility")]
173
                ToplevelCommands::MysqlUseradm => mysql_useradm::Args::command(),
174
            };
175

            
176
            let binary_name = cmd.get_bin_name().unwrap().to_owned();
177

            
178
            generate(
179
                completion_args.shell,
180
                &mut cmd,
181
                binary_name,
182
                &mut std::io::stdout(),
183
            );
184

            
185
            Ok(Some(()))
186
        }
187
        _ => Ok(None),
188
    }
189
}
190

            
191
/// Start a long-lived server using Tokio.
192
fn tokio_start_server(
193
    config_path: Option<PathBuf>,
194
    verbosity: Verbosity,
195
    args: ServerArgs,
196
) -> anyhow::Result<()> {
197
    tokio::runtime::Builder::new_multi_thread()
198
        .enable_all()
199
        .build()
200
        .context("Failed to start Tokio runtime")?
201
        .block_on(async { server::command::handle_command(config_path, verbosity, args).await })
202
}
203

            
204
/// Run the given commmand (from the client side) using Tokio.
205
///
206
/// **WARNING:** This function may be run with elevated privileges.
207
fn tokio_run_command(command: Command, server_connection: StdUnixStream) -> anyhow::Result<()> {
208
    tokio::runtime::Builder::new_current_thread()
209
        .enable_all()
210
        .build()
211
        .context("Failed to start Tokio runtime")?
212
        .block_on(async {
213
            let tokio_socket = TokioUnixStream::from_std(server_connection)?;
214
            let mut message_stream = create_client_to_server_message_stream(tokio_socket);
215

            
216
            while let Some(Ok(message)) = message_stream.next().await {
217
                match message {
218
                    Response::Error(err) => {
219
                        anyhow::bail!("{}", err);
220
                    }
221
                    Response::Ready => break,
222
                    message => {
223
                        eprintln!("Unexpected message from server: {:?}", message);
224
                    }
225
                }
226
            }
227

            
228
            match command {
229
                Command::Client(client_args) => {
230
                    client::commands::handle_command(client_args, message_stream).await
231
                }
232
                Command::Server(_) => unreachable!(),
233
                Command::GenerateCompletions(_) => unreachable!(),
234
            }
235
        })
236
}