mysqladm/cli/mysql_admutils_compatibility/
mysql_useradm.rs

1use clap::Parser;
2use futures_util::{SinkExt, StreamExt};
3use std::path::PathBuf;
4
5use std::os::unix::net::UnixStream as StdUnixStream;
6use tokio::net::UnixStream as TokioUnixStream;
7
8use crate::{
9    cli::{
10        common::erroneous_server_response,
11        mysql_admutils_compatibility::{
12            common::trim_user_name_to_32_chars,
13            error_messages::{
14                handle_create_user_error, handle_drop_user_error, handle_list_users_error,
15            },
16        },
17        user_command::read_password_from_stdin_with_double_check,
18    },
19    core::{
20        bootstrap::bootstrap_server_connection_and_drop_privileges,
21        protocol::{
22            ClientToServerMessageStream, MySQLUser, Request, Response,
23            create_client_to_server_message_stream,
24        },
25    },
26    server::sql::user_operations::DatabaseUser,
27};
28
29/// Create, delete or change password for the USER(s),
30/// as determined by the COMMAND.
31///
32/// This is a compatibility layer for the mysql-useradm command.
33/// Please consider using the newer mysqladm command instead.
34#[derive(Parser)]
35#[command(
36    bin_name = "mysql-useradm",
37    version,
38    about,
39    disable_help_subcommand = true,
40    verbatim_doc_comment
41)]
42pub struct Args {
43    #[command(subcommand)]
44    pub command: Option<Command>,
45
46    /// Path to the socket of the server, if it already exists.
47    #[arg(
48        short,
49        long,
50        value_name = "PATH",
51        global = true,
52        hide_short_help = true
53    )]
54    server_socket_path: Option<PathBuf>,
55
56    /// Config file to use for the server.
57    #[arg(
58        short,
59        long,
60        value_name = "PATH",
61        global = true,
62        hide_short_help = true
63    )]
64    config: Option<PathBuf>,
65}
66
67#[derive(Parser)]
68pub enum Command {
69    /// create the USER(s).
70    Create(CreateArgs),
71
72    /// delete the USER(s).
73    Delete(DeleteArgs),
74
75    /// change the MySQL password for the USER(s).
76    Passwd(PasswdArgs),
77
78    /// give information about the USERS(s), or, if
79    /// none are given, all the users you have.
80    Show(ShowArgs),
81}
82
83#[derive(Parser)]
84pub struct CreateArgs {
85    /// The name of the USER(s) to create.
86    #[arg(num_args = 1..)]
87    name: Vec<MySQLUser>,
88}
89
90#[derive(Parser)]
91pub struct DeleteArgs {
92    /// The name of the USER(s) to delete.
93    #[arg(num_args = 1..)]
94    name: Vec<MySQLUser>,
95}
96
97#[derive(Parser)]
98pub struct PasswdArgs {
99    /// The name of the USER(s) to change the password for.
100    #[arg(num_args = 1..)]
101    name: Vec<MySQLUser>,
102}
103
104#[derive(Parser)]
105pub struct ShowArgs {
106    /// The name of the USER(s) to show.
107    #[arg(num_args = 0..)]
108    name: Vec<MySQLUser>,
109}
110
111/// **WARNING:** This function may be run with elevated privileges.
112pub fn main() -> anyhow::Result<()> {
113    let args: Args = Args::parse();
114
115    let command = match args.command {
116        Some(command) => command,
117        None => {
118            println!(
119                "Try `{} --help' for more information.",
120                std::env::args()
121                    .next()
122                    .unwrap_or("mysql-useradm".to_string())
123            );
124            return Ok(());
125        }
126    };
127
128    let server_connection = bootstrap_server_connection_and_drop_privileges(
129        args.server_socket_path,
130        args.config,
131        Default::default(),
132    )?;
133
134    tokio_run_command(command, server_connection)?;
135
136    Ok(())
137}
138
139fn tokio_run_command(command: Command, server_connection: StdUnixStream) -> anyhow::Result<()> {
140    tokio::runtime::Builder::new_current_thread()
141        .enable_all()
142        .build()
143        .unwrap()
144        .block_on(async {
145            let tokio_socket = TokioUnixStream::from_std(server_connection)?;
146            let message_stream = create_client_to_server_message_stream(tokio_socket);
147            match command {
148                Command::Create(args) => create_user(args, message_stream).await,
149                Command::Delete(args) => drop_users(args, message_stream).await,
150                Command::Passwd(args) => passwd_users(args, message_stream).await,
151                Command::Show(args) => show_users(args, message_stream).await,
152            }
153        })
154}
155
156async fn create_user(
157    args: CreateArgs,
158    mut server_connection: ClientToServerMessageStream,
159) -> anyhow::Result<()> {
160    let db_users = args.name.iter().map(trim_user_name_to_32_chars).collect();
161
162    let message = Request::CreateUsers(db_users);
163    server_connection.send(message).await?;
164
165    let result = match server_connection.next().await {
166        Some(Ok(Response::CreateUsers(result))) => result,
167        response => return erroneous_server_response(response),
168    };
169
170    server_connection.send(Request::Exit).await?;
171
172    for (name, result) in result {
173        match result {
174            Ok(()) => println!("User '{}' created.", name),
175            Err(err) => handle_create_user_error(err, &name),
176        }
177    }
178
179    Ok(())
180}
181
182async fn drop_users(
183    args: DeleteArgs,
184    mut server_connection: ClientToServerMessageStream,
185) -> anyhow::Result<()> {
186    let db_users = args.name.iter().map(trim_user_name_to_32_chars).collect();
187
188    let message = Request::DropUsers(db_users);
189    server_connection.send(message).await?;
190
191    let result = match server_connection.next().await {
192        Some(Ok(Response::DropUsers(result))) => result,
193        response => return erroneous_server_response(response),
194    };
195
196    server_connection.send(Request::Exit).await?;
197
198    for (name, result) in result {
199        match result {
200            Ok(()) => println!("User '{}' deleted.", name),
201            Err(err) => handle_drop_user_error(err, &name),
202        }
203    }
204
205    Ok(())
206}
207
208async fn passwd_users(
209    args: PasswdArgs,
210    mut server_connection: ClientToServerMessageStream,
211) -> anyhow::Result<()> {
212    let db_users = args.name.iter().map(trim_user_name_to_32_chars).collect();
213
214    let message = Request::ListUsers(Some(db_users));
215    server_connection.send(message).await?;
216
217    let response = match server_connection.next().await {
218        Some(Ok(Response::ListUsers(result))) => result,
219        response => return erroneous_server_response(response),
220    };
221
222    let argv0 = std::env::args()
223        .next()
224        .unwrap_or("mysql-useradm".to_string());
225
226    let users = response
227        .into_iter()
228        .filter_map(|(name, result)| match result {
229            Ok(user) => Some(user),
230            Err(err) => {
231                handle_list_users_error(err, &name);
232                None
233            }
234        })
235        .collect::<Vec<_>>();
236
237    for user in users {
238        let password = read_password_from_stdin_with_double_check(&user.user)?;
239        let message = Request::PasswdUser(user.user.to_owned(), password);
240        server_connection.send(message).await?;
241        match server_connection.next().await {
242            Some(Ok(Response::PasswdUser(result))) => match result {
243                Ok(()) => println!("Password updated for user '{}'.", &user.user),
244                Err(_) => eprintln!(
245                    "{}: Failed to update password for user '{}'.",
246                    argv0, user.user,
247                ),
248            },
249            response => return erroneous_server_response(response),
250        }
251    }
252
253    server_connection.send(Request::Exit).await?;
254
255    Ok(())
256}
257
258async fn show_users(
259    args: ShowArgs,
260    mut server_connection: ClientToServerMessageStream,
261) -> anyhow::Result<()> {
262    let db_users: Vec<_> = args.name.iter().map(trim_user_name_to_32_chars).collect();
263
264    let message = if db_users.is_empty() {
265        Request::ListUsers(None)
266    } else {
267        Request::ListUsers(Some(db_users))
268    };
269    server_connection.send(message).await?;
270
271    let users: Vec<DatabaseUser> = match server_connection.next().await {
272        Some(Ok(Response::ListAllUsers(result))) => match result {
273            Ok(users) => users,
274            Err(err) => {
275                println!("Failed to list users: {:?}", err);
276                return Ok(());
277            }
278        },
279        Some(Ok(Response::ListUsers(result))) => result
280            .into_iter()
281            .filter_map(|(name, result)| match result {
282                Ok(user) => Some(user),
283                Err(err) => {
284                    handle_list_users_error(err, &name);
285                    None
286                }
287            })
288            .collect(),
289        response => return erroneous_server_response(response),
290    };
291
292    server_connection.send(Request::Exit).await?;
293
294    for user in users {
295        if user.has_password {
296            println!("User '{}': password set.", user.user);
297        } else {
298            println!("User '{}': no password set.", user.user);
299        }
300    }
301
302    Ok(())
303}