1
use std::{
2
    collections::HashMap,
3
    os::fd::{AsRawFd, FromRawFd, OwnedFd},
4
    path::PathBuf,
5
    sync::Arc,
6
};
7

            
8
use anyhow::Context;
9
use clap::Parser;
10
use tokio::{net::UdpSocket, sync::RwLock};
11
use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberInitExt};
12

            
13
use roowho2_lib::server::{
14
    config::DEFAULT_CONFIG_PATH,
15
    rwhod::{RwhodStatusStore, rwhod_packet_receiver_task, rwhod_packet_sender_task},
16
    varlink_api::varlink_client_server_task,
17
};
18

            
19
#[derive(Parser)]
20
#[command(
21
    author = "Programvareverkstedet <projects@pvv.ntnu.no>",
22
    about,
23
    version
24
)]
25
struct Args {
26
    /// Path to configuration file
27
    #[arg(
28
        short = 'c',
29
        long = "config",
30
        default_value = DEFAULT_CONFIG_PATH,
31
        value_name = "PATH"
32
    )]
33
    config_path: PathBuf,
34
}
35

            
36
#[tokio::main]
37
async fn main() -> anyhow::Result<()> {
38
    let args = Args::parse();
39

            
40
    tracing_subscriber::registry()
41
        .with(fmt::layer())
42
        .with(EnvFilter::from_default_env())
43
        .init();
44

            
45
    let config = toml::from_str::<roowho2_lib::server::config::Config>(
46
        &std::fs::read_to_string(&args.config_path).context(format!(
47
            "Failed to read configuration file {:?}",
48
            args.config_path
49
        ))?,
50
    )?;
51

            
52
    let fd_map: HashMap<String, OwnedFd> = HashMap::from_iter(
53
        sd_notify::listen_fds_with_names(false)?.map(|(fd_num, name)| {
54
            (
55
                name.clone(),
56
                // SAFETY: please don't mess around with file descriptors in random places
57
                //         around the codebase lol
58
                unsafe { std::os::fd::OwnedFd::from_raw_fd(fd_num) },
59
            )
60
        }),
61
    );
62

            
63
    let mut join_set = tokio::task::JoinSet::new();
64

            
65
    let whod_status_store = Arc::new(RwLock::new(HashMap::new()));
66

            
67
    if config.rwhod.enable {
68
        tracing::info!("Starting RWHOD server");
69

            
70
        let socket = fd_map
71
            .get("rwhod_socket")
72
            .map(|fd| {
73
                // SAFETY: see above
74
                let std_socket = unsafe { std::net::UdpSocket::from_raw_fd(fd.as_raw_fd()) };
75
                std_socket.set_nonblocking(true)?;
76
                UdpSocket::from_std(std_socket)
77
            })
78
            .context("RWHOD server is enabled, but socket fd not provided by systemd")??;
79

            
80
        join_set.spawn(rwhod_server(socket, whod_status_store.clone()));
81
    } else {
82
        tracing::debug!("RWHOD server is disabled in configuration");
83
    }
84

            
85
    join_set.spawn(client_server(
86
        fd_map
87
            .get("client_socket")
88
            .context("RWHOD client-server socket fd not provided by systemd")?
89
            .try_clone()
90
            .context("Failed to clone RWHOD client-server socket fd")?,
91
        whod_status_store.clone(),
92
    ));
93

            
94
    join_set.spawn(ctrl_c_handler());
95

            
96
    join_set.join_next().await.unwrap()??;
97

            
98
    Ok(())
99
}
100

            
101
async fn ctrl_c_handler() -> anyhow::Result<()> {
102
    tokio::signal::ctrl_c()
103
        .await
104
        .map_err(|e| anyhow::anyhow!("Failed to listen for Ctrl-C: {}", e))
105
}
106

            
107
async fn rwhod_server(
108
    socket: UdpSocket,
109
    whod_status_store: RwhodStatusStore,
110
) -> anyhow::Result<()> {
111
    let socket = Arc::new(socket);
112

            
113
    let interfaces = roowho2_lib::server::rwhod::determine_relevant_interfaces()?;
114
    let sender_task = rwhod_packet_sender_task(socket.clone(), interfaces);
115

            
116
    let receiver_task = rwhod_packet_receiver_task(socket.clone(), whod_status_store);
117

            
118
    tokio::select! {
119
        res = sender_task => res?,
120
        res = receiver_task => res?,
121
    }
122

            
123
    Ok(())
124
}
125

            
126
async fn client_server(
127
    socket_fd: OwnedFd,
128
    whod_status_store: RwhodStatusStore,
129
) -> anyhow::Result<()> {
130
    // SAFETY: see above
131
    let std_socket =
132
        unsafe { std::os::unix::net::UnixListener::from_raw_fd(socket_fd.as_raw_fd()) };
133
    std_socket.set_nonblocking(true)?;
134
    let zlink_listener = zlink::unix::Listener::try_from(OwnedFd::from(std_socket))?;
135
    let client_server_task = varlink_client_server_task(zlink_listener, whod_status_store);
136

            
137
    client_server_task.await?;
138

            
139
    Ok(())
140
}