1
#[macro_use]
2
extern crate prettytable;
3

            
4
use anyhow::Context;
5
use clap::{CommandFactory, Parser, ValueEnum};
6
use clap_complete::{Shell, generate};
7
use clap_verbosity_flag::Verbosity;
8

            
9
use std::path::PathBuf;
10

            
11
use std::os::unix::net::UnixStream as StdUnixStream;
12
use tokio::net::UnixStream as TokioUnixStream;
13

            
14
use futures_util::StreamExt;
15

            
16
use crate::{
17
    core::{
18
        bootstrap::bootstrap_server_connection_and_drop_privileges,
19
        common::executable_is_suid_or_sgid,
20
        protocol::{Response, create_client_to_server_message_stream},
21
    },
22
    server::command::ServerArgs,
23
};
24

            
25
#[cfg(feature = "mysql-admutils-compatibility")]
26
use crate::client::mysql_admutils_compatibility::{mysql_dbadm, mysql_useradm};
27

            
28
mod server;
29

            
30
mod client;
31
mod core;
32

            
33
/// Database administration tool for non-admin users to manage their own MySQL databases and users.
34
///
35
/// This tool allows you to manage users and databases in MySQL.
36
///
37
/// You are only allowed to manage databases and users that are prefixed with
38
/// either your username, or a group that you are a member of.
39
#[derive(Parser, Debug)]
40
#[command(bin_name = "muscle", version, about, disable_help_subcommand = true)]
41
struct Args {
42
    #[command(subcommand)]
43
    command: Command,
44

            
45
    /// Path to the socket of the server, if it already exists.
46
    #[arg(
47
        short,
48
        long,
49
        value_name = "PATH",
50
        global = true,
51
        hide_short_help = true
52
    )]
53
    server_socket_path: Option<PathBuf>,
54

            
55
    /// Config file to use for the server.
56
    #[arg(
57
        short,
58
        long,
59
        value_name = "PATH",
60
        global = true,
61
        hide_short_help = true
62
    )]
63
    config: Option<PathBuf>,
64

            
65
    #[command(flatten)]
66
    verbose: Verbosity,
67
}
68

            
69
#[derive(Parser, Debug, Clone)]
70
enum Command {
71
    #[command(flatten)]
72
    Client(client::commands::ClientCommand),
73

            
74
    #[command(hide = true)]
75
    Server(server::command::ServerArgs),
76

            
77
    #[command(hide = true)]
78
    GenerateCompletions(GenerateCompletionArgs),
79
}
80

            
81
#[derive(Parser, Debug, Clone)]
82
struct GenerateCompletionArgs {
83
    #[arg(long, default_value = "bash")]
84
    shell: Shell,
85

            
86
    #[arg(long, default_value = "muscle")]
87
    command: ToplevelCommands,
88
}
89

            
90
#[cfg(feature = "mysql-admutils-compatibility")]
91
#[derive(ValueEnum, Debug, Clone)]
92
enum ToplevelCommands {
93
    Muscle,
94
    MysqlDbadm,
95
    MysqlUseradm,
96
}
97

            
98
/// **WARNING:** This function may be run with elevated privileges.
99
fn main() -> anyhow::Result<()> {
100
    #[cfg(feature = "mysql-admutils-compatibility")]
101
    if handle_mysql_admutils_command()?.is_some() {
102
        return Ok(());
103
    }
104

            
105
    let args: Args = Args::parse();
106

            
107
    if handle_server_command(&args)?.is_some() {
108
        return Ok(());
109
    }
110

            
111
    if handle_generate_completions_command(&args)?.is_some() {
112
        return Ok(());
113
    }
114

            
115
    let connection = bootstrap_server_connection_and_drop_privileges(
116
        args.server_socket_path,
117
        args.config,
118
        args.verbose,
119
    )?;
120

            
121
    tokio_run_command(args.command, connection)?;
122

            
123
    Ok(())
124
}
125

            
126
/// **WARNING:** This function may be run with elevated privileges.
127
fn handle_mysql_admutils_command() -> anyhow::Result<Option<()>> {
128
    let argv0 = std::env::args().next().and_then(|s| {
129
        PathBuf::from(s)
130
            .file_name()
131
            .map(|s| s.to_string_lossy().to_string())
132
    });
133

            
134
    match argv0.as_deref() {
135
        Some("mysql-dbadm") => mysql_dbadm::main().map(Some),
136
        Some("mysql-useradm") => mysql_useradm::main().map(Some),
137
        _ => Ok(None),
138
    }
139
}
140

            
141
/// **WARNING:** This function may be run with elevated privileges.
142
fn handle_server_command(args: &Args) -> anyhow::Result<Option<()>> {
143
    match args.command {
144
        Command::Server(ref command) => {
145
            assert!(
146
                !executable_is_suid_or_sgid()?,
147
                "The executable should not be SUID or SGID when running the server manually"
148
            );
149
            tokio_start_server(
150
                args.server_socket_path.to_owned(),
151
                args.config.to_owned(),
152
                args.verbose.to_owned(),
153
                command.to_owned(),
154
            )?;
155
            Ok(Some(()))
156
        }
157
        _ => Ok(None),
158
    }
159
}
160

            
161
/// **WARNING:** This function may be run with elevated privileges.
162
fn handle_generate_completions_command(args: &Args) -> anyhow::Result<Option<()>> {
163
    match args.command {
164
        Command::GenerateCompletions(ref completion_args) => {
165
            assert!(
166
                !executable_is_suid_or_sgid()?,
167
                "The executable should not be SUID or SGID when generating completions"
168
            );
169
            let mut cmd = match completion_args.command {
170
                ToplevelCommands::Muscle => Args::command(),
171
                #[cfg(feature = "mysql-admutils-compatibility")]
172
                ToplevelCommands::MysqlDbadm => mysql_dbadm::Args::command(),
173
                #[cfg(feature = "mysql-admutils-compatibility")]
174
                ToplevelCommands::MysqlUseradm => mysql_useradm::Args::command(),
175
            };
176

            
177
            let binary_name = cmd.get_bin_name().unwrap().to_owned();
178

            
179
            generate(
180
                completion_args.shell,
181
                &mut cmd,
182
                binary_name,
183
                &mut std::io::stdout(),
184
            );
185

            
186
            Ok(Some(()))
187
        }
188
        _ => Ok(None),
189
    }
190
}
191

            
192
/// Start a long-lived server using Tokio.
193
fn tokio_start_server(
194
    server_socket_path: Option<PathBuf>,
195
    config_path: Option<PathBuf>,
196
    verbosity: Verbosity,
197
    args: ServerArgs,
198
) -> anyhow::Result<()> {
199
    tokio::runtime::Builder::new_multi_thread()
200
        .enable_all()
201
        .build()
202
        .context("Failed to start Tokio runtime")?
203
        .block_on(async {
204
            server::command::handle_command(server_socket_path, config_path, verbosity, args).await
205
        })
206
}
207

            
208
/// Run the given commmand (from the client side) using Tokio.
209
///
210
/// **WARNING:** This function may be run with elevated privileges.
211
fn tokio_run_command(command: Command, server_connection: StdUnixStream) -> anyhow::Result<()> {
212
    tokio::runtime::Builder::new_current_thread()
213
        .enable_all()
214
        .build()
215
        .context("Failed to start Tokio runtime")?
216
        .block_on(async {
217
            let tokio_socket = TokioUnixStream::from_std(server_connection)?;
218
            let mut message_stream = create_client_to_server_message_stream(tokio_socket);
219

            
220
            while let Some(Ok(message)) = message_stream.next().await {
221
                match message {
222
                    Response::Error(err) => {
223
                        anyhow::bail!("{}", err);
224
                    }
225
                    Response::Ready => break,
226
                    message => {
227
                        eprintln!("Unexpected message from server: {:?}", message);
228
                    }
229
                }
230
            }
231

            
232
            match command {
233
                Command::Client(client_args) => {
234
                    client::commands::handle_command(client_args, message_stream).await
235
                }
236
                Command::Server(_) => unreachable!(),
237
                Command::GenerateCompletions(_) => unreachable!(),
238
            }
239
        })
240
}