1
use std::{
2
    collections::HashMap,
3
    os::fd::{AsRawFd, FromRawFd, OwnedFd},
4
    path::PathBuf,
5
    sync::Arc,
6
};
7

            
8
use anyhow::Context;
9
use clap::Parser;
10
use tokio::{net::UdpSocket, sync::RwLock};
11
use tokio_util::sync::CancellationToken;
12
use tracing::level_filters::LevelFilter;
13
use tracing_subscriber::layer::SubscriberExt;
14

            
15
use roowho2_lib::server::{
16
    config::{DEFAULT_CONFIG_PATH, LogLevel},
17
    rwhod::{RwhodStatusStore, rwhod_packet_receiver_task, rwhod_packet_sender_task},
18
    varlink_api::varlink_client_server_task,
19
};
20

            
21
#[derive(Parser)]
22
#[command(
23
    author = "Programvareverkstedet <projects@pvv.ntnu.no>",
24
    about,
25
    version
26
)]
27
struct Args {
28
    /// Path to configuration file
29
    #[arg(
30
        short = 'c',
31
        long = "config",
32
        default_value = DEFAULT_CONFIG_PATH,
33
        value_name = "PATH"
34
    )]
35
    config_path: PathBuf,
36
}
37

            
38
#[tokio::main]
39
async fn main() -> anyhow::Result<()> {
40
    let args = Args::parse();
41

            
42
    let config = toml::from_str::<roowho2_lib::server::config::Config>(
43
        &std::fs::read_to_string(&args.config_path).context(format!(
44
            "Failed to read configuration file {:?}",
45
            args.config_path
46
        ))?,
47
    )?;
48

            
49
    let log_filter = match config.log_level.unwrap_or(LogLevel::Info) {
50
        LogLevel::Info => LevelFilter::INFO,
51
        LogLevel::Debug => LevelFilter::DEBUG,
52
        LogLevel::Trace => LevelFilter::TRACE,
53
    };
54

            
55
    let subscriber = tracing_subscriber::registry()
56
        .with(log_filter)
57
        .with(tracing_journald::layer()?);
58

            
59
    tracing::subscriber::set_global_default(subscriber)
60
        .context("Failed to set global default tracing subscriber")?;
61

            
62
    let fd_map: HashMap<String, OwnedFd> =
63
        HashMap::from_iter(sd_notify::listen_fds_with_names()?.map(|(fd_num, name)| {
64
            (
65
                name.clone(),
66
                // SAFETY: please don't mess around with file descriptors in random places
67
                //         around the codebase lol
68
                unsafe { std::os::fd::OwnedFd::from_raw_fd(fd_num) },
69
            )
70
        }));
71

            
72
    let mut join_set = tokio::task::JoinSet::new();
73

            
74
    let whod_status_store = Arc::new(RwLock::new(HashMap::new()));
75

            
76
    let client_server_token = CancellationToken::new();
77
    let client_server_token_ = client_server_token.clone();
78
    tokio::spawn(async move {
79
        client_server_token_.cancelled().await;
80
        tracing::info!("RWHOD client-server is now accepting connections");
81
        #[cfg(feature = "systemd")]
82
        sd_notify::notify(&[sd_notify::NotifyState::Ready]).ok();
83
        Ok::<(), anyhow::Error>(())
84
    });
85

            
86
    if config.rwhod.enable {
87
        tracing::info!("Starting RWHOD server");
88

            
89
        let socket = fd_map
90
            .get("rwhod_socket")
91
            .map(|fd| {
92
                // SAFETY: see above
93
                let std_socket = unsafe { std::net::UdpSocket::from_raw_fd(fd.as_raw_fd()) };
94
                std_socket.set_nonblocking(true)?;
95
                UdpSocket::from_std(std_socket)
96
            })
97
            .context("RWHOD server is enabled, but socket fd not provided by systemd")??;
98

            
99
        join_set.spawn(rwhod_server(socket, whod_status_store.clone()));
100
    } else {
101
        tracing::debug!("RWHOD server is disabled in configuration");
102
    }
103

            
104
    join_set.spawn(client_server(
105
        fd_map
106
            .get("client_socket")
107
            .context("RWHOD client-server socket fd not provided by systemd")?
108
            .try_clone()
109
            .context("Failed to clone RWHOD client-server socket fd")?,
110
        whod_status_store.clone(),
111
        client_server_token,
112
    ));
113

            
114
    join_set.spawn(ctrl_c_handler());
115

            
116
    join_set.join_next().await.unwrap()??;
117

            
118
    Ok(())
119
}
120

            
121
async fn ctrl_c_handler() -> anyhow::Result<()> {
122
    tokio::signal::ctrl_c()
123
        .await
124
        .map_err(|e| anyhow::anyhow!("Failed to listen for Ctrl-C: {}", e))
125
}
126

            
127
async fn rwhod_server(
128
    socket: UdpSocket,
129
    whod_status_store: RwhodStatusStore,
130
) -> anyhow::Result<()> {
131
    let socket = Arc::new(socket);
132

            
133
    let interfaces = roowho2_lib::server::rwhod::determine_relevant_interfaces()?;
134
    let sender_task = rwhod_packet_sender_task(socket.clone(), interfaces);
135

            
136
    let receiver_task = rwhod_packet_receiver_task(socket.clone(), whod_status_store);
137

            
138
    tokio::select! {
139
        res = sender_task => res?,
140
        res = receiver_task => res?,
141
    }
142

            
143
    Ok(())
144
}
145

            
146
async fn client_server(
147
    socket_fd: OwnedFd,
148
    whod_status_store: RwhodStatusStore,
149
    startup_token: CancellationToken,
150
) -> anyhow::Result<()> {
151
    // SAFETY: see above
152
    let std_socket =
153
        unsafe { std::os::unix::net::UnixListener::from_raw_fd(socket_fd.as_raw_fd()) };
154
    std_socket.set_nonblocking(true)?;
155
    let zlink_listener = zlink::unix::Listener::try_from(OwnedFd::from(std_socket))?;
156
    let client_server_task = varlink_client_server_task(zlink_listener, whod_status_store);
157

            
158
    startup_token.cancel();
159

            
160
    client_server_task.await?;
161

            
162
    Ok(())
163
}