1
use clap_complete::CompletionCandidate;
2
use clap_verbosity_flag::Verbosity;
3
use futures_util::SinkExt;
4
use tokio::net::UnixStream as TokioUnixStream;
5
use tokio_stream::StreamExt;
6

            
7
use crate::{
8
    client::commands::erroneous_server_response,
9
    core::{
10
        bootstrap::bootstrap_server_connection_and_drop_privileges,
11
        protocol::{Request, Response, create_client_to_server_message_stream},
12
    },
13
};
14

            
15
#[must_use]
16
pub fn prefix_completer(current: &std::ffi::OsStr) -> Vec<CompletionCandidate> {
17
    match tokio::runtime::Builder::new_current_thread()
18
        .enable_all()
19
        .build()
20
    {
21
        Ok(runtime) => match runtime.block_on(prefix_completer_(current)) {
22
            Ok(completions) => completions,
23
            Err(err) => {
24
                eprintln!("Error getting prefix completions: {err}");
25
                Vec::new()
26
            }
27
        },
28
        Err(err) => {
29
            eprintln!("Error starting Tokio runtime: {err}");
30
            Vec::new()
31
        }
32
    }
33
}
34

            
35
/// Connect to the server to get `MySQL` user completions.
36
async fn prefix_completer_(_current: &std::ffi::OsStr) -> anyhow::Result<Vec<CompletionCandidate>> {
37
    let server_connection =
38
        bootstrap_server_connection_and_drop_privileges(None, None, Verbosity::new(0, 1))?;
39

            
40
    let tokio_socket = TokioUnixStream::from_std(server_connection)?;
41
    let mut server_connection = create_client_to_server_message_stream(tokio_socket);
42

            
43
    while let Some(Ok(message)) = server_connection.next().await {
44
        match message {
45
            Response::Error(err) => {
46
                anyhow::bail!("{err}");
47
            }
48
            Response::Ready => break,
49
            message => {
50
                eprintln!("Unexpected message from server: {message:?}");
51
            }
52
        }
53
    }
54

            
55
    let message = Request::ListValidNamePrefixes;
56

            
57
    if let Err(err) = server_connection.send(message).await {
58
        server_connection.close().await.ok();
59
        anyhow::bail!(anyhow::Error::from(err).context("Failed to communicate with server"));
60
    }
61

            
62
    let result = match server_connection.next().await {
63
        Some(Ok(Response::ListValidNamePrefixes(prefixes))) => prefixes,
64
        response => return erroneous_server_response(response).map(|()| vec![]),
65
    };
66

            
67
    server_connection.send(Request::Exit).await?;
68

            
69
    let result = result
70
        .into_iter()
71
        .map(|prefix| prefix + "_")
72
        .map(CompletionCandidate::new)
73
        .collect();
74

            
75
    Ok(result)
76
}