1
use std::array;
2

            
3
use bytes::{Buf, BufMut, BytesMut};
4
use chrono::{DateTime, Duration, Utc};
5
use serde::{Deserialize, Serialize};
6

            
7
/// Classic C struct for utmp data for a single user session.
8
///
9
/// This struct is used in the rwhod protocol by being interpreted as raw bytes to be sent over UDP.
10
#[derive(Debug, Clone, PartialEq, Eq)]
11
#[repr(C)]
12
pub struct Outmp {
13
    /// tty name
14
    pub out_line: [u8; Self::MAX_TTY_NAME_LEN],
15
    /// user id
16
    pub out_name: [u8; Self::MAX_USER_ID_LEN],
17
    /// time on
18
    pub out_time: i32,
19
}
20

            
21
impl Outmp {
22
    pub const MAX_TTY_NAME_LEN: usize = 8;
23
    pub const MAX_USER_ID_LEN: usize = 8;
24
}
25

            
26
/// Classic C struct for a single user session.
27
///
28
/// This struct is used in the rwhod protocol by being interpreted as raw bytes to be sent over UDP.
29
#[derive(Debug, Clone, PartialEq, Eq)]
30
#[repr(C)]
31
pub struct Whoent {
32
    /// active tty info
33
    pub we_utmp: Outmp,
34
    /// tty idle time
35
    pub we_idle: i32,
36
}
37

            
38
impl Whoent {
39
    pub const SIZE: usize = std::mem::size_of::<Self>();
40

            
41
43
    fn zeroed() -> Self {
42
43
        Self {
43
43
            we_utmp: Outmp {
44
43
                out_line: [0u8; Outmp::MAX_TTY_NAME_LEN],
45
43
                out_name: [0u8; Outmp::MAX_USER_ID_LEN],
46
43
                out_time: 0,
47
43
            },
48
43
            we_idle: 0,
49
43
        }
50
43
    }
51

            
52
6
    fn is_zeroed(&self) -> bool {
53
20
        self.we_utmp.out_line.iter().all(|&b| b == 0)
54
16
            && self.we_utmp.out_name.iter().all(|&b| b == 0)
55
2
            && self.we_utmp.out_time == 0
56
2
            && self.we_idle == 0
57
6
    }
58
}
59

            
60
/// Classic C struct for a rwhod status update.
61
///
62
/// This struct is used in the rwhod protocol by being interpreted as raw bytes to be sent over UDP.
63
#[derive(Debug, Clone, PartialEq, Eq)]
64
#[repr(C)]
65
pub struct Whod {
66
    /// protocol version
67
    pub wd_vers: u8,
68
    /// packet type, see below
69
    pub wd_type: u8,
70
    pub wd_pad: [u8; 2],
71
    /// time stamp by sender
72
    pub wd_sendtime: i32,
73
    /// time stamp applied by receiver
74
    pub wd_recvtime: i32,
75
    /// host's name
76
    pub wd_hostname: [u8; Self::MAX_HOSTNAME_LEN],
77
    /// load average as in uptime
78
    pub wd_loadav: [i32; 3],
79
    /// time system booted
80
    pub wd_boottime: i32,
81
    pub wd_we: [Whoent; Self::MAX_WHOENTRIES],
82
}
83

            
84
impl Whod {
85
    pub const HEADER_SIZE: usize = 1 + 1 + 2 + 4 + 4 + Self::MAX_HOSTNAME_LEN + 4 * 3 + 4;
86
    pub const MAX_SIZE: usize = std::mem::size_of::<Self>();
87

            
88
    pub const MAX_HOSTNAME_LEN: usize = 32;
89
    pub const MAX_WHOENTRIES: usize = 1024 / std::mem::size_of::<Whoent>();
90

            
91
    pub const WHODVERSION: u8 = 1;
92

            
93
    // NOTE: there was probably meant to be more packet types, but only status is defined.
94
    pub const WHODTYPE_STATUS: u8 = 1;
95

            
96
1
    pub fn new(
97
1
        sendtime: i32,
98
1
        recvtime: i32,
99
1
        hostname: [u8; Self::MAX_HOSTNAME_LEN],
100
1
        loadav: [i32; 3],
101
1
        boottime: i32,
102
1
        whoentries: [Whoent; Self::MAX_WHOENTRIES],
103
1
    ) -> Self {
104
1
        debug_assert!(
105
            whoentries
106
                .iter()
107
                .skip_while(|entry| !entry.is_zeroed())
108
                .all(|entry| entry.is_zeroed())
109
        );
110

            
111
1
        Self {
112
1
            wd_vers: Self::WHODVERSION,
113
1
            wd_type: Self::WHODTYPE_STATUS,
114
1
            wd_pad: [0u8; 2],
115
1
            wd_sendtime: sendtime,
116
1
            wd_recvtime: recvtime,
117
1
            wd_hostname: hostname,
118
1
            wd_loadav: loadav,
119
1
            wd_boottime: boottime,
120
1
            wd_we: whoentries,
121
1
        }
122
1
    }
123

            
124
1
    pub fn to_bytes(&self) -> Vec<u8> {
125
1
        let mut buf = BytesMut::with_capacity(Whod::MAX_SIZE);
126
1
        buf.put_u8(self.wd_vers);
127
1
        buf.put_u8(self.wd_type);
128
1
        buf.put_slice(&self.wd_pad);
129
1
        buf.put_i32(self.wd_sendtime);
130
1
        buf.put_i32(self.wd_recvtime);
131
1
        buf.put_slice(&self.wd_hostname);
132
1
        buf.put_i32(self.wd_loadav[0]);
133
1
        buf.put_i32(self.wd_loadav[1]);
134
1
        buf.put_i32(self.wd_loadav[2]);
135
1
        buf.put_i32(self.wd_boottime);
136

            
137
3
        for whoent in self.wd_we.iter().take_while(|entry| !entry.is_zeroed()) {
138
2
            buf.put_slice(&whoent.we_utmp.out_line);
139
2
            buf.put_slice(&whoent.we_utmp.out_name);
140
2
            buf.put_i32(whoent.we_utmp.out_time);
141
2
            buf.put_i32(whoent.we_idle);
142
2
        }
143

            
144
1
        buf.to_vec()
145
1
    }
146

            
147
6
    pub fn from_bytes(input: &[u8]) -> anyhow::Result<Self> {
148
6
        if input.len() < Self::HEADER_SIZE {
149
1
            return Err(anyhow::anyhow!(
150
1
                "Not enough bytes to parse packet header: {} < {}",
151
1
                input.len(),
152
1
                Self::HEADER_SIZE
153
1
            ));
154
5
        }
155

            
156
5
        if input.len() > Self::MAX_SIZE {
157
1
            return Err(anyhow::anyhow!(
158
1
                "Too many bytes to parse packet: {} > {}",
159
1
                input.len(),
160
1
                Self::MAX_SIZE
161
1
            ));
162
4
        }
163

            
164
4
        if !(input.len() - Self::HEADER_SIZE).is_multiple_of(Whoent::SIZE) {
165
1
            return Err(anyhow::anyhow!(
166
1
                "Invalid packet length: {} (not aligned with struct sizes, should be {} + N * {})",
167
1
                input.len(),
168
1
                Self::HEADER_SIZE,
169
1
                Whoent::SIZE,
170
1
            ));
171
3
        }
172

            
173
3
        let mut bytes = bytes::Bytes::copy_from_slice(input);
174

            
175
3
        let wd_vers = bytes.get_u8();
176
3
        if wd_vers != Self::WHODVERSION {
177
1
            return Err(anyhow::anyhow!(
178
1
                "Unsupported whod protocol version: {}",
179
1
                wd_vers
180
1
            ));
181
2
        }
182

            
183
2
        let wd_type = bytes.get_u8();
184
2
        if wd_type != Self::WHODTYPE_STATUS {
185
1
            return Err(anyhow::anyhow!("Unsupported whod packet type: {}", wd_type));
186
1
        }
187

            
188
1
        bytes.advance(2); // skip wd_pad
189

            
190
1
        let wd_sendtime = bytes.get_i32();
191
1
        let wd_recvtime = bytes.get_i32();
192
1
        let mut wd_hostname = [0u8; Self::MAX_HOSTNAME_LEN];
193
1
        bytes.copy_to_slice(&mut wd_hostname);
194
1
        let wd_loadav = [bytes.get_i32(), bytes.get_i32(), bytes.get_i32()];
195
1
        let wd_boottime = bytes.get_i32();
196

            
197
1
        debug_assert!(bytes.remaining() + Self::HEADER_SIZE == input.len());
198

            
199
42
        let mut wd_we = array::from_fn(|_| Whoent::zeroed());
200

            
201
2
        for (byte_chunk, whoent) in bytes.chunks_exact(Whoent::SIZE).zip(wd_we.iter_mut()) {
202
2
            let mut chunk_bytes = bytes::Bytes::copy_from_slice(byte_chunk);
203
2

            
204
2
            let mut out_line = [0u8; Outmp::MAX_TTY_NAME_LEN];
205
2
            chunk_bytes.copy_to_slice(&mut out_line);
206
2
            let mut out_name = [0u8; Outmp::MAX_USER_ID_LEN];
207
2
            chunk_bytes.copy_to_slice(&mut out_name);
208
2
            let out_time = chunk_bytes.get_i32();
209
2

            
210
2
            let we_utmp = Outmp {
211
2
                out_line,
212
2
                out_name,
213
2
                out_time,
214
2
            };
215
2
            let we_idle = chunk_bytes.get_i32();
216
2

            
217
2
            *whoent = Whoent { we_utmp, we_idle };
218
2
        }
219

            
220
1
        let result = Whod::new(
221
1
            wd_sendtime,
222
1
            wd_recvtime,
223
1
            wd_hostname,
224
1
            wd_loadav,
225
1
            wd_boottime,
226
1
            wd_we,
227
        );
228

            
229
1
        Ok(result)
230
6
    }
231
}
232

            
233
// ------------------------------------------------
234

            
235
/// Load average representation: (5 min, 10 min, 15 min)
236
/// All values are multiplied by 100.
237
pub type LoadAverage = (i32, i32, i32);
238

            
239
/// High-level representation of a rwhod status update.
240
///
241
/// This struct is intended for easier use in Rust code, with proper types and dynamic arrays.
242
/// It can be converted to and from the low-level [`Whod`] struct used for network transmission.
243
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
244
pub struct WhodStatusUpdate {
245
    // NOTE: there is only one defined packet type, so we just omit it here
246
    /// Timestamp by sender
247
    pub sendtime: DateTime<Utc>,
248

            
249
    /// Timestamp applied by receiver
250
    pub recvtime: Option<DateTime<Utc>>,
251

            
252
    /// Name of the host sending the status update (max 32 characters)
253
    pub hostname: String,
254

            
255
    /// load average over 5, 10, and 15 minutes multiplied by 100
256
    pub load_average: LoadAverage,
257

            
258
    /// Which time the system was booted
259
    pub boot_time: DateTime<Utc>,
260

            
261
    /// List of users currently logged in to the host (max 42 entries)
262
    pub users: Vec<WhodUserEntry>,
263
}
264

            
265
impl WhodStatusUpdate {
266
1
    pub fn new(
267
1
        sendtime: DateTime<Utc>,
268
1
        recvtime: Option<DateTime<Utc>>,
269
1
        hostname: String,
270
1
        load_average: LoadAverage,
271
1
        boot_time: DateTime<Utc>,
272
1
        users: Vec<WhodUserEntry>,
273
1
    ) -> Self {
274
1
        Self {
275
1
            sendtime,
276
1
            recvtime,
277
1
            hostname,
278
1
            load_average,
279
1
            boot_time,
280
1
            users,
281
1
        }
282
1
    }
283
}
284

            
285
/// High-level representation of a single user session in a rwhod status update.
286
///
287
/// This struct is intended for easier use in Rust code, with proper types.
288
/// It can be converted to and from the low-level [`Whoent`] struct used for network transmission.
289
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
290
pub struct WhodUserEntry {
291
    /// TTY name (max 8 characters)
292
    pub tty: String,
293

            
294
    /// User ID (max 8 characters)
295
    pub user_id: String,
296

            
297
    /// Time when the user logged in
298
    pub login_time: DateTime<Utc>,
299

            
300
    /// How long since the user last typed on the TTY
301
    pub idle_time: Duration,
302
}
303

            
304
impl WhodUserEntry {
305
2
    pub fn new(
306
2
        tty: String,
307
2
        user_id: String,
308
2
        login_time: DateTime<Utc>,
309
2
        idle_time: Duration,
310
2
    ) -> Self {
311
2
        Self {
312
2
            tty,
313
2
            user_id,
314
2
            login_time,
315
2
            idle_time,
316
2
        }
317
2
    }
318
}
319

            
320
impl TryFrom<Whoent> for WhodUserEntry {
321
    type Error = String;
322

            
323
2
    fn try_from(value: Whoent) -> Result<Self, Self::Error> {
324
2
        let tty_end = value
325
2
            .we_utmp
326
2
            .out_line
327
2
            .iter()
328
10
            .position(|&c| c == 0)
329
2
            .unwrap_or(value.we_utmp.out_line.len());
330
2
        let tty = String::from_utf8(value.we_utmp.out_line[..tty_end].to_vec())
331
2
            .map_err(|e| format!("Invalid UTF-8 in TTY name: {}", e))?;
332

            
333
2
        let user_id_end = value
334
2
            .we_utmp
335
2
            .out_name
336
2
            .iter()
337
12
            .position(|&c| c == 0)
338
2
            .unwrap_or(value.we_utmp.out_name.len());
339
2
        let user_id = String::from_utf8(value.we_utmp.out_name[..user_id_end].to_vec())
340
2
            .map_err(|e| format!("Invalid UTF-8 in user ID: {}", e))?;
341

            
342
2
        let login_time = DateTime::from_timestamp_secs(value.we_utmp.out_time as i64).ok_or(
343
2
            format!("Invalid login time timestamp: {}", value.we_utmp.out_time),
344
        )?;
345

            
346
2
        Ok(WhodUserEntry {
347
2
            tty,
348
2
            user_id,
349
2
            login_time,
350
2
            idle_time: Duration::seconds(value.we_idle as i64),
351
2
        })
352
2
    }
353
}
354

            
355
impl TryFrom<Whod> for WhodStatusUpdate {
356
    type Error = String;
357

            
358
1
    fn try_from(value: Whod) -> Result<Self, Self::Error> {
359
1
        if value.wd_vers != Whod::WHODVERSION {
360
            return Err(format!(
361
                "Unsupported whod protocol version: {}",
362
                value.wd_vers
363
            ));
364
1
        }
365

            
366
1
        let sendtime = DateTime::from_timestamp_secs(value.wd_sendtime as i64).ok_or(format!(
367
            "Invalid send time timestamp: {}",
368
            value.wd_sendtime
369
        ))?;
370

            
371
1
        let recvtime = if value.wd_recvtime == 0 {
372
            None
373
        } else {
374
            Some(
375
1
                DateTime::from_timestamp_secs(value.wd_recvtime as i64).ok_or(format!(
376
                    "Invalid receive time timestamp: {}",
377
                    value.wd_recvtime
378
                ))?,
379
            )
380
        };
381

            
382
1
        let hostname_end = value
383
1
            .wd_hostname
384
1
            .iter()
385
9
            .position(|&c| c == 0)
386
1
            .unwrap_or(value.wd_hostname.len());
387
1
        let hostname = String::from_utf8(value.wd_hostname[..hostname_end].to_vec())
388
1
            .map_err(|e| format!("Invalid UTF-8 in hostname: {}", e))?;
389

            
390
1
        let boot_time = DateTime::from_timestamp_secs(value.wd_boottime as i64).ok_or(format!(
391
            "Invalid boot time timestamp: {}",
392
            value.wd_boottime
393
        ))?;
394

            
395
1
        let users = value
396
1
            .wd_we
397
1
            .iter()
398
3
            .take_while(|whoent| !whoent.is_zeroed())
399
1
            .cloned()
400
1
            .map(WhodUserEntry::try_from)
401
1
            .collect::<Result<Vec<WhodUserEntry>, String>>()?;
402

            
403
1
        Ok(WhodStatusUpdate {
404
1
            sendtime,
405
1
            recvtime,
406
1
            hostname,
407
1
            load_average: value.wd_loadav.into(),
408
1
            boot_time,
409
1
            users,
410
1
        })
411
1
    }
412
}
413

            
414
impl TryFrom<WhodUserEntry> for Whoent {
415
    type Error = String;
416

            
417
2
    fn try_from(value: WhodUserEntry) -> Result<Self, Self::Error> {
418
2
        let mut out_line = [0u8; Outmp::MAX_TTY_NAME_LEN];
419
2
        let tty_bytes = value.tty.as_bytes();
420
2
        out_line[..tty_bytes.len().min(Outmp::MAX_TTY_NAME_LEN)].copy_from_slice(tty_bytes);
421

            
422
2
        let mut out_name = [0u8; Outmp::MAX_USER_ID_LEN];
423
2
        let user_id_bytes = value.user_id.as_bytes();
424
2
        out_name[..user_id_bytes.len().min(Outmp::MAX_USER_ID_LEN)].copy_from_slice(user_id_bytes);
425

            
426
2
        let out_time = value
427
2
            .login_time
428
2
            .timestamp()
429
2
            .clamp(i32::MIN as i64, i32::MAX as i64) as i32;
430

            
431
2
        let we_idle = value
432
2
            .idle_time
433
2
            .num_seconds()
434
2
            .clamp(i32::MIN as i64, i32::MAX as i64) as i32;
435

            
436
2
        Ok(Whoent {
437
2
            we_utmp: Outmp {
438
2
                out_line,
439
2
                out_name,
440
2
                out_time,
441
2
            },
442
2
            we_idle,
443
2
        })
444
2
    }
445
}
446

            
447
impl TryFrom<WhodStatusUpdate> for Whod {
448
    type Error = String;
449

            
450
1
    fn try_from(value: WhodStatusUpdate) -> Result<Self, Self::Error> {
451
1
        let mut wd_hostname = [0u8; Whod::MAX_HOSTNAME_LEN];
452
1
        let hostname_bytes = value.hostname.as_bytes();
453
1
        wd_hostname[..hostname_bytes.len().min(Whod::MAX_HOSTNAME_LEN)]
454
1
            .copy_from_slice(hostname_bytes);
455

            
456
1
        let wd_sendtime = value
457
1
            .sendtime
458
1
            .timestamp()
459
1
            .clamp(i32::MIN as i64, i32::MAX as i64) as i32;
460

            
461
1
        let wd_recvtime = value.recvtime.map_or(0, |dt| {
462
1
            dt.timestamp().clamp(i32::MIN as i64, i32::MAX as i64) as i32
463
1
        });
464

            
465
1
        let wd_boottime = value
466
1
            .boot_time
467
1
            .timestamp()
468
1
            .clamp(i32::MIN as i64, i32::MAX as i64) as i32;
469

            
470
1
        let wd_we = value
471
1
            .users
472
1
            .into_iter()
473
1
            .map(Whoent::try_from)
474
1
            .chain(std::iter::repeat(Ok(Whoent::zeroed())))
475
1
            .take(Whod::MAX_WHOENTRIES)
476
1
            .collect::<Result<Vec<Whoent>, String>>()?
477
1
            .try_into()
478
1
            .expect("Length mismatch, this should never happen");
479

            
480
1
        Ok(Whod {
481
1
            wd_vers: Whod::WHODVERSION,
482
1
            wd_type: Whod::WHODTYPE_STATUS,
483
1
            wd_pad: [0u8; 2],
484
1
            wd_sendtime,
485
1
            wd_recvtime,
486
1
            wd_hostname,
487
1
            wd_loadav: value.load_average.into(),
488
1
            wd_boottime,
489
1
            wd_we,
490
1
        })
491
1
    }
492
}
493

            
494
#[cfg(test)]
495
mod tests {
496
    use super::*;
497
    use chrono::TimeZone;
498

            
499
    #[test]
500
1
    fn test_whod_serialization_roundtrip() {
501
1
        let original_status = WhodStatusUpdate::new(
502
1
            Utc.with_ymd_and_hms(2024, 6, 1, 12, 0, 0).unwrap(),
503
1
            Some(Utc.with_ymd_and_hms(2024, 6, 1, 12, 5, 0).unwrap()),
504
1
            "testhost".to_string(),
505
1
            (25, 20, 18),
506
1
            Utc.with_ymd_and_hms(2024, 5, 31, 8, 0, 0).unwrap(),
507
1
            vec![
508
1
                WhodUserEntry::new(
509
1
                    "tty1".to_string(),
510
1
                    "user1".to_string(),
511
1
                    Utc.with_ymd_and_hms(2024, 6, 1, 10, 0, 0).unwrap(),
512
1
                    Duration::minutes(5),
513
                ),
514
1
                WhodUserEntry::new(
515
1
                    "tty2".to_string(),
516
1
                    "user2".to_string(),
517
1
                    Utc.with_ymd_and_hms(2024, 6, 1, 11, 0, 0).unwrap(),
518
1
                    Duration::minutes(10),
519
                ),
520
            ],
521
        );
522

            
523
1
        let whod_struct =
524
1
            Whod::try_from(original_status.clone()).expect("Conversion to Whod failed");
525
1
        let bytes = whod_struct.to_bytes();
526
1
        let parsed_whod = Whod::from_bytes(&bytes).expect("Parsing from bytes failed");
527
1
        let final_status =
528
1
            WhodStatusUpdate::try_from(parsed_whod).expect("Conversion from Whod failed");
529

            
530
1
        assert_eq!(original_status, final_status);
531
1
    }
532

            
533
    #[test]
534
1
    fn test_parser_invalid_bytes() {
535
        // Too short
536
1
        let short_bytes = vec![0u8; Whod::HEADER_SIZE - 1];
537
1
        assert!(Whod::from_bytes(&short_bytes).is_err());
538

            
539
        // Too long
540
1
        let long_bytes = vec![0u8; Whod::MAX_SIZE + 1];
541
1
        assert!(Whod::from_bytes(&long_bytes).is_err());
542

            
543
        // Misaligned length
544
1
        let misaligned_bytes = vec![0u8; Whod::HEADER_SIZE + 1];
545
1
        assert!(Whod::from_bytes(&misaligned_bytes).is_err());
546

            
547
        // Invalid version
548
1
        let mut invalid_version_bytes = vec![0u8; Whod::HEADER_SIZE];
549
1
        invalid_version_bytes[0] = 99; // invalid version
550
1
        assert!(Whod::from_bytes(&invalid_version_bytes).is_err());
551

            
552
        // Invalid packet type
553
1
        let mut invalid_type_bytes = vec![0u8; Whod::HEADER_SIZE];
554
1
        invalid_type_bytes[0] = Whod::WHODVERSION;
555
1
        invalid_type_bytes[1] = 99; // invalid type
556
1
        assert!(Whod::from_bytes(&invalid_type_bytes).is_err());
557
1
    }
558
}