1
use std::{fs, path::PathBuf};
2

            
3
use anyhow::Context;
4
use clap::Parser;
5
use serde::{Deserialize, Serialize};
6
use sqlx::{ConnectOptions, mysql::MySqlConnectOptions};
7

            
8
use crate::core::common::DEFAULT_CONFIG_PATH;
9

            
10
pub const DEFAULT_PORT: u16 = 3306;
11
fn default_mysql_port() -> u16 {
12
    DEFAULT_PORT
13
}
14

            
15
pub const DEFAULT_TIMEOUT: u64 = 2;
16
fn default_mysql_timeout() -> u64 {
17
    DEFAULT_TIMEOUT
18
}
19

            
20
#[derive(Debug, Clone, Deserialize, Serialize)]
21
pub struct ServerConfig {
22
    pub socket_path: Option<PathBuf>,
23
    pub mysql: MysqlConfig,
24
}
25

            
26
#[derive(Debug, Clone, Deserialize, Serialize)]
27
#[serde(rename = "mysql")]
28
pub struct MysqlConfig {
29
    pub socket_path: Option<PathBuf>,
30
    pub host: Option<String>,
31
    #[serde(default = "default_mysql_port")]
32
    pub port: u16,
33
    pub username: Option<String>,
34
    pub password: Option<String>,
35
    pub password_file: Option<PathBuf>,
36
    #[serde(default = "default_mysql_timeout")]
37
    pub timeout: u64,
38
}
39

            
40
impl MysqlConfig {
41
    pub fn as_mysql_connect_options(&self) -> anyhow::Result<MySqlConnectOptions> {
42
        let mut options = MySqlConnectOptions::new()
43
            .database("mysql")
44
            .log_statements(log::LevelFilter::Trace);
45

            
46
        if let Some(username) = &self.username {
47
            options = options.username(username);
48
        }
49

            
50
        if let Some(password) = &self.password {
51
            options = options.password(password);
52
        }
53

            
54
        if let Some(socket_path) = &self.socket_path {
55
            options = options.socket(socket_path);
56
        } else if let Some(host) = &self.host {
57
            options = options.host(host);
58
            options = options.port(self.port);
59
        } else {
60
            anyhow::bail!("No MySQL host or socket path provided");
61
        }
62

            
63
        Ok(options)
64
    }
65

            
66
    pub fn log_connection_notice(&self) {
67
        let mut display_config = self.to_owned();
68
        display_config.password = display_config
69
            .password
70
            .as_ref()
71
            .map(|_| "<REDACTED>".to_owned());
72
        log::debug!(
73
            "Connecting to MySQL server with parameters: {:#?}",
74
            display_config
75
        );
76
    }
77
}
78

            
79
#[derive(Parser, Debug, Clone)]
80
pub struct ServerConfigArgs {
81
    /// Path where the server socket should be created.
82
    #[arg(long, value_name = "PATH", global = true)]
83
    socket_path: Option<PathBuf>,
84

            
85
    /// Path to the socket of the MySQL server.
86
    #[arg(long, value_name = "PATH", global = true)]
87
    mysql_socket_path: Option<PathBuf>,
88

            
89
    /// Hostname of the MySQL server.
90
    #[arg(
91
        long,
92
        value_name = "HOST",
93
        global = true,
94
        conflicts_with = "socket_path"
95
    )]
96
    mysql_host: Option<String>,
97

            
98
    /// Port of the MySQL server.
99
    #[arg(
100
        long,
101
        value_name = "PORT",
102
        global = true,
103
        conflicts_with = "socket_path"
104
    )]
105
    mysql_port: Option<u16>,
106

            
107
    /// Username to use for the MySQL connection.
108
    #[arg(long, value_name = "USER", global = true)]
109
    mysql_user: Option<String>,
110

            
111
    /// Path to a file containing the MySQL password.
112
    #[arg(long, value_name = "PATH", global = true)]
113
    mysql_password_file: Option<PathBuf>,
114

            
115
    /// Seconds to wait for the MySQL connection to be established.
116
    #[arg(long, value_name = "SECONDS", global = true)]
117
    mysql_connect_timeout: Option<u64>,
118
}
119

            
120
/// Use the arguments and whichever configuration file which might or might not
121
/// be found and default values to determine the configuration for the program.
122
pub fn read_config_from_path_with_arg_overrides(
123
    config_path: Option<PathBuf>,
124
    args: ServerConfigArgs,
125
) -> anyhow::Result<ServerConfig> {
126
    let config = read_config_from_path(config_path)?;
127

            
128
    let mysql = config.mysql;
129

            
130
    let password = if let Some(path) = &args.mysql_password_file {
131
        Some(
132
            fs::read_to_string(path)
133
                .context("Failed to read MySQL password file")
134
                .map(|s| s.trim().to_owned())?,
135
        )
136
    } else if let Some(path) = &mysql.password_file {
137
        Some(
138
            fs::read_to_string(path)
139
                .context("Failed to read MySQL password file")
140
                .map(|s| s.trim().to_owned())?,
141
        )
142
    } else {
143
        mysql.password.to_owned()
144
    };
145

            
146
    Ok(ServerConfig {
147
        socket_path: args.socket_path.or(config.socket_path),
148
        mysql: MysqlConfig {
149
            socket_path: args.mysql_socket_path.or(mysql.socket_path),
150
            host: args.mysql_host.or(mysql.host),
151
            port: args.mysql_port.unwrap_or(mysql.port),
152
            username: args.mysql_user.or(mysql.username.to_owned()),
153
            password,
154
            password_file: args.mysql_password_file.or(mysql.password_file),
155
            timeout: args.mysql_connect_timeout.unwrap_or(mysql.timeout),
156
        },
157
    })
158
}
159

            
160
pub fn read_config_from_path(config_path: Option<PathBuf>) -> anyhow::Result<ServerConfig> {
161
    let config_path = config_path.unwrap_or_else(|| PathBuf::from(DEFAULT_CONFIG_PATH));
162

            
163
    log::debug!("Reading config file at {:?}", &config_path);
164

            
165
    fs::read_to_string(&config_path)
166
        .context(format!("Failed to read config file at {:?}", &config_path))
167
        .and_then(|c| toml::from_str(&c).context("Failed to parse config file"))
168
        .context(format!("Failed to parse config file at {:?}", &config_path))
169
}