1
use std::{
2
    fs,
3
    path::{Path, PathBuf},
4
};
5

            
6
use anyhow::Context;
7
use serde::{Deserialize, Serialize};
8
use sqlx::{ConnectOptions, mysql::MySqlConnectOptions};
9

            
10
pub const DEFAULT_PORT: u16 = 3306;
11
fn default_mysql_port() -> u16 {
12
    DEFAULT_PORT
13
}
14

            
15
pub const DEFAULT_TIMEOUT: u64 = 2;
16
fn default_mysql_timeout() -> u64 {
17
    DEFAULT_TIMEOUT
18
}
19

            
20
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
21
#[serde(rename = "mysql")]
22
pub struct MysqlConfig {
23
    pub socket_path: Option<PathBuf>,
24
    pub host: Option<String>,
25
    #[serde(default = "default_mysql_port")]
26
    pub port: u16,
27
    pub username: Option<String>,
28
    pub password: Option<String>,
29
    pub password_file: Option<PathBuf>,
30
    #[serde(default = "default_mysql_timeout")]
31
    pub timeout: u64,
32
}
33

            
34
impl MysqlConfig {
35
    pub fn as_mysql_connect_options(&self) -> anyhow::Result<MySqlConnectOptions> {
36
        let mut options = MySqlConnectOptions::new()
37
            .database("mysql")
38
            .log_statements(tracing::log::LevelFilter::Trace);
39

            
40
        if let Some(username) = &self.username {
41
            options = options.username(username);
42
        }
43

            
44
        if let Some(password_file) = &self.password_file {
45
            let password = fs::read_to_string(password_file)
46
                .with_context(|| {
47
                    format!("Failed to read MySQL password file at {password_file:?}")
48
                })?
49
                .trim()
50
                .to_owned();
51
            options = options.password(&password);
52
        } else if let Some(password) = &self.password {
53
            options = options.password(password);
54
        }
55

            
56
        if let Some(socket_path) = &self.socket_path {
57
            options = options.socket(socket_path);
58
        } else if let Some(host) = &self.host {
59
            options = options.host(host);
60
            options = options.port(self.port);
61
        } else {
62
            anyhow::bail!("No MySQL host or socket path provided");
63
        }
64

            
65
        Ok(options)
66
    }
67

            
68
    pub fn log_connection_notice(&self) {
69
        let mut display_config = self.to_owned();
70
        display_config.password = display_config
71
            .password
72
            .as_ref()
73
            .map(|_| "<REDACTED>".to_owned());
74
        tracing::debug!(
75
            "Connecting to MySQL server with parameters: {:#?}",
76
            display_config
77
        );
78
    }
79
}
80

            
81
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
82
pub struct AuthorizationConfig {
83
    pub group_denylist_file: Option<PathBuf>,
84
}
85

            
86
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
87
pub struct ServerConfig {
88
    pub socket_path: Option<PathBuf>,
89
    pub authorization: AuthorizationConfig,
90
    pub mysql: MysqlConfig,
91
}
92

            
93
impl ServerConfig {
94
    /// Reads the server configuration from the specified path, or the default path if none is provided.
95
    pub fn read_config_from_path(config_path: &Path) -> anyhow::Result<Self> {
96
        tracing::debug!("Reading config file at {:?}", config_path);
97

            
98
        fs::read_to_string(config_path)
99
            .context(format!("Failed to read config file at {config_path:?}"))
100
            .and_then(|c| toml::from_str(&c).context("Failed to parse config file"))
101
            .context(format!("Failed to parse config file at {config_path:?}"))
102
    }
103
}