mysqladm/client/mysql_admutils_compatibility/
mysql_useradm.rs

1use clap::Parser;
2use futures_util::{SinkExt, StreamExt};
3use std::path::PathBuf;
4
5use std::os::unix::net::UnixStream as StdUnixStream;
6use tokio::net::UnixStream as TokioUnixStream;
7
8use crate::{
9    client::{
10        commands::{erroneous_server_response, read_password_from_stdin_with_double_check},
11        mysql_admutils_compatibility::{
12            common::trim_user_name_to_32_chars,
13            error_messages::{
14                handle_create_user_error, handle_drop_user_error, handle_list_users_error,
15            },
16        },
17    },
18    core::{
19        bootstrap::bootstrap_server_connection_and_drop_privileges,
20        protocol::{
21            ClientToServerMessageStream, Request, Response, create_client_to_server_message_stream,
22        },
23        types::MySQLUser,
24    },
25    server::sql::user_operations::DatabaseUser,
26};
27
28/// Create, delete or change password for the USER(s),
29/// as determined by the COMMAND.
30///
31/// This is a compatibility layer for the mysql-useradm command.
32/// Please consider using the newer mysqladm command instead.
33#[derive(Parser)]
34#[command(
35    bin_name = "mysql-useradm",
36    version,
37    about,
38    disable_help_subcommand = true,
39    verbatim_doc_comment
40)]
41pub struct Args {
42    #[command(subcommand)]
43    pub command: Option<Command>,
44
45    /// Path to the socket of the server, if it already exists.
46    #[arg(
47        short,
48        long,
49        value_name = "PATH",
50        global = true,
51        hide_short_help = true
52    )]
53    server_socket_path: Option<PathBuf>,
54
55    /// Config file to use for the server.
56    #[arg(
57        short,
58        long,
59        value_name = "PATH",
60        global = true,
61        hide_short_help = true
62    )]
63    config: Option<PathBuf>,
64}
65
66#[derive(Parser)]
67pub enum Command {
68    /// create the USER(s).
69    Create(CreateArgs),
70
71    /// delete the USER(s).
72    Delete(DeleteArgs),
73
74    /// change the MySQL password for the USER(s).
75    Passwd(PasswdArgs),
76
77    /// give information about the USERS(s), or, if
78    /// none are given, all the users you have.
79    Show(ShowArgs),
80}
81
82#[derive(Parser)]
83pub struct CreateArgs {
84    /// The name of the USER(s) to create.
85    #[arg(num_args = 1..)]
86    name: Vec<MySQLUser>,
87}
88
89#[derive(Parser)]
90pub struct DeleteArgs {
91    /// The name of the USER(s) to delete.
92    #[arg(num_args = 1..)]
93    name: Vec<MySQLUser>,
94}
95
96#[derive(Parser)]
97pub struct PasswdArgs {
98    /// The name of the USER(s) to change the password for.
99    #[arg(num_args = 1..)]
100    name: Vec<MySQLUser>,
101}
102
103#[derive(Parser)]
104pub struct ShowArgs {
105    /// The name of the USER(s) to show.
106    #[arg(num_args = 0..)]
107    name: Vec<MySQLUser>,
108}
109
110/// **WARNING:** This function may be run with elevated privileges.
111pub fn main() -> anyhow::Result<()> {
112    let args: Args = Args::parse();
113
114    let command = match args.command {
115        Some(command) => command,
116        None => {
117            println!(
118                "Try `{} --help' for more information.",
119                std::env::args()
120                    .next()
121                    .unwrap_or("mysql-useradm".to_string())
122            );
123            return Ok(());
124        }
125    };
126
127    let server_connection = bootstrap_server_connection_and_drop_privileges(
128        args.server_socket_path,
129        args.config,
130        Default::default(),
131    )?;
132
133    tokio_run_command(command, server_connection)?;
134
135    Ok(())
136}
137
138fn tokio_run_command(command: Command, server_connection: StdUnixStream) -> anyhow::Result<()> {
139    tokio::runtime::Builder::new_current_thread()
140        .enable_all()
141        .build()
142        .unwrap()
143        .block_on(async {
144            let tokio_socket = TokioUnixStream::from_std(server_connection)?;
145            let message_stream = create_client_to_server_message_stream(tokio_socket);
146            match command {
147                Command::Create(args) => create_user(args, message_stream).await,
148                Command::Delete(args) => drop_users(args, message_stream).await,
149                Command::Passwd(args) => passwd_users(args, message_stream).await,
150                Command::Show(args) => show_users(args, message_stream).await,
151            }
152        })
153}
154
155async fn create_user(
156    args: CreateArgs,
157    mut server_connection: ClientToServerMessageStream,
158) -> anyhow::Result<()> {
159    let db_users = args.name.iter().map(trim_user_name_to_32_chars).collect();
160
161    let message = Request::CreateUsers(db_users);
162    server_connection.send(message).await?;
163
164    let result = match server_connection.next().await {
165        Some(Ok(Response::CreateUsers(result))) => result,
166        response => return erroneous_server_response(response),
167    };
168
169    server_connection.send(Request::Exit).await?;
170
171    for (name, result) in result {
172        match result {
173            Ok(()) => println!("User '{}' created.", name),
174            Err(err) => handle_create_user_error(err, &name),
175        }
176    }
177
178    Ok(())
179}
180
181async fn drop_users(
182    args: DeleteArgs,
183    mut server_connection: ClientToServerMessageStream,
184) -> anyhow::Result<()> {
185    let db_users = args.name.iter().map(trim_user_name_to_32_chars).collect();
186
187    let message = Request::DropUsers(db_users);
188    server_connection.send(message).await?;
189
190    let result = match server_connection.next().await {
191        Some(Ok(Response::DropUsers(result))) => result,
192        response => return erroneous_server_response(response),
193    };
194
195    server_connection.send(Request::Exit).await?;
196
197    for (name, result) in result {
198        match result {
199            Ok(()) => println!("User '{}' deleted.", name),
200            Err(err) => handle_drop_user_error(err, &name),
201        }
202    }
203
204    Ok(())
205}
206
207async fn passwd_users(
208    args: PasswdArgs,
209    mut server_connection: ClientToServerMessageStream,
210) -> anyhow::Result<()> {
211    let db_users = args.name.iter().map(trim_user_name_to_32_chars).collect();
212
213    let message = Request::ListUsers(Some(db_users));
214    server_connection.send(message).await?;
215
216    let response = match server_connection.next().await {
217        Some(Ok(Response::ListUsers(result))) => result,
218        response => return erroneous_server_response(response),
219    };
220
221    let argv0 = std::env::args()
222        .next()
223        .unwrap_or("mysql-useradm".to_string());
224
225    let users = response
226        .into_iter()
227        .filter_map(|(name, result)| match result {
228            Ok(user) => Some(user),
229            Err(err) => {
230                handle_list_users_error(err, &name);
231                None
232            }
233        })
234        .collect::<Vec<_>>();
235
236    for user in users {
237        let password = read_password_from_stdin_with_double_check(&user.user)?;
238        let message = Request::PasswdUser((user.user.to_owned(), password));
239        server_connection.send(message).await?;
240        match server_connection.next().await {
241            Some(Ok(Response::SetUserPassword(result))) => match result {
242                Ok(()) => println!("Password updated for user '{}'.", &user.user),
243                Err(_) => eprintln!(
244                    "{}: Failed to update password for user '{}'.",
245                    argv0, user.user,
246                ),
247            },
248            response => return erroneous_server_response(response),
249        }
250    }
251
252    server_connection.send(Request::Exit).await?;
253
254    Ok(())
255}
256
257async fn show_users(
258    args: ShowArgs,
259    mut server_connection: ClientToServerMessageStream,
260) -> anyhow::Result<()> {
261    let db_users: Vec<_> = args.name.iter().map(trim_user_name_to_32_chars).collect();
262
263    let message = if db_users.is_empty() {
264        Request::ListUsers(None)
265    } else {
266        Request::ListUsers(Some(db_users))
267    };
268    server_connection.send(message).await?;
269
270    let users: Vec<DatabaseUser> = match server_connection.next().await {
271        Some(Ok(Response::ListAllUsers(result))) => match result {
272            Ok(users) => users,
273            Err(err) => {
274                println!("Failed to list users: {:?}", err);
275                return Ok(());
276            }
277        },
278        Some(Ok(Response::ListUsers(result))) => result
279            .into_iter()
280            .filter_map(|(name, result)| match result {
281                Ok(user) => Some(user),
282                Err(err) => {
283                    handle_list_users_error(err, &name);
284                    None
285                }
286            })
287            .collect(),
288        response => return erroneous_server_response(response),
289    };
290
291    server_connection.send(Request::Exit).await?;
292
293    for user in users {
294        if user.has_password {
295            println!("User '{}': password set.", user.user);
296        } else {
297            println!("User '{}': no password set.", user.user);
298        }
299    }
300
301    Ok(())
302}