mysqladm/cli/mysql_admutils_compatibility/
mysql_useradm.rs1use clap::Parser;
2use futures_util::{SinkExt, StreamExt};
3use std::path::PathBuf;
4
5use std::os::unix::net::UnixStream as StdUnixStream;
6use tokio::net::UnixStream as TokioUnixStream;
7
8use crate::{
9 cli::{
10 common::erroneous_server_response,
11 mysql_admutils_compatibility::{
12 common::trim_user_name_to_32_chars,
13 error_messages::{
14 handle_create_user_error, handle_drop_user_error, handle_list_users_error,
15 },
16 },
17 user_command::read_password_from_stdin_with_double_check,
18 },
19 core::{
20 bootstrap::bootstrap_server_connection_and_drop_privileges,
21 protocol::{
22 ClientToServerMessageStream, MySQLUser, Request, Response,
23 create_client_to_server_message_stream,
24 },
25 },
26 server::sql::user_operations::DatabaseUser,
27};
28
29#[derive(Parser)]
35#[command(
36 bin_name = "mysql-useradm",
37 version,
38 about,
39 disable_help_subcommand = true,
40 verbatim_doc_comment
41)]
42pub struct Args {
43 #[command(subcommand)]
44 pub command: Option<Command>,
45
46 #[arg(
48 short,
49 long,
50 value_name = "PATH",
51 global = true,
52 hide_short_help = true
53 )]
54 server_socket_path: Option<PathBuf>,
55
56 #[arg(
58 short,
59 long,
60 value_name = "PATH",
61 global = true,
62 hide_short_help = true
63 )]
64 config: Option<PathBuf>,
65}
66
67#[derive(Parser)]
68pub enum Command {
69 Create(CreateArgs),
71
72 Delete(DeleteArgs),
74
75 Passwd(PasswdArgs),
77
78 Show(ShowArgs),
81}
82
83#[derive(Parser)]
84pub struct CreateArgs {
85 #[arg(num_args = 1..)]
87 name: Vec<MySQLUser>,
88}
89
90#[derive(Parser)]
91pub struct DeleteArgs {
92 #[arg(num_args = 1..)]
94 name: Vec<MySQLUser>,
95}
96
97#[derive(Parser)]
98pub struct PasswdArgs {
99 #[arg(num_args = 1..)]
101 name: Vec<MySQLUser>,
102}
103
104#[derive(Parser)]
105pub struct ShowArgs {
106 #[arg(num_args = 0..)]
108 name: Vec<MySQLUser>,
109}
110
111pub fn main() -> anyhow::Result<()> {
113 let args: Args = Args::parse();
114
115 let command = match args.command {
116 Some(command) => command,
117 None => {
118 println!(
119 "Try `{} --help' for more information.",
120 std::env::args()
121 .next()
122 .unwrap_or("mysql-useradm".to_string())
123 );
124 return Ok(());
125 }
126 };
127
128 let server_connection = bootstrap_server_connection_and_drop_privileges(
129 args.server_socket_path,
130 args.config,
131 Default::default(),
132 )?;
133
134 tokio_run_command(command, server_connection)?;
135
136 Ok(())
137}
138
139fn tokio_run_command(command: Command, server_connection: StdUnixStream) -> anyhow::Result<()> {
140 tokio::runtime::Builder::new_current_thread()
141 .enable_all()
142 .build()
143 .unwrap()
144 .block_on(async {
145 let tokio_socket = TokioUnixStream::from_std(server_connection)?;
146 let message_stream = create_client_to_server_message_stream(tokio_socket);
147 match command {
148 Command::Create(args) => create_user(args, message_stream).await,
149 Command::Delete(args) => drop_users(args, message_stream).await,
150 Command::Passwd(args) => passwd_users(args, message_stream).await,
151 Command::Show(args) => show_users(args, message_stream).await,
152 }
153 })
154}
155
156async fn create_user(
157 args: CreateArgs,
158 mut server_connection: ClientToServerMessageStream,
159) -> anyhow::Result<()> {
160 let db_users = args.name.iter().map(trim_user_name_to_32_chars).collect();
161
162 let message = Request::CreateUsers(db_users);
163 server_connection.send(message).await?;
164
165 let result = match server_connection.next().await {
166 Some(Ok(Response::CreateUsers(result))) => result,
167 response => return erroneous_server_response(response),
168 };
169
170 server_connection.send(Request::Exit).await?;
171
172 for (name, result) in result {
173 match result {
174 Ok(()) => println!("User '{}' created.", name),
175 Err(err) => handle_create_user_error(err, &name),
176 }
177 }
178
179 Ok(())
180}
181
182async fn drop_users(
183 args: DeleteArgs,
184 mut server_connection: ClientToServerMessageStream,
185) -> anyhow::Result<()> {
186 let db_users = args.name.iter().map(trim_user_name_to_32_chars).collect();
187
188 let message = Request::DropUsers(db_users);
189 server_connection.send(message).await?;
190
191 let result = match server_connection.next().await {
192 Some(Ok(Response::DropUsers(result))) => result,
193 response => return erroneous_server_response(response),
194 };
195
196 server_connection.send(Request::Exit).await?;
197
198 for (name, result) in result {
199 match result {
200 Ok(()) => println!("User '{}' deleted.", name),
201 Err(err) => handle_drop_user_error(err, &name),
202 }
203 }
204
205 Ok(())
206}
207
208async fn passwd_users(
209 args: PasswdArgs,
210 mut server_connection: ClientToServerMessageStream,
211) -> anyhow::Result<()> {
212 let db_users = args.name.iter().map(trim_user_name_to_32_chars).collect();
213
214 let message = Request::ListUsers(Some(db_users));
215 server_connection.send(message).await?;
216
217 let response = match server_connection.next().await {
218 Some(Ok(Response::ListUsers(result))) => result,
219 response => return erroneous_server_response(response),
220 };
221
222 let argv0 = std::env::args()
223 .next()
224 .unwrap_or("mysql-useradm".to_string());
225
226 let users = response
227 .into_iter()
228 .filter_map(|(name, result)| match result {
229 Ok(user) => Some(user),
230 Err(err) => {
231 handle_list_users_error(err, &name);
232 None
233 }
234 })
235 .collect::<Vec<_>>();
236
237 for user in users {
238 let password = read_password_from_stdin_with_double_check(&user.user)?;
239 let message = Request::PasswdUser(user.user.to_owned(), password);
240 server_connection.send(message).await?;
241 match server_connection.next().await {
242 Some(Ok(Response::PasswdUser(result))) => match result {
243 Ok(()) => println!("Password updated for user '{}'.", &user.user),
244 Err(_) => eprintln!(
245 "{}: Failed to update password for user '{}'.",
246 argv0, user.user,
247 ),
248 },
249 response => return erroneous_server_response(response),
250 }
251 }
252
253 server_connection.send(Request::Exit).await?;
254
255 Ok(())
256}
257
258async fn show_users(
259 args: ShowArgs,
260 mut server_connection: ClientToServerMessageStream,
261) -> anyhow::Result<()> {
262 let db_users: Vec<_> = args.name.iter().map(trim_user_name_to_32_chars).collect();
263
264 let message = if db_users.is_empty() {
265 Request::ListUsers(None)
266 } else {
267 Request::ListUsers(Some(db_users))
268 };
269 server_connection.send(message).await?;
270
271 let users: Vec<DatabaseUser> = match server_connection.next().await {
272 Some(Ok(Response::ListAllUsers(result))) => match result {
273 Ok(users) => users,
274 Err(err) => {
275 println!("Failed to list users: {:?}", err);
276 return Ok(());
277 }
278 },
279 Some(Ok(Response::ListUsers(result))) => result
280 .into_iter()
281 .filter_map(|(name, result)| match result {
282 Ok(user) => Some(user),
283 Err(err) => {
284 handle_list_users_error(err, &name);
285 None
286 }
287 })
288 .collect(),
289 response => return erroneous_server_response(response),
290 };
291
292 server_connection.send(Request::Exit).await?;
293
294 for user in users {
295 if user.has_password {
296 println!("User '{}': password set.", user.user);
297 } else {
298 println!("User '{}': no password set.", user.user);
299 }
300 }
301
302 Ok(())
303}