1
use std::{
2
    collections::HashMap,
3
    os::fd::{AsRawFd, FromRawFd, OwnedFd},
4
    path::PathBuf,
5
    sync::Arc,
6
};
7

            
8
use anyhow::Context;
9
use clap::Parser;
10
use tokio::{net::UdpSocket, sync::RwLock};
11
use tokio_util::sync::CancellationToken;
12
use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberInitExt};
13

            
14
use roowho2_lib::server::{
15
    config::DEFAULT_CONFIG_PATH,
16
    rwhod::{RwhodStatusStore, rwhod_packet_receiver_task, rwhod_packet_sender_task},
17
    varlink_api::varlink_client_server_task,
18
};
19

            
20
#[derive(Parser)]
21
#[command(
22
    author = "Programvareverkstedet <projects@pvv.ntnu.no>",
23
    about,
24
    version
25
)]
26
struct Args {
27
    /// Path to configuration file
28
    #[arg(
29
        short = 'c',
30
        long = "config",
31
        default_value = DEFAULT_CONFIG_PATH,
32
        value_name = "PATH"
33
    )]
34
    config_path: PathBuf,
35
}
36

            
37
#[tokio::main]
38
async fn main() -> anyhow::Result<()> {
39
    let args = Args::parse();
40

            
41
    tracing_subscriber::registry()
42
        .with(fmt::layer())
43
        .with(EnvFilter::from_default_env())
44
        .init();
45

            
46
    let config = toml::from_str::<roowho2_lib::server::config::Config>(
47
        &std::fs::read_to_string(&args.config_path).context(format!(
48
            "Failed to read configuration file {:?}",
49
            args.config_path
50
        ))?,
51
    )?;
52

            
53
    let fd_map: HashMap<String, OwnedFd> = HashMap::from_iter(
54
        sd_notify::listen_fds_with_names(false)?.map(|(fd_num, name)| {
55
            (
56
                name.clone(),
57
                // SAFETY: please don't mess around with file descriptors in random places
58
                //         around the codebase lol
59
                unsafe { std::os::fd::OwnedFd::from_raw_fd(fd_num) },
60
            )
61
        }),
62
    );
63

            
64
    let mut join_set = tokio::task::JoinSet::new();
65

            
66
    let whod_status_store = Arc::new(RwLock::new(HashMap::new()));
67

            
68
    let client_server_token = CancellationToken::new();
69
    let client_server_token_ = client_server_token.clone();
70
    tokio::spawn(async move {
71
        client_server_token_.cancelled().await;
72
        tracing::info!("RWHOD client-server is now accepting connections");
73
        #[cfg(feature = "systemd")]
74
        sd_notify::notify(true, &[sd_notify::NotifyState::Ready]).ok();
75
        Ok::<(), anyhow::Error>(())
76
    });
77

            
78
    if config.rwhod.enable {
79
        tracing::info!("Starting RWHOD server");
80

            
81
        let socket = fd_map
82
            .get("rwhod_socket")
83
            .map(|fd| {
84
                // SAFETY: see above
85
                let std_socket = unsafe { std::net::UdpSocket::from_raw_fd(fd.as_raw_fd()) };
86
                std_socket.set_nonblocking(true)?;
87
                UdpSocket::from_std(std_socket)
88
            })
89
            .context("RWHOD server is enabled, but socket fd not provided by systemd")??;
90

            
91
        join_set.spawn(rwhod_server(socket, whod_status_store.clone()));
92
    } else {
93
        tracing::debug!("RWHOD server is disabled in configuration");
94
    }
95

            
96
    join_set.spawn(client_server(
97
        fd_map
98
            .get("client_socket")
99
            .context("RWHOD client-server socket fd not provided by systemd")?
100
            .try_clone()
101
            .context("Failed to clone RWHOD client-server socket fd")?,
102
        whod_status_store.clone(),
103
        client_server_token,
104
    ));
105

            
106
    join_set.spawn(ctrl_c_handler());
107

            
108
    join_set.join_next().await.unwrap()??;
109

            
110
    Ok(())
111
}
112

            
113
async fn ctrl_c_handler() -> anyhow::Result<()> {
114
    tokio::signal::ctrl_c()
115
        .await
116
        .map_err(|e| anyhow::anyhow!("Failed to listen for Ctrl-C: {}", e))
117
}
118

            
119
async fn rwhod_server(
120
    socket: UdpSocket,
121
    whod_status_store: RwhodStatusStore,
122
) -> anyhow::Result<()> {
123
    let socket = Arc::new(socket);
124

            
125
    let interfaces = roowho2_lib::server::rwhod::determine_relevant_interfaces()?;
126
    let sender_task = rwhod_packet_sender_task(socket.clone(), interfaces);
127

            
128
    let receiver_task = rwhod_packet_receiver_task(socket.clone(), whod_status_store);
129

            
130
    tokio::select! {
131
        res = sender_task => res?,
132
        res = receiver_task => res?,
133
    }
134

            
135
    Ok(())
136
}
137

            
138
async fn client_server(
139
    socket_fd: OwnedFd,
140
    whod_status_store: RwhodStatusStore,
141
    startup_token: CancellationToken,
142
) -> anyhow::Result<()> {
143
    // SAFETY: see above
144
    let std_socket =
145
        unsafe { std::os::unix::net::UnixListener::from_raw_fd(socket_fd.as_raw_fd()) };
146
    std_socket.set_nonblocking(true)?;
147
    let zlink_listener = zlink::unix::Listener::try_from(OwnedFd::from(std_socket))?;
148
    let client_server_task = varlink_client_server_task(zlink_listener, whod_status_store);
149

            
150
    startup_token.cancel();
151

            
152
    client_server_task.await?;
153

            
154
    Ok(())
155
}