1
use nix::{ifaddrs::getifaddrs, net::if_::InterfaceFlags};
2
use std::{
3
    collections::HashSet,
4
    net::{IpAddr, SocketAddr},
5
    sync::Arc,
6
};
7
use tokio::{
8
    net::UdpSocket,
9
    time::{Duration as TokioDuration, interval},
10
};
11

            
12
use crate::{proto::Whod, server::rwhod::rwhod_status::generate_rwhod_status_update};
13

            
14
/// Default port for rwhod communication.
15
pub const RWHOD_BROADCAST_PORT: u16 = 513;
16

            
17
#[derive(Debug, Clone)]
18
pub struct RwhodSendTarget {
19
    /// Name of the network interface.
20
    pub name: String,
21

            
22
    /// Address to send rwhod packets to.
23
    /// This is either the broadcast address (for broadcast interfaces)
24
    /// or the point-to-point destination address (for point-to-point interfaces).
25
    pub addr: IpAddr,
26
}
27

            
28
/// Find all networks network interfaces suitable for rwhod communication.
29
1
pub fn determine_relevant_interfaces() -> anyhow::Result<Vec<RwhodSendTarget>> {
30
1
    getifaddrs().map_err(|e| e.into()).map(|ifaces| {
31
1
        ifaces
32
            // interface must be up
33
6
            .filter(|iface| iface.flags.contains(InterfaceFlags::IFF_UP))
34
            // interface must be broadcast or point-to-point
35
6
            .filter(|iface| {
36
6
                iface
37
6
                    .flags
38
6
                    .intersects(InterfaceFlags::IFF_BROADCAST | InterfaceFlags::IFF_POINTOPOINT)
39
6
            })
40
3
            .filter_map(|iface| {
41
3
                let neighbor_addr = if iface.flags.contains(InterfaceFlags::IFF_BROADCAST) {
42
3
                    iface.broadcast
43
                } else if iface.flags.contains(InterfaceFlags::IFF_POINTOPOINT) {
44
                    iface.destination
45
                } else {
46
                    None
47
                };
48

            
49
3
                match neighbor_addr {
50
2
                    Some(addr) => addr
51
2
                        .as_sockaddr_in()
52
2
                        .map(|sa| IpAddr::V4(sa.ip()))
53
2
                        .or_else(|| addr.as_sockaddr_in6().map(|sa| IpAddr::V6(sa.ip())))
54
2
                        .map(|ip_addr| RwhodSendTarget {
55
1
                            name: iface.interface_name,
56
1
                            addr: ip_addr,
57
1
                        }),
58
1
                    None => None,
59
                }
60
3
            })
61
            // keep first occurrence per interface name
62
1
            .scan(HashSet::new(), |seen, n| {
63
1
                if seen.insert(n.name.clone()) {
64
1
                    Some(n)
65
                } else {
66
                    None
67
                }
68
1
            })
69
1
            .collect::<Vec<RwhodSendTarget>>()
70
1
    })
71
1
}
72

            
73
pub async fn send_rwhod_packet_to_interface(
74
    socket: Arc<UdpSocket>,
75
    interface: &RwhodSendTarget,
76
    packet: &Whod,
77
) -> anyhow::Result<()> {
78
    let serialized_packet = packet.to_bytes();
79

            
80
    // TODO: the old rwhod daemon doesn't actually ever listen to ipv6, maybe remove it
81
    let target_addr = match interface.addr {
82
        IpAddr::V4(addr) => SocketAddr::new(IpAddr::V4(addr), RWHOD_BROADCAST_PORT),
83
        IpAddr::V6(addr) => SocketAddr::new(IpAddr::V6(addr), RWHOD_BROADCAST_PORT),
84
    };
85

            
86
    tracing::debug!(
87
        "Sending rwhod packet to interface {} at address {}",
88
        interface.name,
89
        target_addr
90
    );
91

            
92
    socket
93
        .send_to(&serialized_packet, &target_addr)
94
        .await
95
        .map_err(|e| anyhow::anyhow!("Failed to send rwhod packet: {}", e))?;
96

            
97
    Ok(())
98
}
99

            
100
pub async fn rwhod_packet_sender_task(
101
    socket: Arc<UdpSocket>,
102
    interfaces: Vec<RwhodSendTarget>,
103
) -> anyhow::Result<()> {
104
    let mut interval = interval(TokioDuration::from_secs(60));
105

            
106
    loop {
107
        interval.tick().await;
108

            
109
        let status_update = generate_rwhod_status_update()?;
110

            
111
        tracing::debug!("Generated rwhod packet: {:?}", status_update);
112

            
113
        let packet = status_update
114
            .try_into()
115
            .map_err(|e| anyhow::anyhow!("{}", e))?;
116

            
117
        for interface in &interfaces {
118
            if let Err(e) = send_rwhod_packet_to_interface(socket.clone(), interface, &packet).await
119
            {
120
                tracing::error!(
121
                    "Failed to send rwhod packet on interface {}: {}",
122
                    interface.name,
123
                    e
124
                );
125
            }
126
        }
127
    }
128
}
129

            
130
#[cfg(test)]
131
mod tests {
132
    use super::*;
133

            
134
    #[test]
135
1
    fn test_determine_relevant_interfaces() {
136
1
        let interfaces = determine_relevant_interfaces().unwrap();
137
1
        for interface in interfaces {
138
1
            println!("Interface: {} Address: {}", interface.name, interface.addr);
139
1
        }
140
1
    }
141
}