1
use std::{
2
    collections::{HashMap, HashSet},
3
    net::{IpAddr, SocketAddr},
4
    path::Path,
5
    sync::Arc,
6
};
7

            
8
use anyhow::Context;
9
use chrono::{DateTime, Duration, Timelike, Utc};
10
use nix::{
11
    ifaddrs::getifaddrs,
12
    net::if_::InterfaceFlags,
13
    sys::{stat::stat, sysinfo::sysinfo},
14
    unistd::gethostname,
15
};
16
use tokio::{
17
    net::UdpSocket,
18
    sync::RwLock,
19
    time::{Duration as TokioDuration, interval},
20
};
21
use uucore::utmpx::Utmpx;
22

            
23
use crate::proto::{Whod, WhodStatusUpdate, WhodUserEntry};
24

            
25
/// Default port for rwhod communication.
26
pub const RWHOD_BROADCAST_PORT: u16 = 513;
27

            
28
pub type RwhodStatusStore = Arc<RwLock<HashMap<String, WhodStatusUpdate>>>;
29

            
30
/// Reads utmp entries to determine currently logged-in users.
31
pub fn generate_rwhod_user_entries(now: DateTime<Utc>) -> anyhow::Result<Vec<WhodUserEntry>> {
32
    Utmpx::iter_all_records()
33
        .filter(|entry| entry.is_user_process())
34
        .map(|entry| {
35
            let login_time = entry
36
                .login_time()
37
                .checked_to_utc()
38
                .and_then(|t| DateTime::<Utc>::from_timestamp_secs(t.unix_timestamp()))
39
                .ok_or_else(|| anyhow::anyhow!("Failed to convert login time to UTC"))?;
40

            
41
            let idle_time = stat(&Path::new("/dev").join(entry.tty_device()))
42
                .ok()
43
                .and_then(|st| {
44
                    let last_active = DateTime::<Utc>::from_timestamp_secs(st.st_atime)?;
45
                    Some((now - last_active).max(Duration::zero()))
46
                })
47
                .unwrap_or(Duration::zero());
48

            
49
            debug_assert!(
50
                idle_time.num_seconds() >= 0,
51
                "Idle time should never be negative"
52
            );
53

            
54
            Ok(WhodUserEntry::new(
55
                entry.tty_device(),
56
                entry.user(),
57
                login_time,
58
                idle_time,
59
            ))
60
        })
61
        .collect()
62
}
63

            
64
/// Generate a rwhod status update packet representing the current system state.
65
pub fn generate_rwhod_status_update() -> anyhow::Result<WhodStatusUpdate> {
66
    let sysinfo = sysinfo().unwrap();
67
    let load_average = sysinfo.load_average();
68
    let uptime = sysinfo.uptime();
69
    let hostname = gethostname()?.to_str().unwrap().to_string();
70
    let now = Utc::now().with_nanosecond(0).unwrap_or(Utc::now());
71

            
72
    let result = WhodStatusUpdate::new(
73
        now,
74
        None,
75
        hostname,
76
        (
77
            (load_average.0 * 100.0).abs() as i32,
78
            (load_average.1 * 100.0).abs() as i32,
79
            (load_average.2 * 100.0).abs() as i32,
80
        ),
81
        now - uptime,
82
        generate_rwhod_user_entries(now)?,
83
    );
84

            
85
    Ok(result)
86
}
87

            
88
#[derive(Debug, Clone)]
89
pub struct RwhodSendTarget {
90
    /// Name of the network interface.
91
    pub name: String,
92

            
93
    /// Address to send rwhod packets to.
94
    /// This is either the broadcast address (for broadcast interfaces)
95
    /// or the point-to-point destination address (for point-to-point interfaces).
96
    pub addr: IpAddr,
97
}
98

            
99
/// Find all networks network interfaces suitable for rwhod communication.
100
pub fn determine_relevant_interfaces() -> anyhow::Result<Vec<RwhodSendTarget>> {
101
    getifaddrs().map_err(|e| e.into()).map(|ifaces| {
102
        ifaces
103
            // interface must be up
104
            .filter(|iface| iface.flags.contains(InterfaceFlags::IFF_UP))
105
            // interface must be broadcast or point-to-point
106
            .filter(|iface| {
107
                iface
108
                    .flags
109
                    .intersects(InterfaceFlags::IFF_BROADCAST | InterfaceFlags::IFF_POINTOPOINT)
110
            })
111
            .filter_map(|iface| {
112
                let neighbor_addr = if iface.flags.contains(InterfaceFlags::IFF_BROADCAST) {
113
                    iface.broadcast
114
                } else if iface.flags.contains(InterfaceFlags::IFF_POINTOPOINT) {
115
                    iface.destination
116
                } else {
117
                    None
118
                };
119

            
120
                match neighbor_addr {
121
                    Some(addr) => addr
122
                        .as_sockaddr_in()
123
                        .map(|sa| IpAddr::V4(sa.ip()))
124
                        .or_else(|| addr.as_sockaddr_in6().map(|sa| IpAddr::V6(sa.ip())))
125
                        .map(|ip_addr| RwhodSendTarget {
126
                            name: iface.interface_name,
127
                            addr: ip_addr,
128
                        }),
129
                    None => None,
130
                }
131
            })
132
            // keep first occurrence per interface name
133
            .scan(HashSet::new(), |seen, n| {
134
                if seen.insert(n.name.clone()) {
135
                    Some(n)
136
                } else {
137
                    None
138
                }
139
            })
140
            .collect::<Vec<RwhodSendTarget>>()
141
    })
142
}
143

            
144
pub async fn send_rwhod_packet_to_interface(
145
    socket: Arc<UdpSocket>,
146
    interface: &RwhodSendTarget,
147
    packet: &Whod,
148
) -> anyhow::Result<()> {
149
    let serialized_packet = packet.to_bytes();
150

            
151
    // TODO: the old rwhod daemon doesn't actually ever listen to ipv6, maybe remove it
152
    let target_addr = match interface.addr {
153
        IpAddr::V4(addr) => SocketAddr::new(IpAddr::V4(addr), RWHOD_BROADCAST_PORT),
154
        IpAddr::V6(addr) => SocketAddr::new(IpAddr::V6(addr), RWHOD_BROADCAST_PORT),
155
    };
156

            
157
    tracing::debug!(
158
        "Sending rwhod packet to interface {} at address {}",
159
        interface.name,
160
        target_addr
161
    );
162

            
163
    socket
164
        .send_to(&serialized_packet, &target_addr)
165
        .await
166
        .map_err(|e| anyhow::anyhow!("Failed to send rwhod packet: {}", e))?;
167

            
168
    Ok(())
169
}
170

            
171
pub async fn rwhod_packet_receiver_task(
172
    socket: Arc<UdpSocket>,
173
    whod_status_store: RwhodStatusStore,
174
) -> anyhow::Result<()> {
175
    let mut buf = [0u8; Whod::MAX_SIZE];
176

            
177
    loop {
178
        let (len, src) = socket.recv_from(&mut buf).await?;
179

            
180
        tracing::debug!("Received rwhod packet of length {} bytes from {}", len, src);
181

            
182
        if len < Whod::HEADER_SIZE {
183
            tracing::error!(
184
                "Received too short packet from {src}: {len} bytes (needs to be at least {} bytes)",
185
                Whod::HEADER_SIZE
186
            );
187
            continue;
188
        }
189

            
190
        let result = Whod::from_bytes(&buf[..len])
191
            .context("Failed to parse whod packet")?
192
            .try_into()
193
            .map(|mut status_update: WhodStatusUpdate| {
194
                let timestamp = Utc::now().with_nanosecond(0).unwrap_or(Utc::now());
195
                status_update.recvtime = Some(timestamp);
196
                status_update
197
            })
198
            .map_err(|e| anyhow::anyhow!("Invalid whod packet: {}", e));
199

            
200
        match result {
201
            Ok(status_update) => {
202
                tracing::debug!("Processed whod packet from {src}: {:?}", status_update);
203

            
204
                let mut store = whod_status_store.write().await;
205
                store.insert(status_update.hostname.clone(), status_update);
206
            }
207
            Err(err) => {
208
                tracing::error!("Error processing whod packet from {src}: {err}");
209
            }
210
        }
211
    }
212
}
213

            
214
pub async fn rwhod_packet_sender_task(
215
    socket: Arc<UdpSocket>,
216
    interfaces: Vec<RwhodSendTarget>,
217
) -> anyhow::Result<()> {
218
    let mut interval = interval(TokioDuration::from_secs(60));
219

            
220
    loop {
221
        interval.tick().await;
222

            
223
        let status_update = generate_rwhod_status_update()?;
224

            
225
        tracing::debug!("Generated rwhod packet: {:?}", status_update);
226

            
227
        let packet = status_update
228
            .try_into()
229
            .map_err(|e| anyhow::anyhow!("{}", e))?;
230

            
231
        for interface in &interfaces {
232
            if let Err(e) = send_rwhod_packet_to_interface(socket.clone(), interface, &packet).await
233
            {
234
                tracing::error!(
235
                    "Failed to send rwhod packet on interface {}: {}",
236
                    interface.name,
237
                    e
238
                );
239
            }
240
        }
241
    }
242
}