1
use clap::{Parser, Subcommand};
2
use clap_complete::ArgValueCompleter;
3
use futures_util::{SinkExt, StreamExt};
4
use std::path::PathBuf;
5

            
6
use std::os::unix::net::UnixStream as StdUnixStream;
7
use tokio::net::UnixStream as TokioUnixStream;
8

            
9
use crate::{
10
    client::{
11
        commands::{erroneous_server_response, read_password_from_stdin_with_double_check},
12
        mysql_admutils_compatibility::{
13
            common::trim_user_name_to_32_chars,
14
            error_messages::{
15
                handle_create_user_error, handle_drop_user_error, handle_list_users_error,
16
            },
17
        },
18
    },
19
    core::{
20
        bootstrap::bootstrap_server_connection_and_drop_privileges,
21
        completion::{mysql_user_completer, prefix_completer},
22
        protocol::{
23
            ClientToServerMessageStream, Request, Response, create_client_to_server_message_stream,
24
        },
25
        types::MySQLUser,
26
    },
27
    server::sql::user_operations::DatabaseUser,
28
};
29

            
30
/// Create, delete or change password for the USER(s),
31
/// as determined by the COMMAND.
32
///
33
/// This is a compatibility layer for the 'mysql-useradm' command.
34
/// Please consider using the newer 'muscl' command instead.
35
#[derive(Parser)]
36
#[command(
37
    bin_name = "mysql-useradm",
38
    version,
39
    about,
40
    disable_help_subcommand = true,
41
    verbatim_doc_comment
42
)]
43
pub struct Args {
44
    #[command(subcommand)]
45
    pub command: Option<Command>,
46

            
47
    /// Path to the socket of the server, if it already exists.
48
    #[arg(
49
        short,
50
        long,
51
        value_name = "PATH",
52
        value_hint = clap::ValueHint::FilePath,
53
        global = true,
54
        hide_short_help = true
55
    )]
56
    server_socket_path: Option<PathBuf>,
57

            
58
    /// Config file to use for the server.
59
    #[arg(
60
        short,
61
        long,
62
        value_name = "PATH",
63
        value_hint = clap::ValueHint::FilePath,
64
        global = true,
65
        hide_short_help = true
66
    )]
67
    config: Option<PathBuf>,
68
}
69

            
70
#[derive(Subcommand)]
71
pub enum Command {
72
    /// create the USER(s).
73
    Create(CreateArgs),
74

            
75
    /// delete the USER(s).
76
    Delete(DeleteArgs),
77

            
78
    /// change the `MySQL` password for the USER(s).
79
    Passwd(PasswdArgs),
80

            
81
    /// give information about the USERS(s), or, if
82
    /// none are given, all the users you have.
83
    Show(ShowArgs),
84
}
85

            
86
#[derive(Parser)]
87
pub struct CreateArgs {
88
    /// The name of the USER(s) to create.
89
    #[arg(num_args = 1..)]
90
    #[cfg_attr(not(feature = "suid-sgid-mode"), arg(add = ArgValueCompleter::new(prefix_completer)))]
91
    name: Vec<MySQLUser>,
92
}
93

            
94
#[derive(Parser)]
95
pub struct DeleteArgs {
96
    /// The name of the USER(s) to delete.
97
    #[arg(num_args = 1..)]
98
    #[cfg_attr(not(feature = "suid-sgid-mode"), arg(add = ArgValueCompleter::new(mysql_user_completer)))]
99
    name: Vec<MySQLUser>,
100
}
101

            
102
#[derive(Parser)]
103
pub struct PasswdArgs {
104
    /// The name of the USER(s) to change the password for.
105
    #[arg(num_args = 1..)]
106
    #[cfg_attr(not(feature = "suid-sgid-mode"), arg(add = ArgValueCompleter::new(mysql_user_completer)))]
107
    name: Vec<MySQLUser>,
108
}
109

            
110
#[derive(Parser)]
111
pub struct ShowArgs {
112
    /// The name of the USER(s) to show.
113
    #[arg(num_args = 0..)]
114
    #[cfg_attr(not(feature = "suid-sgid-mode"), arg(add = ArgValueCompleter::new(mysql_user_completer)))]
115
    name: Vec<MySQLUser>,
116
}
117

            
118
/// **WARNING:** This function may be run with elevated privileges.
119
pub fn main() -> anyhow::Result<()> {
120
    let args: Args = Args::parse();
121

            
122
    let Some(command) = args.command else {
123
        println!(
124
            "Try `{} --help' for more information.",
125
            std::env::args()
126
                .next()
127
                .unwrap_or("mysql-useradm".to_string())
128
        );
129
        return Ok(());
130
    };
131

            
132
    let server_connection = bootstrap_server_connection_and_drop_privileges(
133
        args.server_socket_path,
134
        args.config,
135
        Default::default(),
136
    )?;
137

            
138
    tokio_run_command(command, server_connection)?;
139

            
140
    Ok(())
141
}
142

            
143
fn tokio_run_command(command: Command, server_connection: StdUnixStream) -> anyhow::Result<()> {
144
    tokio::runtime::Builder::new_current_thread()
145
        .enable_all()
146
        .build()
147
        .unwrap()
148
        .block_on(async {
149
            let tokio_socket = TokioUnixStream::from_std(server_connection)?;
150
            let mut message_stream = create_client_to_server_message_stream(tokio_socket);
151

            
152
            while let Some(Ok(message)) = message_stream.next().await {
153
                match message {
154
                    Response::Error(err) => {
155
                        anyhow::bail!("{err}");
156
                    }
157
                    Response::Ready => break,
158
                    message => {
159
                        eprintln!("Unexpected message from server: {message:?}");
160
                    }
161
                }
162
            }
163

            
164
            match command {
165
                Command::Create(args) => create_user(args, message_stream).await,
166
                Command::Delete(args) => drop_users(args, message_stream).await,
167
                Command::Passwd(args) => passwd_users(args, message_stream).await,
168
                Command::Show(args) => show_users(args, message_stream).await,
169
            }
170
        })
171
}
172

            
173
async fn create_user(
174
    args: CreateArgs,
175
    mut server_connection: ClientToServerMessageStream,
176
) -> anyhow::Result<()> {
177
    let db_users = args.name.iter().map(trim_user_name_to_32_chars).collect();
178

            
179
    let message = Request::CreateUsers(db_users);
180
    server_connection.send(message).await?;
181

            
182
    let result = match server_connection.next().await {
183
        Some(Ok(Response::CreateUsers(result))) => result,
184
        response => return erroneous_server_response(response),
185
    };
186

            
187
    server_connection.send(Request::Exit).await?;
188

            
189
    for (name, result) in result {
190
        match result {
191
            Ok(()) => println!("User '{name}' created."),
192
            Err(err) => handle_create_user_error(&err, &name),
193
        }
194
    }
195

            
196
    Ok(())
197
}
198

            
199
async fn drop_users(
200
    args: DeleteArgs,
201
    mut server_connection: ClientToServerMessageStream,
202
) -> anyhow::Result<()> {
203
    let db_users = args.name.iter().map(trim_user_name_to_32_chars).collect();
204

            
205
    let message = Request::DropUsers(db_users);
206
    server_connection.send(message).await?;
207

            
208
    let result = match server_connection.next().await {
209
        Some(Ok(Response::DropUsers(result))) => result,
210
        response => return erroneous_server_response(response),
211
    };
212

            
213
    server_connection.send(Request::Exit).await?;
214

            
215
    for (name, result) in result {
216
        match result {
217
            Ok(()) => println!("User '{name}' deleted."),
218
            Err(err) => handle_drop_user_error(&err, &name),
219
        }
220
    }
221

            
222
    Ok(())
223
}
224

            
225
async fn passwd_users(
226
    args: PasswdArgs,
227
    mut server_connection: ClientToServerMessageStream,
228
) -> anyhow::Result<()> {
229
    let db_users = args.name.iter().map(trim_user_name_to_32_chars).collect();
230

            
231
    let message = Request::ListUsers(Some(db_users));
232
    server_connection.send(message).await?;
233

            
234
    let response = match server_connection.next().await {
235
        Some(Ok(Response::ListUsers(result))) => result,
236
        response => return erroneous_server_response(response),
237
    };
238

            
239
    let argv0 = std::env::args()
240
        .next()
241
        .unwrap_or("mysql-useradm".to_string());
242

            
243
    let users = response
244
        .into_iter()
245
        .filter_map(|(name, result)| match result {
246
            Ok(user) => Some(user),
247
            Err(err) => {
248
                handle_list_users_error(&err, &name);
249
                None
250
            }
251
        })
252
        .collect::<Vec<_>>();
253

            
254
    for user in users {
255
        let password = read_password_from_stdin_with_double_check(&user.user)?;
256
        let message = Request::PasswdUser((user.user.clone(), password));
257
        server_connection.send(message).await?;
258
        match server_connection.next().await {
259
            Some(Ok(Response::SetUserPassword(result))) => match result {
260
                Ok(()) => println!("Password updated for user '{}'.", &user.user),
261
                Err(_) => eprintln!(
262
                    "{}: Failed to update password for user '{}'.",
263
                    argv0, user.user,
264
                ),
265
            },
266
            response => return erroneous_server_response(response),
267
        }
268
    }
269

            
270
    server_connection.send(Request::Exit).await?;
271

            
272
    Ok(())
273
}
274

            
275
async fn show_users(
276
    args: ShowArgs,
277
    mut server_connection: ClientToServerMessageStream,
278
) -> anyhow::Result<()> {
279
    let db_users: Vec<_> = args.name.iter().map(trim_user_name_to_32_chars).collect();
280

            
281
    let message = if db_users.is_empty() {
282
        Request::ListUsers(None)
283
    } else {
284
        Request::ListUsers(Some(db_users))
285
    };
286
    server_connection.send(message).await?;
287

            
288
    let users: Vec<DatabaseUser> = match server_connection.next().await {
289
        Some(Ok(Response::ListAllUsers(result))) => match result {
290
            Ok(users) => users,
291
            Err(err) => {
292
                eprintln!("Failed to list users: {err:?}");
293
                return Ok(());
294
            }
295
        },
296
        Some(Ok(Response::ListUsers(result))) => result
297
            .into_iter()
298
            .filter_map(|(name, result)| match result {
299
                Ok(user) => Some(user),
300
                Err(err) => {
301
                    handle_list_users_error(&err, &name);
302
                    None
303
                }
304
            })
305
            .collect(),
306
        response => return erroneous_server_response(response),
307
    };
308

            
309
    server_connection.send(Request::Exit).await?;
310

            
311
    for user in users {
312
        if user.has_password {
313
            println!("User '{}': password set.", user.user);
314
        } else {
315
            println!("User '{}': no password set.", user.user);
316
        }
317
    }
318

            
319
    Ok(())
320
}