1
use clap::Parser;
2
use futures_util::{SinkExt, StreamExt};
3
use std::path::PathBuf;
4

            
5
use std::os::unix::net::UnixStream as StdUnixStream;
6
use tokio::net::UnixStream as TokioUnixStream;
7

            
8
use crate::{
9
    client::{
10
        commands::{erroneous_server_response, read_password_from_stdin_with_double_check},
11
        mysql_admutils_compatibility::{
12
            common::trim_user_name_to_32_chars,
13
            error_messages::{
14
                handle_create_user_error, handle_drop_user_error, handle_list_users_error,
15
            },
16
        },
17
    },
18
    core::{
19
        bootstrap::bootstrap_server_connection_and_drop_privileges,
20
        protocol::{
21
            ClientToServerMessageStream, Request, Response, create_client_to_server_message_stream,
22
        },
23
        types::MySQLUser,
24
    },
25
    server::sql::user_operations::DatabaseUser,
26
};
27

            
28
/// Create, delete or change password for the USER(s),
29
/// as determined by the COMMAND.
30
///
31
/// This is a compatibility layer for the 'mysql-useradm' command.
32
/// Please consider using the newer 'muscl' command instead.
33
#[derive(Parser)]
34
#[command(
35
    bin_name = "mysql-useradm",
36
    version,
37
    about,
38
    disable_help_subcommand = true,
39
    verbatim_doc_comment
40
)]
41
pub struct Args {
42
    #[command(subcommand)]
43
    pub command: Option<Command>,
44

            
45
    /// Path to the socket of the server, if it already exists.
46
    #[arg(
47
        short,
48
        long,
49
        value_name = "PATH",
50
        global = true,
51
        hide_short_help = true
52
    )]
53
    server_socket_path: Option<PathBuf>,
54

            
55
    /// Config file to use for the server.
56
    #[arg(
57
        short,
58
        long,
59
        value_name = "PATH",
60
        global = true,
61
        hide_short_help = true
62
    )]
63
    config: Option<PathBuf>,
64
}
65

            
66
#[derive(Parser)]
67
pub enum Command {
68
    /// create the USER(s).
69
    Create(CreateArgs),
70

            
71
    /// delete the USER(s).
72
    Delete(DeleteArgs),
73

            
74
    /// change the MySQL password for the USER(s).
75
    Passwd(PasswdArgs),
76

            
77
    /// give information about the USERS(s), or, if
78
    /// none are given, all the users you have.
79
    Show(ShowArgs),
80
}
81

            
82
#[derive(Parser)]
83
pub struct CreateArgs {
84
    /// The name of the USER(s) to create.
85
    #[arg(num_args = 1..)]
86
    name: Vec<MySQLUser>,
87
}
88

            
89
#[derive(Parser)]
90
pub struct DeleteArgs {
91
    /// The name of the USER(s) to delete.
92
    #[arg(num_args = 1..)]
93
    name: Vec<MySQLUser>,
94
}
95

            
96
#[derive(Parser)]
97
pub struct PasswdArgs {
98
    /// The name of the USER(s) to change the password for.
99
    #[arg(num_args = 1..)]
100
    name: Vec<MySQLUser>,
101
}
102

            
103
#[derive(Parser)]
104
pub struct ShowArgs {
105
    /// The name of the USER(s) to show.
106
    #[arg(num_args = 0..)]
107
    name: Vec<MySQLUser>,
108
}
109

            
110
/// **WARNING:** This function may be run with elevated privileges.
111
pub fn main() -> anyhow::Result<()> {
112
    let args: Args = Args::parse();
113

            
114
    let command = match args.command {
115
        Some(command) => command,
116
        None => {
117
            println!(
118
                "Try `{} --help' for more information.",
119
                std::env::args()
120
                    .next()
121
                    .unwrap_or("mysql-useradm".to_string())
122
            );
123
            return Ok(());
124
        }
125
    };
126

            
127
    let server_connection = bootstrap_server_connection_and_drop_privileges(
128
        args.server_socket_path,
129
        args.config,
130
        Default::default(),
131
    )?;
132

            
133
    tokio_run_command(command, server_connection)?;
134

            
135
    Ok(())
136
}
137

            
138
fn tokio_run_command(command: Command, server_connection: StdUnixStream) -> anyhow::Result<()> {
139
    tokio::runtime::Builder::new_current_thread()
140
        .enable_all()
141
        .build()
142
        .unwrap()
143
        .block_on(async {
144
            let tokio_socket = TokioUnixStream::from_std(server_connection)?;
145
            let message_stream = create_client_to_server_message_stream(tokio_socket);
146
            match command {
147
                Command::Create(args) => create_user(args, message_stream).await,
148
                Command::Delete(args) => drop_users(args, message_stream).await,
149
                Command::Passwd(args) => passwd_users(args, message_stream).await,
150
                Command::Show(args) => show_users(args, message_stream).await,
151
            }
152
        })
153
}
154

            
155
async fn create_user(
156
    args: CreateArgs,
157
    mut server_connection: ClientToServerMessageStream,
158
) -> anyhow::Result<()> {
159
    let db_users = args.name.iter().map(trim_user_name_to_32_chars).collect();
160

            
161
    let message = Request::CreateUsers(db_users);
162
    server_connection.send(message).await?;
163

            
164
    let result = match server_connection.next().await {
165
        Some(Ok(Response::CreateUsers(result))) => result,
166
        response => return erroneous_server_response(response),
167
    };
168

            
169
    server_connection.send(Request::Exit).await?;
170

            
171
    for (name, result) in result {
172
        match result {
173
            Ok(()) => println!("User '{}' created.", name),
174
            Err(err) => handle_create_user_error(err, &name),
175
        }
176
    }
177

            
178
    Ok(())
179
}
180

            
181
async fn drop_users(
182
    args: DeleteArgs,
183
    mut server_connection: ClientToServerMessageStream,
184
) -> anyhow::Result<()> {
185
    let db_users = args.name.iter().map(trim_user_name_to_32_chars).collect();
186

            
187
    let message = Request::DropUsers(db_users);
188
    server_connection.send(message).await?;
189

            
190
    let result = match server_connection.next().await {
191
        Some(Ok(Response::DropUsers(result))) => result,
192
        response => return erroneous_server_response(response),
193
    };
194

            
195
    server_connection.send(Request::Exit).await?;
196

            
197
    for (name, result) in result {
198
        match result {
199
            Ok(()) => println!("User '{}' deleted.", name),
200
            Err(err) => handle_drop_user_error(err, &name),
201
        }
202
    }
203

            
204
    Ok(())
205
}
206

            
207
async fn passwd_users(
208
    args: PasswdArgs,
209
    mut server_connection: ClientToServerMessageStream,
210
) -> anyhow::Result<()> {
211
    let db_users = args.name.iter().map(trim_user_name_to_32_chars).collect();
212

            
213
    let message = Request::ListUsers(Some(db_users));
214
    server_connection.send(message).await?;
215

            
216
    let response = match server_connection.next().await {
217
        Some(Ok(Response::ListUsers(result))) => result,
218
        response => return erroneous_server_response(response),
219
    };
220

            
221
    let argv0 = std::env::args()
222
        .next()
223
        .unwrap_or("mysql-useradm".to_string());
224

            
225
    let users = response
226
        .into_iter()
227
        .filter_map(|(name, result)| match result {
228
            Ok(user) => Some(user),
229
            Err(err) => {
230
                handle_list_users_error(err, &name);
231
                None
232
            }
233
        })
234
        .collect::<Vec<_>>();
235

            
236
    for user in users {
237
        let password = read_password_from_stdin_with_double_check(&user.user)?;
238
        let message = Request::PasswdUser((user.user.to_owned(), password));
239
        server_connection.send(message).await?;
240
        match server_connection.next().await {
241
            Some(Ok(Response::SetUserPassword(result))) => match result {
242
                Ok(()) => println!("Password updated for user '{}'.", &user.user),
243
                Err(_) => eprintln!(
244
                    "{}: Failed to update password for user '{}'.",
245
                    argv0, user.user,
246
                ),
247
            },
248
            response => return erroneous_server_response(response),
249
        }
250
    }
251

            
252
    server_connection.send(Request::Exit).await?;
253

            
254
    Ok(())
255
}
256

            
257
async fn show_users(
258
    args: ShowArgs,
259
    mut server_connection: ClientToServerMessageStream,
260
) -> anyhow::Result<()> {
261
    let db_users: Vec<_> = args.name.iter().map(trim_user_name_to_32_chars).collect();
262

            
263
    let message = if db_users.is_empty() {
264
        Request::ListUsers(None)
265
    } else {
266
        Request::ListUsers(Some(db_users))
267
    };
268
    server_connection.send(message).await?;
269

            
270
    let users: Vec<DatabaseUser> = match server_connection.next().await {
271
        Some(Ok(Response::ListAllUsers(result))) => match result {
272
            Ok(users) => users,
273
            Err(err) => {
274
                println!("Failed to list users: {:?}", err);
275
                return Ok(());
276
            }
277
        },
278
        Some(Ok(Response::ListUsers(result))) => result
279
            .into_iter()
280
            .filter_map(|(name, result)| match result {
281
                Ok(user) => Some(user),
282
                Err(err) => {
283
                    handle_list_users_error(err, &name);
284
                    None
285
                }
286
            })
287
            .collect(),
288
        response => return erroneous_server_response(response),
289
    };
290

            
291
    server_connection.send(Request::Exit).await?;
292

            
293
    for user in users {
294
        if user.has_password {
295
            println!("User '{}': password set.", user.user);
296
        } else {
297
            println!("User '{}': no password set.", user.user);
298
        }
299
    }
300

            
301
    Ok(())
302
}