1
use std::{
2
    fs,
3
    os::{fd::FromRawFd, unix::net::UnixListener as StdUnixListener},
4
    path::PathBuf,
5
    sync::Arc,
6
    time::Duration,
7
};
8

            
9
use anyhow::{Context, anyhow};
10
use futures_util::SinkExt;
11
use indoc::concatdoc;
12
use sqlx::MySqlPool;
13
use tokio::{net::UnixListener as TokioUnixListener, task::JoinHandle, time::interval};
14
use tokio_util::task::TaskTracker;
15
// use tokio_util::sync::CancellationToken;
16

            
17
use crate::{
18
    core::{
19
        common::UnixUser,
20
        protocol::{Response, create_server_to_client_message_stream},
21
    },
22
    server::{
23
        config::{MysqlConfig, ServerConfig},
24
        session_handler::session_handler,
25
    },
26
};
27

            
28
// TODO: implement graceful shutdown and graceful restarts
29
#[allow(dead_code)]
30
pub struct Supervisor {
31
    config: ServerConfig,
32
    systemd_mode: bool,
33

            
34
    // sighup_cancel_token: CancellationToken,
35
    // sigterm_cancel_token: CancellationToken,
36
    // signal_handler_task: JoinHandle<()>,
37
    db_connection_pool: MySqlPool,
38
    // listener: TokioUnixListener,
39
    listener_task: JoinHandle<anyhow::Result<()>>,
40
    handler_task_tracker: TaskTracker,
41

            
42
    watchdog_timeout: Option<Duration>,
43
    systemd_watchdog_task: Option<JoinHandle<()>>,
44

            
45
    connection_counter: std::sync::Arc<()>,
46
    status_notifier_task: Option<JoinHandle<()>>,
47
}
48

            
49
impl Supervisor {
50
    pub async fn new(config: ServerConfig, systemd_mode: bool) -> anyhow::Result<Self> {
51
        let mut watchdog_duration = None;
52
        let mut watchdog_micro_seconds = 0;
53
        let watchdog_task =
54
            if systemd_mode && sd_notify::watchdog_enabled(true, &mut watchdog_micro_seconds) {
55
                watchdog_duration = Some(Duration::from_micros(watchdog_micro_seconds));
56
                log::debug!(
57
                    "Systemd watchdog enabled with {} millisecond interval",
58
                    watchdog_micro_seconds.div_ceil(1000),
59
                );
60
                Some(spawn_watchdog_task(watchdog_duration.unwrap()))
61
            } else {
62
                log::debug!("Systemd watchdog not enabled, skipping watchdog thread");
63
                None
64
            };
65

            
66
        let db_connection_pool = create_db_connection_pool(&config.mysql).await?;
67

            
68
        let connection_counter = Arc::new(());
69
        let status_notifier_task = if systemd_mode {
70
            Some(spawn_status_notifier_task(connection_counter.clone()))
71
        } else {
72
            None
73
        };
74

            
75
        // TODO: try to detech systemd socket before using the provided socket path
76
        let listener = match config.socket_path {
77
            Some(ref path) => create_unix_listener_with_socket_path(path.clone()).await?,
78
            None => create_unix_listener_with_systemd_socket().await?,
79
        };
80

            
81
        let listener_task = {
82
            let connection_counter = connection_counter.clone();
83
            tokio::spawn(spawn_listener_task(
84
                listener,
85
                connection_counter,
86
                db_connection_pool.clone(),
87
            ))
88
        };
89

            
90
        // let sighup_cancel_token = CancellationToken::new();
91
        // let sigterm_cancel_token = CancellationToken::new();
92

            
93
        Ok(Self {
94
            config,
95
            systemd_mode,
96
            // sighup_cancel_token,
97
            // sigterm_cancel_token,
98
            // signal_handler_task,
99
            db_connection_pool,
100
            // listener,
101
            listener_task,
102
            handler_task_tracker: TaskTracker::new(),
103
            watchdog_timeout: watchdog_duration,
104
            systemd_watchdog_task: watchdog_task,
105
            connection_counter,
106
            status_notifier_task,
107
        })
108
    }
109

            
110
    pub async fn run(self) -> anyhow::Result<()> {
111
        self.listener_task.await?
112
    }
113
}
114

            
115
fn spawn_watchdog_task(duration: Duration) -> JoinHandle<()> {
116
    tokio::spawn(async move {
117
        let mut interval = interval(duration.div_f32(2.0));
118
        log::debug!(
119
            "Starting systemd watchdog task, pinging every {} milliseconds",
120
            duration.div_f32(2.0).as_millis()
121
        );
122
        loop {
123
            interval.tick().await;
124
            if let Err(err) = sd_notify::notify(false, &[sd_notify::NotifyState::Watchdog]) {
125
                log::warn!("Failed to notify systemd watchdog: {}", err);
126
            } else {
127
                log::trace!("Ping sent to systemd watchdog");
128
            }
129
        }
130
    })
131
}
132

            
133
fn spawn_status_notifier_task(connection_counter: std::sync::Arc<()>) -> JoinHandle<()> {
134
    const NON_CONNECTION_ARC_COUNT: usize = 4;
135
    const STATUS_UPDATE_INTERVAL_SECS: Duration = Duration::from_secs(1);
136

            
137
    tokio::spawn(async move {
138
        let mut interval = interval(STATUS_UPDATE_INTERVAL_SECS);
139
        loop {
140
            interval.tick().await;
141
            log::trace!("Updating systemd status notification");
142
            let count = Arc::strong_count(&connection_counter) - NON_CONNECTION_ARC_COUNT;
143
            let message = if count > 0 {
144
                format!("Handling {} connections", count)
145
            } else {
146
                "Waiting for connections".to_string()
147
            };
148
            sd_notify::notify(false, &[sd_notify::NotifyState::Status(message.as_str())]).ok();
149
        }
150
    })
151
}
152

            
153
async fn create_unix_listener_with_socket_path(
154
    socket_path: PathBuf,
155
) -> anyhow::Result<TokioUnixListener> {
156
    let parent_directory = socket_path.parent().unwrap();
157
    if !parent_directory.exists() {
158
        log::debug!("Creating directory {:?}", parent_directory);
159
        fs::create_dir_all(parent_directory)?;
160
    }
161

            
162
    log::info!("Listening on socket {:?}", socket_path);
163

            
164
    match fs::remove_file(socket_path.as_path()) {
165
        Ok(_) => {}
166
        Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
167
        Err(e) => return Err(e.into()),
168
    }
169

            
170
    let listener = TokioUnixListener::bind(socket_path)?;
171

            
172
    Ok(listener)
173
}
174

            
175
async fn create_unix_listener_with_systemd_socket() -> anyhow::Result<TokioUnixListener> {
176
    let fd = sd_notify::listen_fds()
177
        .context("Failed to get file descriptors from systemd")?
178
        .next()
179
        .context("No file descriptors received from systemd")?;
180

            
181
    debug_assert!(fd == 3, "Unexpected file descriptor from systemd: {}", fd);
182

            
183
    log::debug!(
184
        "Received file descriptor from systemd with id: '{}', assuming socket",
185
        fd
186
    );
187

            
188
    let std_unix_listener = unsafe { StdUnixListener::from_raw_fd(fd) };
189
    let listener = TokioUnixListener::from_std(std_unix_listener)?;
190
    Ok(listener)
191
}
192

            
193
async fn create_db_connection_pool(config: &MysqlConfig) -> anyhow::Result<MySqlPool> {
194
    let mysql_config = config.as_mysql_connect_options()?;
195

            
196
    config.log_connection_notice();
197

            
198
    match tokio::time::timeout(
199
        Duration::from_secs(config.timeout),
200
        MySqlPool::connect_with(mysql_config),
201
    )
202
    .await
203
    {
204
        Ok(connection) => connection.context("Failed to connect to the database"),
205
        Err(_) => Err(anyhow!("Timed out after {} seconds", config.timeout))
206
            .context("Failed to connect to the database"),
207
    }
208
}
209

            
210
// fn spawn_signal_handler_task(
211
//     sighup_token: CancellationToken,
212
//     sigterm_token: CancellationToken,
213
// ) -> JoinHandle<()> {
214
//     tokio::spawn(async move {
215
//         let mut sighup_stream = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::hangup())
216
//             .expect("Failed to set up SIGHUP handler");
217
//         let mut sigterm_stream = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
218
//             .expect("Failed to set up SIGTERM handler");
219

            
220
//         loop {
221
//             tokio::select! {
222
//                 _ = sighup_stream.recv() => {
223
//                     log::info!("Received SIGHUP signal");
224
//                     sighup_token.cancel();
225
//                 }
226
//                 _ = sigterm_stream.recv() => {
227
//                     log::info!("Received SIGTERM signal");
228
//                     sigterm_token.cancel();
229
//                     break;
230
//                 }
231
//             }
232
//         }
233
//     })
234
// }
235

            
236
async fn spawn_listener_task(
237
    listener: TokioUnixListener,
238
    connection_counter: Arc<()>,
239
    db_pool: MySqlPool,
240
) -> anyhow::Result<()> {
241
    sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?;
242

            
243
    while let Ok((conn, _addr)) = listener.accept().await {
244
        log::debug!("Got new connection");
245

            
246
        let uid = match conn.peer_cred() {
247
            Ok(cred) => cred.uid(),
248
            Err(e) => {
249
                log::error!("Failed to get peer credentials from socket: {}", e);
250
                let mut message_stream = create_server_to_client_message_stream(conn);
251
                message_stream
252
                    .send(Response::Error(
253
                        (concatdoc! {
254
                            "Server failed to get peer credentials from socket\n",
255
                            "Please check the server logs or contact the system administrators"
256
                        })
257
                        .to_string(),
258
                    ))
259
                    .await
260
                    .ok();
261
                continue;
262
            }
263
        };
264

            
265
        log::debug!("Validated peer UID: {}", uid);
266

            
267
        let _connection_counter_guard = Arc::clone(&connection_counter);
268

            
269
        let unix_user = match UnixUser::from_uid(uid) {
270
            Ok(user) => user,
271
            Err(e) => {
272
                log::error!("Failed to get username from uid: {}", e);
273
                let mut message_stream = create_server_to_client_message_stream(conn);
274
                message_stream
275
                    .send(Response::Error(
276
                        (concatdoc! {
277
                            "Server failed to get user data from the system\n",
278
                            "Please check the server logs or contact the system administrators"
279
                        })
280
                        .to_string(),
281
                    ))
282
                    .await
283
                    .ok();
284
                continue;
285
            }
286
        };
287

            
288
        log::info!("Accepted connection from UNIX user: {}", unix_user.username);
289

            
290
        match session_handler(conn, &unix_user, db_pool.clone()).await {
291
            Ok(()) => {}
292
            Err(e) => {
293
                log::error!("Failed to run server: {}", e);
294
            }
295
        }
296
    }
297

            
298
    Ok(())
299
}