mysqladm/client/mysql_admutils_compatibility/
mysql_useradm.rs1use clap::Parser;
2use futures_util::{SinkExt, StreamExt};
3use std::path::PathBuf;
4
5use std::os::unix::net::UnixStream as StdUnixStream;
6use tokio::net::UnixStream as TokioUnixStream;
7
8use crate::{
9 client::{
10 commands::{erroneous_server_response, read_password_from_stdin_with_double_check},
11 mysql_admutils_compatibility::{
12 common::trim_user_name_to_32_chars,
13 error_messages::{
14 handle_create_user_error, handle_drop_user_error, handle_list_users_error,
15 },
16 },
17 },
18 core::{
19 bootstrap::bootstrap_server_connection_and_drop_privileges,
20 protocol::{
21 ClientToServerMessageStream, Request, Response, create_client_to_server_message_stream,
22 },
23 types::MySQLUser,
24 },
25 server::sql::user_operations::DatabaseUser,
26};
27
28#[derive(Parser)]
34#[command(
35 bin_name = "mysql-useradm",
36 version,
37 about,
38 disable_help_subcommand = true,
39 verbatim_doc_comment
40)]
41pub struct Args {
42 #[command(subcommand)]
43 pub command: Option<Command>,
44
45 #[arg(
47 short,
48 long,
49 value_name = "PATH",
50 global = true,
51 hide_short_help = true
52 )]
53 server_socket_path: Option<PathBuf>,
54
55 #[arg(
57 short,
58 long,
59 value_name = "PATH",
60 global = true,
61 hide_short_help = true
62 )]
63 config: Option<PathBuf>,
64}
65
66#[derive(Parser)]
67pub enum Command {
68 Create(CreateArgs),
70
71 Delete(DeleteArgs),
73
74 Passwd(PasswdArgs),
76
77 Show(ShowArgs),
80}
81
82#[derive(Parser)]
83pub struct CreateArgs {
84 #[arg(num_args = 1..)]
86 name: Vec<MySQLUser>,
87}
88
89#[derive(Parser)]
90pub struct DeleteArgs {
91 #[arg(num_args = 1..)]
93 name: Vec<MySQLUser>,
94}
95
96#[derive(Parser)]
97pub struct PasswdArgs {
98 #[arg(num_args = 1..)]
100 name: Vec<MySQLUser>,
101}
102
103#[derive(Parser)]
104pub struct ShowArgs {
105 #[arg(num_args = 0..)]
107 name: Vec<MySQLUser>,
108}
109
110pub fn main() -> anyhow::Result<()> {
112 let args: Args = Args::parse();
113
114 let command = match args.command {
115 Some(command) => command,
116 None => {
117 println!(
118 "Try `{} --help' for more information.",
119 std::env::args()
120 .next()
121 .unwrap_or("mysql-useradm".to_string())
122 );
123 return Ok(());
124 }
125 };
126
127 let server_connection = bootstrap_server_connection_and_drop_privileges(
128 args.server_socket_path,
129 args.config,
130 Default::default(),
131 )?;
132
133 tokio_run_command(command, server_connection)?;
134
135 Ok(())
136}
137
138fn tokio_run_command(command: Command, server_connection: StdUnixStream) -> anyhow::Result<()> {
139 tokio::runtime::Builder::new_current_thread()
140 .enable_all()
141 .build()
142 .unwrap()
143 .block_on(async {
144 let tokio_socket = TokioUnixStream::from_std(server_connection)?;
145 let message_stream = create_client_to_server_message_stream(tokio_socket);
146 match command {
147 Command::Create(args) => create_user(args, message_stream).await,
148 Command::Delete(args) => drop_users(args, message_stream).await,
149 Command::Passwd(args) => passwd_users(args, message_stream).await,
150 Command::Show(args) => show_users(args, message_stream).await,
151 }
152 })
153}
154
155async fn create_user(
156 args: CreateArgs,
157 mut server_connection: ClientToServerMessageStream,
158) -> anyhow::Result<()> {
159 let db_users = args.name.iter().map(trim_user_name_to_32_chars).collect();
160
161 let message = Request::CreateUsers(db_users);
162 server_connection.send(message).await?;
163
164 let result = match server_connection.next().await {
165 Some(Ok(Response::CreateUsers(result))) => result,
166 response => return erroneous_server_response(response),
167 };
168
169 server_connection.send(Request::Exit).await?;
170
171 for (name, result) in result {
172 match result {
173 Ok(()) => println!("User '{}' created.", name),
174 Err(err) => handle_create_user_error(err, &name),
175 }
176 }
177
178 Ok(())
179}
180
181async fn drop_users(
182 args: DeleteArgs,
183 mut server_connection: ClientToServerMessageStream,
184) -> anyhow::Result<()> {
185 let db_users = args.name.iter().map(trim_user_name_to_32_chars).collect();
186
187 let message = Request::DropUsers(db_users);
188 server_connection.send(message).await?;
189
190 let result = match server_connection.next().await {
191 Some(Ok(Response::DropUsers(result))) => result,
192 response => return erroneous_server_response(response),
193 };
194
195 server_connection.send(Request::Exit).await?;
196
197 for (name, result) in result {
198 match result {
199 Ok(()) => println!("User '{}' deleted.", name),
200 Err(err) => handle_drop_user_error(err, &name),
201 }
202 }
203
204 Ok(())
205}
206
207async fn passwd_users(
208 args: PasswdArgs,
209 mut server_connection: ClientToServerMessageStream,
210) -> anyhow::Result<()> {
211 let db_users = args.name.iter().map(trim_user_name_to_32_chars).collect();
212
213 let message = Request::ListUsers(Some(db_users));
214 server_connection.send(message).await?;
215
216 let response = match server_connection.next().await {
217 Some(Ok(Response::ListUsers(result))) => result,
218 response => return erroneous_server_response(response),
219 };
220
221 let argv0 = std::env::args()
222 .next()
223 .unwrap_or("mysql-useradm".to_string());
224
225 let users = response
226 .into_iter()
227 .filter_map(|(name, result)| match result {
228 Ok(user) => Some(user),
229 Err(err) => {
230 handle_list_users_error(err, &name);
231 None
232 }
233 })
234 .collect::<Vec<_>>();
235
236 for user in users {
237 let password = read_password_from_stdin_with_double_check(&user.user)?;
238 let message = Request::PasswdUser((user.user.to_owned(), password));
239 server_connection.send(message).await?;
240 match server_connection.next().await {
241 Some(Ok(Response::SetUserPassword(result))) => match result {
242 Ok(()) => println!("Password updated for user '{}'.", &user.user),
243 Err(_) => eprintln!(
244 "{}: Failed to update password for user '{}'.",
245 argv0, user.user,
246 ),
247 },
248 response => return erroneous_server_response(response),
249 }
250 }
251
252 server_connection.send(Request::Exit).await?;
253
254 Ok(())
255}
256
257async fn show_users(
258 args: ShowArgs,
259 mut server_connection: ClientToServerMessageStream,
260) -> anyhow::Result<()> {
261 let db_users: Vec<_> = args.name.iter().map(trim_user_name_to_32_chars).collect();
262
263 let message = if db_users.is_empty() {
264 Request::ListUsers(None)
265 } else {
266 Request::ListUsers(Some(db_users))
267 };
268 server_connection.send(message).await?;
269
270 let users: Vec<DatabaseUser> = match server_connection.next().await {
271 Some(Ok(Response::ListAllUsers(result))) => match result {
272 Ok(users) => users,
273 Err(err) => {
274 println!("Failed to list users: {:?}", err);
275 return Ok(());
276 }
277 },
278 Some(Ok(Response::ListUsers(result))) => result
279 .into_iter()
280 .filter_map(|(name, result)| match result {
281 Ok(user) => Some(user),
282 Err(err) => {
283 handle_list_users_error(err, &name);
284 None
285 }
286 })
287 .collect(),
288 response => return erroneous_server_response(response),
289 };
290
291 server_connection.send(Request::Exit).await?;
292
293 for user in users {
294 if user.has_password {
295 println!("User '{}': password set.", user.user);
296 } else {
297 println!("User '{}': no password set.", user.user);
298 }
299 }
300
301 Ok(())
302}