1
use std::{fs, path::PathBuf, time::Duration};
2

            
3
use anyhow::{Context, anyhow};
4
use clap::Parser;
5
use serde::{Deserialize, Serialize};
6
use sqlx::{ConnectOptions, MySqlConnection, mysql::MySqlConnectOptions};
7

            
8
use crate::core::common::DEFAULT_CONFIG_PATH;
9

            
10
pub const DEFAULT_PORT: u16 = 3306;
11
pub const DEFAULT_TIMEOUT: u64 = 2;
12

            
13
// NOTE: this might look empty now, and the extra wrapping for the mysql
14
//       config seems unnecessary, but it will be useful later when we
15
//       add more configuration options.
16
#[derive(Debug, Clone, Deserialize, Serialize)]
17
pub struct ServerConfig {
18
    pub mysql: MysqlConfig,
19
}
20

            
21
#[derive(Debug, Clone, Deserialize, Serialize)]
22
#[serde(rename = "mysql")]
23
pub struct MysqlConfig {
24
    pub socket_path: Option<PathBuf>,
25
    pub host: Option<String>,
26
    pub port: Option<u16>,
27
    pub username: Option<String>,
28
    pub password: Option<String>,
29
    pub password_file: Option<PathBuf>,
30
    pub timeout: Option<u64>,
31
}
32

            
33
#[derive(Parser, Debug, Clone)]
34
pub struct ServerConfigArgs {
35
    /// Path to the socket of the MySQL server.
36
    #[arg(long, value_name = "PATH", global = true)]
37
    socket_path: Option<PathBuf>,
38

            
39
    /// Hostname of the MySQL server.
40
    #[arg(
41
        long,
42
        value_name = "HOST",
43
        global = true,
44
        conflicts_with = "socket_path"
45
    )]
46
    mysql_host: Option<String>,
47

            
48
    /// Port of the MySQL server.
49
    #[arg(
50
        long,
51
        value_name = "PORT",
52
        global = true,
53
        conflicts_with = "socket_path"
54
    )]
55
    mysql_port: Option<u16>,
56

            
57
    /// Username to use for the MySQL connection.
58
    #[arg(long, value_name = "USER", global = true)]
59
    mysql_user: Option<String>,
60

            
61
    /// Path to a file containing the MySQL password.
62
    #[arg(long, value_name = "PATH", global = true)]
63
    mysql_password_file: Option<PathBuf>,
64

            
65
    /// Seconds to wait for the MySQL connection to be established.
66
    #[arg(long, value_name = "SECONDS", global = true)]
67
    mysql_connect_timeout: Option<u64>,
68
}
69

            
70
/// Use the arguments and whichever configuration file which might or might not
71
/// be found and default values to determine the configuration for the program.
72
pub fn read_config_from_path_with_arg_overrides(
73
    config_path: Option<PathBuf>,
74
    args: ServerConfigArgs,
75
) -> anyhow::Result<ServerConfig> {
76
    let config = read_config_from_path(config_path)?;
77

            
78
    let mysql = config.mysql;
79

            
80
    let password = if let Some(path) = &args.mysql_password_file {
81
        Some(
82
            fs::read_to_string(path)
83
                .context("Failed to read MySQL password file")
84
                .map(|s| s.trim().to_owned())?,
85
        )
86
    } else if let Some(path) = &mysql.password_file {
87
        Some(
88
            fs::read_to_string(path)
89
                .context("Failed to read MySQL password file")
90
                .map(|s| s.trim().to_owned())?,
91
        )
92
    } else {
93
        mysql.password.to_owned()
94
    };
95

            
96
    Ok(ServerConfig {
97
        mysql: MysqlConfig {
98
            socket_path: args.socket_path.or(mysql.socket_path),
99
            host: args.mysql_host.or(mysql.host),
100
            port: args.mysql_port.or(mysql.port),
101
            username: args.mysql_user.or(mysql.username.to_owned()),
102
            password,
103
            password_file: args.mysql_password_file.or(mysql.password_file),
104
            timeout: args.mysql_connect_timeout.or(mysql.timeout),
105
        },
106
    })
107
}
108

            
109
pub fn read_config_from_path(config_path: Option<PathBuf>) -> anyhow::Result<ServerConfig> {
110
    let config_path = config_path.unwrap_or_else(|| PathBuf::from(DEFAULT_CONFIG_PATH));
111

            
112
    log::debug!("Reading config file at {:?}", &config_path);
113

            
114
    fs::read_to_string(&config_path)
115
        .context(format!("Failed to read config file at {:?}", &config_path))
116
        .and_then(|c| toml::from_str(&c).context("Failed to parse config file"))
117
        .context(format!("Failed to parse config file at {:?}", &config_path))
118
}
119

            
120
fn log_config(config: &MysqlConfig) {
121
    let mut display_config = config.to_owned();
122
    display_config.password = display_config
123
        .password
124
        .as_ref()
125
        .map(|_| "<REDACTED>".to_owned());
126
    log::debug!(
127
        "Connecting to MySQL server with parameters: {:#?}",
128
        display_config
129
    );
130
}
131

            
132
/// Use the provided configuration to establish a connection to a MySQL server.
133
pub async fn create_mysql_connection_from_config(
134
    config: &MysqlConfig,
135
) -> anyhow::Result<MySqlConnection> {
136
    log_config(config);
137

            
138
    let mut mysql_options = MySqlConnectOptions::new()
139
        .database("mysql")
140
        .log_statements(log::LevelFilter::Trace);
141

            
142
    if let Some(username) = &config.username {
143
        mysql_options = mysql_options.username(username);
144
    }
145

            
146
    if let Some(password) = &config.password {
147
        mysql_options = mysql_options.password(password);
148
    }
149

            
150
    if let Some(socket_path) = &config.socket_path {
151
        mysql_options = mysql_options.socket(socket_path);
152
    } else if let Some(host) = &config.host {
153
        mysql_options = mysql_options.host(host);
154
        mysql_options = mysql_options.port(config.port.unwrap_or(DEFAULT_PORT));
155
    } else {
156
        anyhow::bail!("No MySQL host or socket path provided");
157
    }
158

            
159
    match tokio::time::timeout(
160
        Duration::from_secs(config.timeout.unwrap_or(DEFAULT_TIMEOUT)),
161
        mysql_options.connect(),
162
    )
163
    .await
164
    {
165
        Ok(connection) => connection.context("Failed to connect to the database"),
166
        Err(_) => {
167
            Err(anyhow!("Timed out after 2 seconds")).context("Failed to connect to the database")
168
        }
169
    }
170
}