1
use std::{
2
    fs,
3
    os::{fd::FromRawFd, unix::net::UnixListener as StdUnixListener},
4
    path::PathBuf,
5
    sync::{
6
        Arc,
7
        atomic::{AtomicU64, Ordering},
8
    },
9
    time::Duration,
10
};
11

            
12
use anyhow::{Context, anyhow};
13
use sqlx::MySqlPool;
14
use tokio::{
15
    net::UnixListener as TokioUnixListener,
16
    select,
17
    sync::{Mutex, RwLock, broadcast},
18
    task::JoinHandle,
19
    time::interval,
20
};
21
use tokio_util::{sync::CancellationToken, task::TaskTracker};
22

            
23
use crate::{
24
    core::protocol::request_validation::GroupDenylist,
25
    server::{
26
        authorization::read_and_parse_group_denylist,
27
        config::{MysqlConfig, ServerConfig},
28
        session_handler::{SessionId, session_handler},
29
    },
30
};
31

            
32
#[derive(Clone, Debug)]
33
pub enum SupervisorMessage {
34
    StopAcceptingNewConnections,
35
    ResumeAcceptingNewConnections,
36
    Shutdown,
37
}
38

            
39
#[derive(Clone, Debug)]
40
pub struct ReloadEvent;
41

            
42
#[allow(dead_code)]
43
pub struct Supervisor {
44
    config_path: PathBuf,
45
    config: Arc<Mutex<ServerConfig>>,
46
    group_deny_list: Arc<RwLock<GroupDenylist>>,
47
    systemd_mode: bool,
48

            
49
    shutdown_cancel_token: CancellationToken,
50
    reload_message_receiver: broadcast::Receiver<ReloadEvent>,
51
    signal_handler_task: JoinHandle<()>,
52

            
53
    db_connection_pool: Arc<RwLock<MySqlPool>>,
54
    db_is_mariadb: Arc<RwLock<bool>>,
55
    listener: Arc<RwLock<TokioUnixListener>>,
56
    listener_task: JoinHandle<anyhow::Result<()>>,
57
    handler_task_tracker: TaskTracker,
58
    supervisor_message_sender: broadcast::Sender<SupervisorMessage>,
59

            
60
    watchdog_timeout: Option<Duration>,
61
    systemd_watchdog_task: Option<JoinHandle<()>>,
62

            
63
    status_notifier_task: Option<JoinHandle<()>>,
64
}
65

            
66
impl Supervisor {
67
    pub async fn new(config_path: PathBuf, systemd_mode: bool) -> anyhow::Result<Self> {
68
        tracing::debug!("Starting server supervisor");
69
        tracing::debug!(
70
            "Running in tokio with {} worker threads",
71
            tokio::runtime::Handle::current().metrics().num_workers()
72
        );
73

            
74
        let config = ServerConfig::read_config_from_path(&config_path)
75
            .context("Failed to read server configuration")?;
76

            
77
        let group_deny_list = if let Some(denylist_path) = &config.authorization.group_denylist_file
78
        {
79
            let denylist = read_and_parse_group_denylist(denylist_path)
80
                .context("Failed to read group denylist file")?;
81
            tracing::debug!(
82
                "Loaded group denylist with {} entries from {:?}",
83
                denylist.len(),
84
                denylist_path
85
            );
86
            Arc::new(RwLock::new(denylist))
87
        } else {
88
            tracing::debug!("No group denylist file specified, proceeding without a denylist");
89
            Arc::new(RwLock::new(GroupDenylist::new()))
90
        };
91

            
92
        let mut watchdog_duration = None;
93
        #[cfg(target_os = "linux")]
94
        let watchdog_task =
95
            if systemd_mode && let Some(watchdog_duration_) = sd_notify::watchdog_enabled() {
96
                tracing::debug!(
97
                    "Systemd watchdog enabled with {} millisecond interval",
98
                    watchdog_duration_.as_millis()
99
                );
100
                watchdog_duration = Some(watchdog_duration_);
101
                Some(spawn_watchdog_task(watchdog_duration_))
102
            } else {
103
                tracing::debug!("Systemd watchdog not enabled, skipping watchdog thread");
104
                None
105
            };
106
        #[cfg(not(target_os = "linux"))]
107
        let watchdog_task = None;
108

            
109
        let db_connection_pool =
110
            Arc::new(RwLock::new(create_db_connection_pool(&config.mysql).await?));
111

            
112
        let db_is_mariadb = {
113
            let connection = db_connection_pool.read().await;
114
            let version: String = sqlx::query_scalar("SELECT VERSION()")
115
                .fetch_one(&*connection)
116
                .await
117
                .context("Failed to query database version")?;
118

            
119
            let result = version.to_lowercase().contains("mariadb");
120
            tracing::debug!(
121
                "Connected to {} database server",
122
                if result { "MariaDB" } else { "MySQL" }
123
            );
124

            
125
            Arc::new(RwLock::new(result))
126
        };
127

            
128
        let task_tracker = TaskTracker::new();
129

            
130
        #[cfg(target_os = "linux")]
131
        let status_notifier_task = if systemd_mode {
132
            Some(spawn_status_notifier_task(task_tracker.clone()))
133
        } else {
134
            None
135
        };
136
        #[cfg(not(target_os = "linux"))]
137
        let status_notifier_task = None;
138

            
139
        let (tx, rx) = broadcast::channel(1);
140

            
141
        // TODO: try to detect systemd socket before using the provided socket path
142
        #[cfg(target_os = "linux")]
143
        let listener = Arc::new(RwLock::new(match config.socket_path {
144
            Some(ref path) => create_unix_listener_with_socket_path(path.clone()).await?,
145
            None => create_unix_listener_with_systemd_socket().await?,
146
        }));
147
        #[cfg(not(target_os = "linux"))]
148
        let listener = Arc::new(RwLock::new(
149
            create_unix_listener_with_socket_path(
150
                config
151
                    .socket_path
152
                    .as_ref()
153
                    .ok_or(anyhow!("Socket path must be set"))?
154
                    .clone(),
155
            )
156
            .await?,
157
        ));
158

            
159
        let (reload_tx, reload_rx) = broadcast::channel(1);
160
        let shutdown_cancel_token = CancellationToken::new();
161
        let signal_handler_task =
162
            spawn_signal_handler_task(reload_tx, shutdown_cancel_token.clone());
163

            
164
        let listener_clone = listener.clone();
165
        let task_tracker_clone = task_tracker.clone();
166
        let listener_task = {
167
            tokio::spawn(listener_task(
168
                listener_clone,
169
                task_tracker_clone,
170
                db_connection_pool.clone(),
171
                rx,
172
                db_is_mariadb.clone(),
173
                group_deny_list.clone(),
174
            ))
175
        };
176

            
177
        Ok(Self {
178
            config_path,
179
            config: Arc::new(Mutex::new(config)),
180
            group_deny_list,
181
            systemd_mode,
182
            reload_message_receiver: reload_rx,
183
            shutdown_cancel_token,
184
            signal_handler_task,
185
            db_connection_pool,
186
            db_is_mariadb,
187
            listener,
188
            listener_task,
189
            handler_task_tracker: task_tracker,
190
            supervisor_message_sender: tx,
191
            watchdog_timeout: watchdog_duration,
192
            systemd_watchdog_task: watchdog_task,
193
            status_notifier_task,
194
        })
195
    }
196

            
197
    fn stop_receiving_new_connections(&self) -> anyhow::Result<()> {
198
        self.handler_task_tracker.close();
199
        self.supervisor_message_sender
200
            .send(SupervisorMessage::StopAcceptingNewConnections)
201
            .context("Failed to send stop accepting new connections message to listener task")?;
202
        Ok(())
203
    }
204

            
205
    fn resume_receiving_new_connections(&self) -> anyhow::Result<()> {
206
        self.handler_task_tracker.reopen();
207
        self.supervisor_message_sender
208
            .send(SupervisorMessage::ResumeAcceptingNewConnections)
209
            .context("Failed to send resume accepting new connections message to listener task")?;
210
        Ok(())
211
    }
212

            
213
    async fn wait_for_existing_connections_to_finish(&self) -> anyhow::Result<()> {
214
        self.handler_task_tracker.wait().await;
215
        Ok(())
216
    }
217

            
218
    async fn reload_config(&self) -> anyhow::Result<()> {
219
        let new_config = ServerConfig::read_config_from_path(&self.config_path)
220
            .context("Failed to read server configuration")?;
221
        let mut config = self.config.clone().lock_owned().await;
222
        *config = new_config;
223

            
224
        let group_deny_list = if let Some(denylist_path) = &config.authorization.group_denylist_file
225
        {
226
            let denylist = read_and_parse_group_denylist(denylist_path)
227
                .context("Failed to read group denylist file")?;
228

            
229
            tracing::debug!(
230
                "Loaded group denylist with {} entries from {:?}",
231
                denylist.len(),
232
                denylist_path
233
            );
234
            denylist
235
        } else {
236
            tracing::debug!("No group denylist file specified, proceeding without a denylist");
237
            GroupDenylist::new()
238
        };
239
        let mut group_deny_list_lock = self.group_deny_list.write().await;
240
        *group_deny_list_lock = group_deny_list;
241
        Ok(())
242
    }
243

            
244
    async fn restart_db_connection_pool(&self) -> anyhow::Result<()> {
245
        let config = self.config.lock().await;
246
        let mut connection_pool = self.db_connection_pool.clone().write_owned().await;
247
        let mut db_is_mariadb_lock = self.db_is_mariadb.write().await;
248

            
249
        let new_db_pool = create_db_connection_pool(&config.mysql).await?;
250
        let db_is_mariadb = {
251
            let version: String = sqlx::query_scalar("SELECT VERSION()")
252
                .fetch_one(&new_db_pool)
253
                .await
254
                .context("Failed to query database version")?;
255

            
256
            let result = version.to_lowercase().contains("mariadb");
257
            tracing::debug!(
258
                "Connected to {} database server",
259
                if result { "MariaDB" } else { "MySQL" }
260
            );
261

            
262
            result
263
        };
264

            
265
        *connection_pool = new_db_pool;
266
        *db_is_mariadb_lock = db_is_mariadb;
267
        Ok(())
268
    }
269

            
270
    // NOTE: the listener task will block the write lock unless the task is cancelled
271
    //       first. Make sure to handle that appropriately to avoid a deadlock.
272
    async fn reload_listener(&self) -> anyhow::Result<()> {
273
        let config = self.config.lock().await;
274
        #[cfg(target_os = "linux")]
275
        let new_listener = match config.socket_path {
276
            Some(ref path) => create_unix_listener_with_socket_path(path.clone()).await?,
277
            None => create_unix_listener_with_systemd_socket().await?,
278
        };
279
        #[cfg(not(target_os = "linux"))]
280
        let new_listener = create_unix_listener_with_socket_path(
281
            config
282
                .socket_path
283
                .as_ref()
284
                .ok_or(anyhow!("Socket path must be set"))?
285
                .clone(),
286
        )
287
        .await?;
288

            
289
        let mut listener = self.listener.write().await;
290
        *listener = new_listener;
291
        Ok(())
292
    }
293

            
294
    pub async fn reload(&self) -> anyhow::Result<()> {
295
        #[cfg(target_os = "linux")]
296
        sd_notify::notify(&[
297
            sd_notify::NotifyState::Reloading,
298
            sd_notify::NotifyState::monotonic_usec_now()
299
                .expect("Failed to get monotonic time to send to systemd while reloading"),
300
            sd_notify::NotifyState::Status("Reloading configuration"),
301
        ])?;
302

            
303
        let previous_config = self.config.lock().await.clone();
304
        self.reload_config().await?;
305

            
306
        let mut listener_task_was_stopped = false;
307

            
308
        // NOTE: despite closing the existing db pool, any already acquired connections will remain valid until dropped,
309
        //       so we don't need to close existing connections here.
310
        if self.config.lock().await.mysql != previous_config.mysql {
311
            tracing::debug!("MySQL configuration has changed");
312

            
313
            tracing::debug!("Restarting database connection pool with new configuration");
314
            self.restart_db_connection_pool().await?;
315
        }
316

            
317
        if self.config.lock().await.socket_path != previous_config.socket_path {
318
            tracing::debug!("Socket path configuration has changed, reloading listener");
319
            if !listener_task_was_stopped {
320
                listener_task_was_stopped = true;
321
                tracing::debug!("Stop accepting new connections");
322
                self.stop_receiving_new_connections()?;
323

            
324
                tracing::debug!("Waiting for existing connections to finish");
325
                self.wait_for_existing_connections_to_finish().await?;
326
            }
327

            
328
            tracing::debug!("Reloading listener with new socket path");
329
            self.reload_listener().await?;
330
        }
331

            
332
        if listener_task_was_stopped {
333
            tracing::debug!("Resuming listener task");
334
            self.resume_receiving_new_connections()?;
335
        }
336

            
337
        #[cfg(target_os = "linux")]
338
        sd_notify::notify(&[sd_notify::NotifyState::Ready])?;
339

            
340
        Ok(())
341
    }
342

            
343
    pub async fn shutdown(&self) -> anyhow::Result<()> {
344
        #[cfg(target_os = "linux")]
345
        sd_notify::notify(&[sd_notify::NotifyState::Stopping])?;
346

            
347
        tracing::debug!("Stop accepting new connections");
348
        self.stop_receiving_new_connections()?;
349

            
350
        let connection_count = self.handler_task_tracker.len();
351
        tracing::debug!(
352
            "Waiting for {} existing connections to finish",
353
            connection_count
354
        );
355
        self.wait_for_existing_connections_to_finish().await?;
356

            
357
        tracing::debug!("Shutting down listener task");
358
        self.supervisor_message_sender
359
            .send(SupervisorMessage::Shutdown)
360
            .unwrap_or_else(|e| {
361
                tracing::warn!("Failed to send shutdown message to listener task: {}", e);
362
                0
363
            });
364

            
365
        tracing::debug!("Shutting down database connection pool");
366
        self.db_connection_pool.read().await.close().await;
367

            
368
        tracing::debug!("Server shutdown complete");
369

            
370
        std::process::exit(0);
371
    }
372

            
373
    pub async fn run(&self) -> anyhow::Result<()> {
374
        loop {
375
            select! {
376
                biased;
377

            
378
                _ = async {
379
                  let mut rx = self.reload_message_receiver.resubscribe();
380
                  rx.recv().await
381
                } => {
382
                    tracing::info!("Reloading configuration");
383
                    match self.reload().await {
384
                        Ok(()) => {
385
                            tracing::info!("Configuration reloaded successfully");
386
                        }
387
                        Err(e) => {
388
                            tracing::error!("Failed to reload configuration: {}", e);
389
                        }
390
                    }
391
                }
392

            
393
                () = self.shutdown_cancel_token.cancelled() => {
394
                    tracing::info!("Shutting down server");
395
                    self.shutdown().await?;
396
                    break;
397
                }
398
            }
399
        }
400

            
401
        Ok(())
402
    }
403
}
404

            
405
#[cfg(target_os = "linux")]
406
fn spawn_watchdog_task(duration: Duration) -> JoinHandle<()> {
407
    tokio::spawn(async move {
408
        let mut interval = interval(duration.div_f32(2.0));
409
        tracing::debug!(
410
            "Starting systemd watchdog task, pinging every {} milliseconds",
411
            duration.div_f32(2.0).as_millis()
412
        );
413
        loop {
414
            interval.tick().await;
415
            if let Err(err) = sd_notify::notify(&[sd_notify::NotifyState::Watchdog]) {
416
                tracing::warn!("Failed to notify systemd watchdog: {}", err);
417
            }
418
        }
419
    })
420
}
421

            
422
#[cfg(target_os = "linux")]
423
fn spawn_status_notifier_task(task_tracker: TaskTracker) -> JoinHandle<()> {
424
    const STATUS_UPDATE_INTERVAL_SECS: Duration = Duration::from_secs(1);
425

            
426
    tokio::spawn(async move {
427
        let mut interval = interval(STATUS_UPDATE_INTERVAL_SECS);
428
        loop {
429
            interval.tick().await;
430
            let count = task_tracker.len();
431

            
432
            let message = if count > 0 {
433
                format!("Handling {count} connections")
434
            } else {
435
                "Waiting for connections".to_string()
436
            };
437

            
438
            if let Err(e) = sd_notify::notify(&[sd_notify::NotifyState::Status(message.as_str())]) {
439
                tracing::warn!("Failed to send systemd status notification: {}", e);
440
            }
441
        }
442
    })
443
}
444

            
445
async fn create_unix_listener_with_socket_path(
446
    socket_path: PathBuf,
447
) -> anyhow::Result<TokioUnixListener> {
448
    let parent_directory = socket_path.parent().unwrap();
449
    if !parent_directory.exists() {
450
        tracing::debug!("Creating directory {:?}", parent_directory);
451
        fs::create_dir_all(parent_directory)?;
452
    }
453

            
454
    tracing::info!("Listening on socket {:?}", socket_path);
455

            
456
    match fs::remove_file(socket_path.as_path()) {
457
        Ok(()) => {}
458
        Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
459
        Err(e) => return Err(e.into()),
460
    }
461

            
462
    let listener = TokioUnixListener::bind(socket_path)?;
463

            
464
    Ok(listener)
465
}
466

            
467
#[cfg(target_os = "linux")]
468
async fn create_unix_listener_with_systemd_socket() -> anyhow::Result<TokioUnixListener> {
469
    let fd = sd_notify::listen_fds()
470
        .context("Failed to get file descriptors from systemd")?
471
        .next()
472
        .context("No file descriptors received from systemd")?;
473

            
474
    debug_assert!(fd == 3, "Unexpected file descriptor from systemd: {fd}");
475

            
476
    tracing::debug!(
477
        "Received file descriptor from systemd with id: '{}', assuming socket",
478
        fd
479
    );
480

            
481
    let std_unix_listener = unsafe { StdUnixListener::from_raw_fd(fd) };
482
    std_unix_listener
483
        .set_nonblocking(true)
484
        .context("Failed to set non-blocking mode on systemd socket")?;
485
    let listener = TokioUnixListener::from_std(std_unix_listener)?;
486

            
487
    Ok(listener)
488
}
489

            
490
async fn create_db_connection_pool(config: &MysqlConfig) -> anyhow::Result<MySqlPool> {
491
    let mysql_config = config.as_mysql_connect_options()?;
492

            
493
    config.log_connection_notice();
494

            
495
    let pool = match tokio::time::timeout(
496
        Duration::from_secs(config.timeout),
497
        MySqlPool::connect_with(mysql_config),
498
    )
499
    .await
500
    {
501
        Ok(connection) => connection.context("Failed to connect to the database"),
502
        Err(_) => Err(anyhow!("Timed out after {} seconds", config.timeout))
503
            .context("Failed to connect to the database"),
504
    }?;
505

            
506
    let pool_opts = pool.options();
507
    tracing::debug!(
508
        "Successfully opened database connection pool with options (max_connections: {}, min_connections: {})",
509
        pool_opts.get_max_connections(),
510
        pool_opts.get_min_connections(),
511
    );
512

            
513
    Ok(pool)
514
}
515

            
516
fn spawn_signal_handler_task(
517
    reload_sender: broadcast::Sender<ReloadEvent>,
518
    shutdown_token: CancellationToken,
519
) -> JoinHandle<()> {
520
    tokio::spawn(async move {
521
        let mut sighup_stream =
522
            tokio::signal::unix::signal(tokio::signal::unix::SignalKind::hangup())
523
                .expect("Failed to set up SIGHUP handler");
524
        let mut sigterm_stream =
525
            tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
526
                .expect("Failed to set up SIGTERM handler");
527

            
528
        loop {
529
            tokio::select! {
530
                _ = sighup_stream.recv() => {
531
                    tracing::info!("Received SIGHUP signal");
532
                    reload_sender.send(ReloadEvent).ok();
533
                }
534
                _ = sigterm_stream.recv() => {
535
                    tracing::info!("Received SIGTERM signal");
536
                    shutdown_token.cancel();
537
                    break;
538
                }
539
            }
540
        }
541
    })
542
}
543

            
544
async fn listener_task(
545
    listener: Arc<RwLock<TokioUnixListener>>,
546
    task_tracker: TaskTracker,
547
    db_pool: Arc<RwLock<MySqlPool>>,
548
    mut supervisor_message_receiver: broadcast::Receiver<SupervisorMessage>,
549
    db_is_mariadb: Arc<RwLock<bool>>,
550
    group_denylist: Arc<RwLock<GroupDenylist>>,
551
) -> anyhow::Result<()> {
552
    #[cfg(target_os = "linux")]
553
    sd_notify::notify(&[sd_notify::NotifyState::Ready])?;
554

            
555
    let connection_counter = AtomicU64::new(0);
556

            
557
    loop {
558
        tokio::select! {
559
            biased;
560

            
561
            Ok(message) = supervisor_message_receiver.recv() => {
562
                match message {
563
                    SupervisorMessage::StopAcceptingNewConnections => {
564
                        tracing::info!("Listener task received stop accepting new connections message, stopping listener");
565
                        while let Ok(msg) = supervisor_message_receiver.try_recv() {
566
                            if let SupervisorMessage::ResumeAcceptingNewConnections = msg {
567
                                tracing::info!("Listener task received resume accepting new connections message, resuming listener");
568
                                break;
569
                            }
570
                        }
571
                    }
572
                    SupervisorMessage::Shutdown => {
573
                        tracing::info!("Listener task received shutdown message, exiting listener task");
574
                        break;
575
                    }
576
                    _ => {}
577
                }
578
            }
579

            
580
            accept_result = async {
581
                let listener = listener.read().await;
582
                listener.accept().await
583
            } => {
584
                match accept_result {
585
                    Ok((conn, _addr)) => {
586
                        connection_counter.fetch_add(1, Ordering::Relaxed);
587
                        let conn_id = connection_counter.load(Ordering::Relaxed);
588

            
589
                        tracing::debug!("Got new connection, assigned session ID {}", conn_id);
590

            
591
                        let session_id = SessionId::new(conn_id);
592
                        let db_pool_clone = db_pool.clone();
593
                        let db_is_mariadb_clone = *db_is_mariadb.read().await;
594
                        let group_denylist_arc_clone = group_denylist.clone();
595
                        task_tracker.spawn(async move {
596
                            match session_handler(
597
                                conn,
598
                                session_id,
599
                                db_pool_clone,
600
                                db_is_mariadb_clone,
601
                                &*group_denylist_arc_clone.read().await,
602
                            ).await {
603
                                Ok(()) => {},
604
                                Err(e) => tracing::error!("Session {} failed: {}", conn_id, e),
605
                            }
606
                        });
607
                    },
608
                    Err(e) => tracing::error!("Failed to accept new connection: {}", e),
609
                }
610
            }
611
        }
612
    }
613

            
614
    Ok(())
615
}