1
use std::{os::fd::OwnedFd, time::Duration};
2

            
3
use anyhow::Context;
4
use itertools::Itertools;
5
use serde::{Deserialize, Serialize};
6
use tokio::time::timeout;
7
use zlink::{ReplyError, service::MethodReply};
8

            
9
use crate::{
10
    proto::{WhodStatusUpdate, WhodUserEntry, finger_protocol::FingerResponseUserEntry},
11
    server::{
12
        fingerd::{self, FingerRequestInfo, FingerRequestNetworking, finger_utmp_users},
13
        rwhod::RwhodStatusStore,
14
    },
15
};
16

            
17
// Types for 'no.ntnu.pvv.roowho2.rwhod'
18

            
19
#[zlink::proxy("no.ntnu.pvv.roowho2.rwhod")]
20
pub trait VarlinkRwhodClientProxy {
21
    async fn rwho(
22
        &mut self,
23
        all: bool,
24
    ) -> zlink::Result<Result<VarlinkRwhoResponse, VarlinkRwhodClientError>>;
25

            
26
    async fn ruptime(
27
        &mut self,
28
    ) -> zlink::Result<Result<VarlinkRuptimeResponse, VarlinkRwhodClientError>>;
29
}
30

            
31
#[derive(Debug, Deserialize)]
32
#[serde(tag = "method", content = "parameters")]
33
pub enum VarlinkRwhodClientRequest {
34
    #[serde(rename = "no.ntnu.pvv.roowho2.rwhod.Rwho")]
35
    Rwho {
36
        /// Retrieve all users, even those that have been idle for a long time.
37
        all: bool,
38
    },
39

            
40
    #[serde(rename = "no.ntnu.pvv.roowho2.rwhod.Ruptime")]
41
    Ruptime,
42
}
43

            
44
#[derive(Debug, Clone, PartialEq, Serialize)]
45
#[serde(untagged)]
46
pub enum VarlinkRwhodClientResponse {
47
    Rwho(VarlinkRwhoResponse),
48
    Ruptime(VarlinkRuptimeResponse),
49
}
50

            
51
pub type VarlinkRwhoResponse = Vec<(String, WhodUserEntry)>;
52
pub type VarlinkRuptimeResponse = Vec<WhodStatusUpdate>;
53

            
54
#[derive(Debug, Clone, PartialEq, ReplyError)]
55
#[zlink(interface = "no.ntnu.pvv.roowho2.rwhod")]
56
pub enum VarlinkRwhodClientError {
57
    InvalidRequest,
58
    TimedOut,
59
}
60

            
61
// Types for 'no.ntnu.pvv.roowho2.finger'
62

            
63
#[zlink::proxy("no.ntnu.pvv.roowho2.finger")]
64
pub trait VarlinkFingerClientProxy {
65
    async fn finger(
66
        &mut self,
67
        user_queries: Option<Vec<String>>,
68
        match_fullnames: bool,
69
        request_info: FingerRequestInfo,
70
        request_networking: FingerRequestNetworking,
71
        disable_user_account_db: bool,
72
        raw_remote_output: bool,
73
    ) -> zlink::Result<Result<VarlinkFingerResponse, VarlinkFingerClientError>>;
74
}
75

            
76
#[derive(Debug, Deserialize)]
77
#[serde(tag = "method", content = "parameters")]
78
pub enum VarlinkFingerClientRequest {
79
    #[serde(rename = "no.ntnu.pvv.roowho2.finger.Finger")]
80
    Finger {
81
        user_queries: Option<Vec<String>>,
82
        match_fullnames: bool,
83
        request_info: FingerRequestInfo,
84
        request_networking: FingerRequestNetworking,
85
        disable_user_account_db: bool,
86
        raw_remote_output: bool,
87
    },
88
}
89

            
90
#[derive(Debug, Serialize)]
91
#[serde(untagged)]
92
pub enum VarlinkFingerClientResponse {
93
    Finger(VarlinkFingerResponse),
94
}
95

            
96
pub type VarlinkFingerResponse = Vec<FingerResponseUserEntry>;
97

            
98
#[derive(Debug, Clone, PartialEq, ReplyError)]
99
#[zlink(interface = "no.ntnu.pvv.roowho2.finger")]
100
pub enum VarlinkFingerClientError {
101
    InvalidRequest,
102
    TimedOut,
103
}
104

            
105
// --------------------
106

            
107
#[derive(Debug, Deserialize)]
108
#[serde(untagged)]
109
#[allow(unused)]
110
pub enum VarlinkMethod {
111
    Rwhod(VarlinkRwhodClientRequest),
112
    Finger(VarlinkFingerClientRequest),
113
}
114

            
115
#[derive(Debug, Serialize)]
116
#[serde(untagged)]
117
#[allow(unused)]
118
pub enum VarlinkReply {
119
    Rwhod(VarlinkRwhodClientResponse),
120
    Finger(VarlinkFingerClientResponse),
121
}
122

            
123
#[derive(Debug, Clone, PartialEq, Serialize)]
124
#[serde(untagged)]
125
#[allow(unused)]
126
pub enum VarlinkReplyError {
127
    Rwhod(VarlinkRwhodClientError),
128
    Finger(VarlinkFingerClientError),
129
}
130

            
131
#[derive(Debug, Clone)]
132
pub struct VarlinkRoowhoo2ClientServer {
133
    whod_status_store: RwhodStatusStore,
134
}
135

            
136
impl VarlinkRoowhoo2ClientServer {
137
    pub fn new(whod_status_store: RwhodStatusStore) -> Self {
138
        Self { whod_status_store }
139
    }
140
}
141

            
142
impl VarlinkRoowhoo2ClientServer {
143
    // TODO: handle 'all' parameter
144
    async fn handle_rwho_request(&self, _all: bool) -> VarlinkRwhoResponse {
145
        tracing::debug!(all = _all, "Handling Rwho request");
146
        let store = self.whod_status_store.read().await;
147

            
148
        let mut all_user_entries = Vec::with_capacity(store.len());
149
        for status_update in store.values() {
150
            all_user_entries.extend_from_slice(
151
                &status_update
152
                    .users
153
                    .iter()
154
                    .map(|user| (status_update.hostname.clone(), user.clone()))
155
                    .collect::<Vec<(String, WhodUserEntry)>>(),
156
            );
157
        }
158

            
159
        all_user_entries
160
    }
161

            
162
    async fn handle_ruptime_request(&self) -> VarlinkRuptimeResponse {
163
        tracing::debug!("Handling Ruptime request");
164
        let store = self.whod_status_store.read().await;
165
        store.values().cloned().collect()
166
    }
167

            
168
    async fn handle_finger_request(
169
        &self,
170
        user_queries: Option<Vec<String>>,
171
        match_fullnames: bool,
172
        request_info: FingerRequestInfo,
173
        _request_networking: FingerRequestNetworking,
174
        _disable_user_account_db: bool,
175
        _raw_remote_output: bool,
176
    ) -> VarlinkFingerResponse {
177
        tracing::debug!(
178
          user_queries = ?user_queries,
179
          match_fullnames = match_fullnames,
180
          request_info = ?request_info,
181
          "Handling Finger request",
182
        );
183
        match user_queries {
184
            Some(usernames) => usernames
185
                .into_iter()
186
                .flat_map::<Vec<_>, _>(|username| {
187
                    fingerd::search_for_user(&username, match_fullnames, &request_info)
188
                        .into_iter()
189
                        .map(|res| (username.clone(), res))
190
                        .collect()
191
                })
192
                .dedup_by(|a, b| match (&a.1, &b.1) {
193
                    (Ok(user_a), Ok(user_b)) => user_a.username == user_b.username,
194
                    _ => false,
195
                })
196
                .filter_map(|(username, user)| match user {
197
                    Ok(user_info) => Some(user_info),
198
                    Err(err) => {
199
                        tracing::error!(
200
                            "Error retrieving local user information for '{}': {}",
201
                            username,
202
                            err
203
                        );
204
                        None
205
                    }
206
                })
207
                .map(Box::new)
208
                .map(FingerResponseUserEntry::Structured)
209
                .collect(),
210
            None => finger_utmp_users(&request_info)
211
                .into_iter()
212
                .filter_map(|res| match res {
213
                    Ok(user_info) => Some(user_info),
214
                    Err(err) => {
215
                        tracing::error!("Error retrieving local user information: {}", err);
216
                        None
217
                    }
218
                })
219
                .map(Box::new)
220
                .map(FingerResponseUserEntry::Structured)
221
                .collect(),
222
        }
223
    }
224
}
225

            
226
impl zlink::Service<zlink::unix::Stream> for VarlinkRoowhoo2ClientServer {
227
    type MethodCall<'de> = VarlinkMethod;
228
    type ReplyParams<'se> = VarlinkReply;
229
    type ReplyStreamParams = ();
230
    type ReplyStream = futures_util::stream::Empty<(zlink::Reply<()>, Vec<OwnedFd>)>;
231
    type ReplyError<'se> = VarlinkReplyError;
232

            
233
    async fn handle<'service>(
234
        &'service mut self,
235
        call: &'service zlink::Call<Self::MethodCall<'_>>,
236
        _conn: &mut zlink::Connection<zlink::unix::Stream>,
237
        _fds: Vec<std::os::fd::OwnedFd>,
238
    ) -> zlink::service::HandleResult<
239
        Self::ReplyParams<'service>,
240
        Self::ReplyStream,
241
        Self::ReplyError<'service>,
242
    > {
243
        match call.method() {
244
            VarlinkMethod::Rwhod(VarlinkRwhodClientRequest::Rwho { all }) => {
245
                let result =
246
                    match timeout(Duration::from_secs(2), self.handle_rwho_request(*all)).await {
247
                        Ok(response) => response,
248
                        Err(_) => {
249
                            tracing::error!("Rwho request timed out after 2 seconds");
250
                            return (
251
                                MethodReply::Error(VarlinkReplyError::Rwhod(
252
                                    VarlinkRwhodClientError::TimedOut,
253
                                )),
254
                                Default::default(),
255
                            );
256
                        }
257
                    };
258

            
259
                (
260
                    MethodReply::Single(Some(VarlinkReply::Rwhod(
261
                        VarlinkRwhodClientResponse::Rwho(result),
262
                    ))),
263
                    Default::default(),
264
                )
265
            }
266
            VarlinkMethod::Rwhod(VarlinkRwhodClientRequest::Ruptime) => {
267
                let result =
268
                    match timeout(Duration::from_secs(2), self.handle_ruptime_request()).await {
269
                        Ok(response) => response,
270
                        Err(_) => {
271
                            tracing::error!("Ruptime request timed out after 2 seconds");
272
                            return (
273
                                MethodReply::Error(VarlinkReplyError::Rwhod(
274
                                    VarlinkRwhodClientError::TimedOut,
275
                                )),
276
                                Default::default(),
277
                            );
278
                        }
279
                    };
280

            
281
                (
282
                    MethodReply::Single(Some(VarlinkReply::Rwhod(
283
                        VarlinkRwhodClientResponse::Ruptime(result),
284
                    ))),
285
                    Default::default(),
286
                )
287
            }
288
            VarlinkMethod::Finger(VarlinkFingerClientRequest::Finger {
289
                user_queries,
290
                match_fullnames,
291
                request_info,
292
                request_networking,
293
                disable_user_account_db,
294
                raw_remote_output,
295
            }) => {
296
                let result = match timeout(
297
                    Duration::from_secs(2),
298
                    self.handle_finger_request(
299
                        user_queries.clone(),
300
                        *match_fullnames,
301
                        request_info.clone(),
302
                        request_networking.clone(),
303
                        *disable_user_account_db,
304
                        *raw_remote_output,
305
                    ),
306
                )
307
                .await
308
                {
309
                    Ok(response) => response,
310
                    Err(_) => {
311
                        tracing::error!("Finger request timed out after 2 seconds");
312
                        return (
313
                            MethodReply::Error(VarlinkReplyError::Finger(
314
                                VarlinkFingerClientError::TimedOut,
315
                            )),
316
                            Default::default(),
317
                        );
318
                    }
319
                };
320

            
321
                (
322
                    MethodReply::Single(Some(VarlinkReply::Finger(
323
                        VarlinkFingerClientResponse::Finger(result),
324
                    ))),
325
                    Default::default(),
326
                )
327
            }
328
        }
329
    }
330
}
331

            
332
pub async fn varlink_client_server_task(
333
    socket: zlink::unix::Listener,
334
    whod_status_store: RwhodStatusStore,
335
) -> anyhow::Result<()> {
336
    let service = VarlinkRoowhoo2ClientServer::new(whod_status_store);
337

            
338
    let server = zlink::Server::new(socket, service);
339

            
340
    tracing::info!("Starting Rwhod client API server");
341

            
342
    server
343
        .run()
344
        .await
345
        .context("Rwhod client API server failed")?;
346

            
347
    Ok(())
348
}