1
use std::{
2
    fs,
3
    os::{fd::FromRawFd, unix::net::UnixListener as StdUnixListener},
4
    path::PathBuf,
5
    sync::{
6
        Arc,
7
        atomic::{AtomicU64, Ordering},
8
    },
9
    time::Duration,
10
};
11

            
12
use anyhow::{Context, anyhow};
13
use sqlx::MySqlPool;
14
use tokio::{
15
    net::UnixListener as TokioUnixListener,
16
    select,
17
    sync::{Mutex, RwLock, broadcast},
18
    task::JoinHandle,
19
    time::interval,
20
};
21
use tokio_util::{sync::CancellationToken, task::TaskTracker};
22

            
23
use crate::{
24
    core::protocol::request_validation::GroupDenylist,
25
    server::{
26
        authorization::read_and_parse_group_denylist,
27
        config::{MysqlConfig, ServerConfig},
28
        session_handler::{SessionId, session_handler},
29
    },
30
};
31

            
32
#[derive(Clone, Debug)]
33
pub enum SupervisorMessage {
34
    StopAcceptingNewConnections,
35
    ResumeAcceptingNewConnections,
36
    Shutdown,
37
}
38

            
39
#[derive(Clone, Debug)]
40
pub struct ReloadEvent;
41

            
42
#[allow(dead_code)]
43
pub struct Supervisor {
44
    config_path: PathBuf,
45
    config: Arc<Mutex<ServerConfig>>,
46
    group_deny_list: Arc<RwLock<GroupDenylist>>,
47
    systemd_mode: bool,
48

            
49
    shutdown_cancel_token: CancellationToken,
50
    reload_message_receiver: broadcast::Receiver<ReloadEvent>,
51
    signal_handler_task: JoinHandle<()>,
52

            
53
    db_connection_pool: Arc<RwLock<MySqlPool>>,
54
    db_is_mariadb: Arc<RwLock<bool>>,
55
    listener: Arc<RwLock<TokioUnixListener>>,
56
    listener_task: JoinHandle<anyhow::Result<()>>,
57
    handler_task_tracker: TaskTracker,
58
    supervisor_message_sender: broadcast::Sender<SupervisorMessage>,
59

            
60
    watchdog_timeout: Option<Duration>,
61
    systemd_watchdog_task: Option<JoinHandle<()>>,
62

            
63
    status_notifier_task: Option<JoinHandle<()>>,
64
}
65

            
66
impl Supervisor {
67
    pub async fn new(config_path: PathBuf, systemd_mode: bool) -> anyhow::Result<Self> {
68
        tracing::debug!("Starting server supervisor");
69
        tracing::debug!(
70
            "Running in tokio with {} worker threads",
71
            tokio::runtime::Handle::current().metrics().num_workers()
72
        );
73

            
74
        let config = ServerConfig::read_config_from_path(&config_path)
75
            .context("Failed to read server configuration")?;
76

            
77
        let group_deny_list = if let Some(denylist_path) = &config.authorization.group_denylist_file
78
        {
79
            let denylist = read_and_parse_group_denylist(denylist_path)
80
                .context("Failed to read group denylist file")?;
81
            tracing::debug!(
82
                "Loaded group denylist with {} entries from {:?}",
83
                denylist.len(),
84
                denylist_path
85
            );
86
            Arc::new(RwLock::new(denylist))
87
        } else {
88
            tracing::debug!("No group denylist file specified, proceeding without a denylist");
89
            Arc::new(RwLock::new(GroupDenylist::new()))
90
        };
91

            
92
        let mut watchdog_duration = None;
93
        let mut watchdog_micro_seconds = 0;
94
        #[cfg(target_os = "linux")]
95
        let watchdog_task =
96
            if systemd_mode && sd_notify::watchdog_enabled(true, &mut watchdog_micro_seconds) {
97
                let watchdog_duration_ = Duration::from_micros(watchdog_micro_seconds);
98
                tracing::debug!(
99
                    "Systemd watchdog enabled with {} millisecond interval",
100
                    watchdog_micro_seconds.div_ceil(1000),
101
                );
102
                watchdog_duration = Some(watchdog_duration_);
103
                Some(spawn_watchdog_task(watchdog_duration_))
104
            } else {
105
                tracing::debug!("Systemd watchdog not enabled, skipping watchdog thread");
106
                None
107
            };
108
        #[cfg(not(target_os = "linux"))]
109
        let watchdog_task = None;
110

            
111
        let db_connection_pool =
112
            Arc::new(RwLock::new(create_db_connection_pool(&config.mysql).await?));
113

            
114
        let db_is_mariadb = {
115
            let connection = db_connection_pool.read().await;
116
            let version: String = sqlx::query_scalar("SELECT VERSION()")
117
                .fetch_one(&*connection)
118
                .await
119
                .context("Failed to query database version")?;
120

            
121
            let result = version.to_lowercase().contains("mariadb");
122
            tracing::debug!(
123
                "Connected to {} database server",
124
                if result { "MariaDB" } else { "MySQL" }
125
            );
126

            
127
            Arc::new(RwLock::new(result))
128
        };
129

            
130
        let task_tracker = TaskTracker::new();
131

            
132
        #[cfg(target_os = "linux")]
133
        let status_notifier_task = if systemd_mode {
134
            Some(spawn_status_notifier_task(task_tracker.clone()))
135
        } else {
136
            None
137
        };
138
        #[cfg(not(target_os = "linux"))]
139
        let status_notifier_task = None;
140

            
141
        let (tx, rx) = broadcast::channel(1);
142

            
143
        // TODO: try to detect systemd socket before using the provided socket path
144
        #[cfg(target_os = "linux")]
145
        let listener = Arc::new(RwLock::new(match config.socket_path {
146
            Some(ref path) => create_unix_listener_with_socket_path(path.clone()).await?,
147
            None => create_unix_listener_with_systemd_socket().await?,
148
        }));
149
        #[cfg(not(target_os = "linux"))]
150
        let listener = Arc::new(RwLock::new(
151
            create_unix_listener_with_socket_path(
152
                config
153
                    .socket_path
154
                    .as_ref()
155
                    .ok_or(anyhow!("Socket path must be set"))?
156
                    .clone(),
157
            )
158
            .await?,
159
        ));
160

            
161
        let (reload_tx, reload_rx) = broadcast::channel(1);
162
        let shutdown_cancel_token = CancellationToken::new();
163
        let signal_handler_task =
164
            spawn_signal_handler_task(reload_tx, shutdown_cancel_token.clone());
165

            
166
        let listener_clone = listener.clone();
167
        let task_tracker_clone = task_tracker.clone();
168
        let listener_task = {
169
            tokio::spawn(listener_task(
170
                listener_clone,
171
                task_tracker_clone,
172
                db_connection_pool.clone(),
173
                rx,
174
                db_is_mariadb.clone(),
175
                group_deny_list.clone(),
176
            ))
177
        };
178

            
179
        Ok(Self {
180
            config_path,
181
            config: Arc::new(Mutex::new(config)),
182
            group_deny_list,
183
            systemd_mode,
184
            reload_message_receiver: reload_rx,
185
            shutdown_cancel_token,
186
            signal_handler_task,
187
            db_connection_pool,
188
            db_is_mariadb,
189
            listener,
190
            listener_task,
191
            handler_task_tracker: task_tracker,
192
            supervisor_message_sender: tx,
193
            watchdog_timeout: watchdog_duration,
194
            systemd_watchdog_task: watchdog_task,
195
            status_notifier_task,
196
        })
197
    }
198

            
199
    fn stop_receiving_new_connections(&self) -> anyhow::Result<()> {
200
        self.handler_task_tracker.close();
201
        self.supervisor_message_sender
202
            .send(SupervisorMessage::StopAcceptingNewConnections)
203
            .context("Failed to send stop accepting new connections message to listener task")?;
204
        Ok(())
205
    }
206

            
207
    fn resume_receiving_new_connections(&self) -> anyhow::Result<()> {
208
        self.handler_task_tracker.reopen();
209
        self.supervisor_message_sender
210
            .send(SupervisorMessage::ResumeAcceptingNewConnections)
211
            .context("Failed to send resume accepting new connections message to listener task")?;
212
        Ok(())
213
    }
214

            
215
    async fn wait_for_existing_connections_to_finish(&self) -> anyhow::Result<()> {
216
        self.handler_task_tracker.wait().await;
217
        Ok(())
218
    }
219

            
220
    async fn reload_config(&self) -> anyhow::Result<()> {
221
        let new_config = ServerConfig::read_config_from_path(&self.config_path)
222
            .context("Failed to read server configuration")?;
223
        let mut config = self.config.clone().lock_owned().await;
224
        *config = new_config;
225

            
226
        let group_deny_list = if let Some(denylist_path) = &config.authorization.group_denylist_file
227
        {
228
            let denylist = read_and_parse_group_denylist(denylist_path)
229
                .context("Failed to read group denylist file")?;
230

            
231
            tracing::debug!(
232
                "Loaded group denylist with {} entries from {:?}",
233
                denylist.len(),
234
                denylist_path
235
            );
236
            denylist
237
        } else {
238
            tracing::debug!("No group denylist file specified, proceeding without a denylist");
239
            GroupDenylist::new()
240
        };
241
        let mut group_deny_list_lock = self.group_deny_list.write().await;
242
        *group_deny_list_lock = group_deny_list;
243
        Ok(())
244
    }
245

            
246
    async fn restart_db_connection_pool(&self) -> anyhow::Result<()> {
247
        let config = self.config.lock().await;
248
        let mut connection_pool = self.db_connection_pool.clone().write_owned().await;
249
        let mut db_is_mariadb_lock = self.db_is_mariadb.write().await;
250

            
251
        let new_db_pool = create_db_connection_pool(&config.mysql).await?;
252
        let db_is_mariadb = {
253
            let version: String = sqlx::query_scalar("SELECT VERSION()")
254
                .fetch_one(&new_db_pool)
255
                .await
256
                .context("Failed to query database version")?;
257

            
258
            let result = version.to_lowercase().contains("mariadb");
259
            tracing::debug!(
260
                "Connected to {} database server",
261
                if result { "MariaDB" } else { "MySQL" }
262
            );
263

            
264
            result
265
        };
266

            
267
        *connection_pool = new_db_pool;
268
        *db_is_mariadb_lock = db_is_mariadb;
269
        Ok(())
270
    }
271

            
272
    // NOTE: the listener task will block the write lock unless the task is cancelled
273
    //       first. Make sure to handle that appropriately to avoid a deadlock.
274
    async fn reload_listener(&self) -> anyhow::Result<()> {
275
        let config = self.config.lock().await;
276
        #[cfg(target_os = "linux")]
277
        let new_listener = match config.socket_path {
278
            Some(ref path) => create_unix_listener_with_socket_path(path.clone()).await?,
279
            None => create_unix_listener_with_systemd_socket().await?,
280
        };
281
        #[cfg(not(target_os = "linux"))]
282
        let new_listener = create_unix_listener_with_socket_path(
283
            config
284
                .socket_path
285
                .as_ref()
286
                .ok_or(anyhow!("Socket path must be set"))?
287
                .clone(),
288
        )
289
        .await?;
290

            
291
        let mut listener = self.listener.write().await;
292
        *listener = new_listener;
293
        Ok(())
294
    }
295

            
296
    pub async fn reload(&self) -> anyhow::Result<()> {
297
        #[cfg(target_os = "linux")]
298
        sd_notify::notify(
299
            false,
300
            &[
301
                sd_notify::NotifyState::Reloading,
302
                sd_notify::NotifyState::monotonic_usec_now()
303
                    .expect("Failed to get monotonic time to send to systemd while reloading"),
304
                sd_notify::NotifyState::Status("Reloading configuration"),
305
            ],
306
        )?;
307

            
308
        let previous_config = self.config.lock().await.clone();
309
        self.reload_config().await?;
310

            
311
        let mut listener_task_was_stopped = false;
312

            
313
        // NOTE: despite closing the existing db pool, any already acquired connections will remain valid until dropped,
314
        //       so we don't need to close existing connections here.
315
        if self.config.lock().await.mysql != previous_config.mysql {
316
            tracing::debug!("MySQL configuration has changed");
317

            
318
            tracing::debug!("Restarting database connection pool with new configuration");
319
            self.restart_db_connection_pool().await?;
320
        }
321

            
322
        if self.config.lock().await.socket_path != previous_config.socket_path {
323
            tracing::debug!("Socket path configuration has changed, reloading listener");
324
            if !listener_task_was_stopped {
325
                listener_task_was_stopped = true;
326
                tracing::debug!("Stop accepting new connections");
327
                self.stop_receiving_new_connections()?;
328

            
329
                tracing::debug!("Waiting for existing connections to finish");
330
                self.wait_for_existing_connections_to_finish().await?;
331
            }
332

            
333
            tracing::debug!("Reloading listener with new socket path");
334
            self.reload_listener().await?;
335
        }
336

            
337
        if listener_task_was_stopped {
338
            tracing::debug!("Resuming listener task");
339
            self.resume_receiving_new_connections()?;
340
        }
341

            
342
        #[cfg(target_os = "linux")]
343
        sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?;
344

            
345
        Ok(())
346
    }
347

            
348
    pub async fn shutdown(&self) -> anyhow::Result<()> {
349
        #[cfg(target_os = "linux")]
350
        sd_notify::notify(false, &[sd_notify::NotifyState::Stopping])?;
351

            
352
        tracing::debug!("Stop accepting new connections");
353
        self.stop_receiving_new_connections()?;
354

            
355
        let connection_count = self.handler_task_tracker.len();
356
        tracing::debug!(
357
            "Waiting for {} existing connections to finish",
358
            connection_count
359
        );
360
        self.wait_for_existing_connections_to_finish().await?;
361

            
362
        tracing::debug!("Shutting down listener task");
363
        self.supervisor_message_sender
364
            .send(SupervisorMessage::Shutdown)
365
            .unwrap_or_else(|e| {
366
                tracing::warn!("Failed to send shutdown message to listener task: {}", e);
367
                0
368
            });
369

            
370
        tracing::debug!("Shutting down database connection pool");
371
        self.db_connection_pool.read().await.close().await;
372

            
373
        tracing::debug!("Server shutdown complete");
374

            
375
        std::process::exit(0);
376
    }
377

            
378
    pub async fn run(&self) -> anyhow::Result<()> {
379
        loop {
380
            select! {
381
                biased;
382

            
383
                _ = async {
384
                  let mut rx = self.reload_message_receiver.resubscribe();
385
                  rx.recv().await
386
                } => {
387
                    tracing::info!("Reloading configuration");
388
                    match self.reload().await {
389
                        Ok(()) => {
390
                            tracing::info!("Configuration reloaded successfully");
391
                        }
392
                        Err(e) => {
393
                            tracing::error!("Failed to reload configuration: {}", e);
394
                        }
395
                    }
396
                }
397

            
398
                () = self.shutdown_cancel_token.cancelled() => {
399
                    tracing::info!("Shutting down server");
400
                    self.shutdown().await?;
401
                    break;
402
                }
403
            }
404
        }
405

            
406
        Ok(())
407
    }
408
}
409

            
410
#[cfg(target_os = "linux")]
411
fn spawn_watchdog_task(duration: Duration) -> JoinHandle<()> {
412
    tokio::spawn(async move {
413
        let mut interval = interval(duration.div_f32(2.0));
414
        tracing::debug!(
415
            "Starting systemd watchdog task, pinging every {} milliseconds",
416
            duration.div_f32(2.0).as_millis()
417
        );
418
        loop {
419
            interval.tick().await;
420
            if let Err(err) = sd_notify::notify(false, &[sd_notify::NotifyState::Watchdog]) {
421
                tracing::warn!("Failed to notify systemd watchdog: {}", err);
422
            }
423
        }
424
    })
425
}
426

            
427
#[cfg(target_os = "linux")]
428
fn spawn_status_notifier_task(task_tracker: TaskTracker) -> JoinHandle<()> {
429
    const STATUS_UPDATE_INTERVAL_SECS: Duration = Duration::from_secs(1);
430

            
431
    tokio::spawn(async move {
432
        let mut interval = interval(STATUS_UPDATE_INTERVAL_SECS);
433
        loop {
434
            interval.tick().await;
435
            let count = task_tracker.len();
436

            
437
            let message = if count > 0 {
438
                format!("Handling {count} connections")
439
            } else {
440
                "Waiting for connections".to_string()
441
            };
442

            
443
            if let Err(e) =
444
                sd_notify::notify(false, &[sd_notify::NotifyState::Status(message.as_str())])
445
            {
446
                tracing::warn!("Failed to send systemd status notification: {}", e);
447
            }
448
        }
449
    })
450
}
451

            
452
async fn create_unix_listener_with_socket_path(
453
    socket_path: PathBuf,
454
) -> anyhow::Result<TokioUnixListener> {
455
    let parent_directory = socket_path.parent().unwrap();
456
    if !parent_directory.exists() {
457
        tracing::debug!("Creating directory {:?}", parent_directory);
458
        fs::create_dir_all(parent_directory)?;
459
    }
460

            
461
    tracing::info!("Listening on socket {:?}", socket_path);
462

            
463
    match fs::remove_file(socket_path.as_path()) {
464
        Ok(()) => {}
465
        Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
466
        Err(e) => return Err(e.into()),
467
    }
468

            
469
    let listener = TokioUnixListener::bind(socket_path)?;
470

            
471
    Ok(listener)
472
}
473

            
474
#[cfg(target_os = "linux")]
475
async fn create_unix_listener_with_systemd_socket() -> anyhow::Result<TokioUnixListener> {
476
    let fd = sd_notify::listen_fds()
477
        .context("Failed to get file descriptors from systemd")?
478
        .next()
479
        .context("No file descriptors received from systemd")?;
480

            
481
    debug_assert!(fd == 3, "Unexpected file descriptor from systemd: {fd}");
482

            
483
    tracing::debug!(
484
        "Received file descriptor from systemd with id: '{}', assuming socket",
485
        fd
486
    );
487

            
488
    let std_unix_listener = unsafe { StdUnixListener::from_raw_fd(fd) };
489
    std_unix_listener
490
        .set_nonblocking(true)
491
        .context("Failed to set non-blocking mode on systemd socket")?;
492
    let listener = TokioUnixListener::from_std(std_unix_listener)?;
493

            
494
    Ok(listener)
495
}
496

            
497
async fn create_db_connection_pool(config: &MysqlConfig) -> anyhow::Result<MySqlPool> {
498
    let mysql_config = config.as_mysql_connect_options()?;
499

            
500
    config.log_connection_notice();
501

            
502
    let pool = match tokio::time::timeout(
503
        Duration::from_secs(config.timeout),
504
        MySqlPool::connect_with(mysql_config),
505
    )
506
    .await
507
    {
508
        Ok(connection) => connection.context("Failed to connect to the database"),
509
        Err(_) => Err(anyhow!("Timed out after {} seconds", config.timeout))
510
            .context("Failed to connect to the database"),
511
    }?;
512

            
513
    let pool_opts = pool.options();
514
    tracing::debug!(
515
        "Successfully opened database connection pool with options (max_connections: {}, min_connections: {})",
516
        pool_opts.get_max_connections(),
517
        pool_opts.get_min_connections(),
518
    );
519

            
520
    Ok(pool)
521
}
522

            
523
fn spawn_signal_handler_task(
524
    reload_sender: broadcast::Sender<ReloadEvent>,
525
    shutdown_token: CancellationToken,
526
) -> JoinHandle<()> {
527
    tokio::spawn(async move {
528
        let mut sighup_stream =
529
            tokio::signal::unix::signal(tokio::signal::unix::SignalKind::hangup())
530
                .expect("Failed to set up SIGHUP handler");
531
        let mut sigterm_stream =
532
            tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
533
                .expect("Failed to set up SIGTERM handler");
534

            
535
        loop {
536
            tokio::select! {
537
                _ = sighup_stream.recv() => {
538
                    tracing::info!("Received SIGHUP signal");
539
                    reload_sender.send(ReloadEvent).ok();
540
                }
541
                _ = sigterm_stream.recv() => {
542
                    tracing::info!("Received SIGTERM signal");
543
                    shutdown_token.cancel();
544
                    break;
545
                }
546
            }
547
        }
548
    })
549
}
550

            
551
async fn listener_task(
552
    listener: Arc<RwLock<TokioUnixListener>>,
553
    task_tracker: TaskTracker,
554
    db_pool: Arc<RwLock<MySqlPool>>,
555
    mut supervisor_message_receiver: broadcast::Receiver<SupervisorMessage>,
556
    db_is_mariadb: Arc<RwLock<bool>>,
557
    group_denylist: Arc<RwLock<GroupDenylist>>,
558
) -> anyhow::Result<()> {
559
    #[cfg(target_os = "linux")]
560
    sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?;
561

            
562
    let connection_counter = AtomicU64::new(0);
563

            
564
    loop {
565
        tokio::select! {
566
            biased;
567

            
568
            Ok(message) = supervisor_message_receiver.recv() => {
569
                match message {
570
                    SupervisorMessage::StopAcceptingNewConnections => {
571
                        tracing::info!("Listener task received stop accepting new connections message, stopping listener");
572
                        while let Ok(msg) = supervisor_message_receiver.try_recv() {
573
                            if let SupervisorMessage::ResumeAcceptingNewConnections = msg {
574
                                tracing::info!("Listener task received resume accepting new connections message, resuming listener");
575
                                break;
576
                            }
577
                        }
578
                    }
579
                    SupervisorMessage::Shutdown => {
580
                        tracing::info!("Listener task received shutdown message, exiting listener task");
581
                        break;
582
                    }
583
                    _ => {}
584
                }
585
            }
586

            
587
            accept_result = async {
588
                let listener = listener.read().await;
589
                listener.accept().await
590
            } => {
591
                match accept_result {
592
                    Ok((conn, _addr)) => {
593
                        connection_counter.fetch_add(1, Ordering::Relaxed);
594
                        let conn_id = connection_counter.load(Ordering::Relaxed);
595

            
596
                        tracing::debug!("Got new connection, assigned session ID {}", conn_id);
597

            
598
                        let session_id = SessionId::new(conn_id);
599
                        let db_pool_clone = db_pool.clone();
600
                        let db_is_mariadb_clone = *db_is_mariadb.read().await;
601
                        let group_denylist_arc_clone = group_denylist.clone();
602
                        task_tracker.spawn(async move {
603
                            match session_handler(
604
                                conn,
605
                                session_id,
606
                                db_pool_clone,
607
                                db_is_mariadb_clone,
608
                                &*group_denylist_arc_clone.read().await,
609
                            ).await {
610
                                Ok(()) => {},
611
                                Err(e) => tracing::error!("Session {} failed: {}", conn_id, e),
612
                            }
613
                        });
614
                    },
615
                    Err(e) => tracing::error!("Failed to accept new connection: {}", e),
616
                }
617
            }
618
        }
619
    }
620

            
621
    Ok(())
622
}