1
use std::{
2
    fs,
3
    os::unix::fs::FileTypeExt,
4
    path::{Path, PathBuf},
5
    sync::Arc,
6
    time::Duration,
7
};
8

            
9
use anyhow::{Context, anyhow};
10
use clap_verbosity_flag::{InfoLevel, Verbosity};
11
use nix::{
12
    libc::{EXIT_SUCCESS, exit},
13
    unistd::{AccessFlags, access},
14
};
15
use sqlx::mysql::MySqlPoolOptions;
16
use std::os::unix::net::UnixStream as StdUnixStream;
17
use tokio::{net::UnixStream as TokioUnixStream, sync::RwLock};
18
use tracing_subscriber::prelude::*;
19

            
20
use crate::{
21
    core::{
22
        common::{DEFAULT_CONFIG_PATH, DEFAULT_SOCKET_PATH, UnixUser, executing_in_suid_sgid_mode},
23
        protocol::request_validation::GroupDenylist,
24
    },
25
    server::{
26
        authorization::read_and_parse_group_denylist,
27
        config::{MysqlConfig, ServerConfig},
28
        landlock::landlock_restrict_server,
29
        session_handler::{self, SessionId},
30
    },
31
};
32

            
33
/// Determine whether we will make a connection to an external server
34
/// or start an internal server with elevated privileges.
35
///
36
/// If neither is feasible, an error is returned.
37
fn will_connect_to_external_server(
38
    server_socket_path: Option<&PathBuf>,
39
    // This parameter is only used in suid-sgid-mode
40
    #[allow(unused_variables)] config_path: Option<&PathBuf>,
41
) -> anyhow::Result<bool> {
42
    if server_socket_path.is_some() {
43
        return Ok(true);
44
    }
45

            
46
    #[cfg(feature = "suid-sgid-mode")]
47
    if config_path.is_some() {
48
        return Ok(false);
49
    }
50

            
51
    if fs::metadata(DEFAULT_SOCKET_PATH).is_ok() {
52
        return Ok(true);
53
    }
54

            
55
    #[cfg(feature = "suid-sgid-mode")]
56
    if fs::metadata(DEFAULT_CONFIG_PATH).is_ok() {
57
        return Ok(false);
58
    }
59

            
60
    #[cfg(feature = "suid-sgid-mode")]
61
    anyhow::bail!("No socket path or config path provided, and no default socket or config found");
62

            
63
    #[cfg(not(feature = "suid-sgid-mode"))]
64
    anyhow::bail!("No socket path provided, and no default socket found");
65
}
66

            
67
/// This function is used to bootstrap the connection to the server.
68
/// This can happen in two ways:
69
///
70
/// 1. If a socket path is provided, or exists in the default location,
71
///    the function will connect to the socket and authenticate with the
72
///    server to ensure that the server knows the uid of the client.
73
///
74
/// 2. If a config path is provided, or exists in the default location,
75
///    and the config is readable, the function will assume it is either
76
///    setuid or setgid, and will fork a child process to run the server
77
///    with the provided config. The server will exit silently by itself
78
///    when it is done, and this function will only return for the client
79
///    with the socket for the server.
80
///
81
/// If neither of these options are available, the function will fail.
82
///
83
/// Note that this function is also responsible for setting up logging,
84
/// because in the case of an internal server, we need to drop privileges
85
/// before we can initialize logging.
86
///
87
/// **WARNING:** This function may be run with elevated privileges.
88
pub fn bootstrap_server_connection_and_drop_privileges(
89
    server_socket_path: Option<PathBuf>,
90
    config: Option<PathBuf>,
91
    verbose: Verbosity<InfoLevel>,
92
) -> anyhow::Result<StdUnixStream> {
93
    if will_connect_to_external_server(server_socket_path.as_ref(), config.as_ref())? {
94
        assert!(
95
            !executing_in_suid_sgid_mode()?,
96
            "The executable should not be SUID or SGID when connecting to an external server"
97
        );
98

            
99
        let subscriber = tracing_subscriber::Registry::default()
100
            .with(verbose.tracing_level_filter())
101
            .with(
102
                tracing_subscriber::fmt::layer()
103
                    .with_line_number(cfg!(debug_assertions))
104
                    .with_target(cfg!(debug_assertions))
105
                    .with_thread_ids(false)
106
                    .with_thread_names(false),
107
            );
108

            
109
        tracing::subscriber::set_global_default(subscriber)
110
            .context("Failed to set global default tracing subscriber")?;
111

            
112
        connect_to_external_server(server_socket_path)
113
    } else if cfg!(feature = "suid-sgid-mode") {
114
        // NOTE: We need to be really careful with the code up until this point,
115
        //       as we might be running with elevated privileges.
116
        let server_connection = bootstrap_internal_server_and_drop_privs(config)?;
117

            
118
        let subscriber = tracing_subscriber::Registry::default()
119
            .with(verbose.tracing_level_filter())
120
            .with(
121
                tracing_subscriber::fmt::layer()
122
                    .with_line_number(cfg!(debug_assertions))
123
                    .with_target(cfg!(debug_assertions))
124
                    .with_thread_ids(false)
125
                    .with_thread_names(false),
126
            );
127

            
128
        tracing::subscriber::set_global_default(subscriber)
129
            .context("Failed to set global default tracing subscriber")?;
130

            
131
        Ok(server_connection)
132
    } else {
133
        anyhow::bail!("SUID/SGID support is not enabled, cannot start internal server");
134
    }
135
}
136

            
137
fn socket_path_is_ok(path: &Path) -> anyhow::Result<()> {
138
    fs::metadata(path)
139
        .context(format!("Failed to get metadata for {:?}", path))
140
        .and_then(|meta| {
141
            if !meta.file_type().is_socket() {
142
                anyhow::bail!("{:?} is not a unix socket", path);
143
            }
144

            
145
            access(path, AccessFlags::R_OK | AccessFlags::W_OK)
146
                .with_context(|| format!("Socket at {:?} is not readable/writable", path))?;
147

            
148
            Ok(())
149
        })
150
}
151

            
152
fn connect_to_external_server(
153
    server_socket_path: Option<PathBuf>,
154
) -> anyhow::Result<StdUnixStream> {
155
    if let Some(socket_path) = server_socket_path {
156
        tracing::trace!("Checking socket at {:?}", socket_path);
157
        socket_path_is_ok(&socket_path)?;
158

            
159
        tracing::debug!("Connecting to socket at {:?}", socket_path);
160
        return match StdUnixStream::connect(socket_path) {
161
            Ok(socket) => Ok(socket),
162
            Err(e) => match e.kind() {
163
                std::io::ErrorKind::NotFound => Err(anyhow::anyhow!("Socket not found")),
164
                std::io::ErrorKind::PermissionDenied => Err(anyhow::anyhow!("Permission denied")),
165
                _ => Err(anyhow::anyhow!("Failed to connect to socket: {e}")),
166
            },
167
        };
168
    }
169

            
170
    if fs::metadata(DEFAULT_SOCKET_PATH).is_ok() {
171
        tracing::trace!("Checking socket at {:?}", DEFAULT_SOCKET_PATH);
172
        socket_path_is_ok(Path::new(DEFAULT_SOCKET_PATH))?;
173

            
174
        tracing::debug!("Connecting to default socket at {:?}", DEFAULT_SOCKET_PATH);
175
        return match StdUnixStream::connect(DEFAULT_SOCKET_PATH) {
176
            Ok(socket) => Ok(socket),
177
            Err(e) => match e.kind() {
178
                std::io::ErrorKind::NotFound => Err(anyhow::anyhow!("Socket not found")),
179
                std::io::ErrorKind::PermissionDenied => Err(anyhow::anyhow!("Permission denied")),
180
                _ => Err(anyhow::anyhow!("Failed to connect to socket: {e}")),
181
            },
182
        };
183
    }
184

            
185
    anyhow::bail!(
186
        "No socket path provided, and no socket found found at default location {DEFAULT_SOCKET_PATH}"
187
    );
188
}
189

            
190
// TODO: this function is security critical, it should be integration tested
191
//       in isolation.
192
/// Drop privileges to the real user and group of the process.
193
/// If the process is not running with elevated privileges, this function
194
/// is a no-op.
195
pub fn drop_privs() -> anyhow::Result<()> {
196
    tracing::debug!("Dropping privileges");
197
    let real_uid = nix::unistd::getuid();
198
    let real_gid = nix::unistd::getgid();
199

            
200
    nix::unistd::setuid(real_uid).context("Failed to drop privileges")?;
201
    nix::unistd::setgid(real_gid).context("Failed to drop privileges")?;
202

            
203
    debug_assert_eq!(nix::unistd::getuid(), real_uid);
204
    debug_assert_eq!(nix::unistd::getgid(), real_gid);
205

            
206
    tracing::debug!("Privileges dropped successfully");
207
    Ok(())
208
}
209

            
210
/// Bootstrap an internal server by forking a child process to run the server, giving it
211
/// the other half of a Unix socket pair to communicate with the client process.
212
fn bootstrap_internal_server_and_drop_privs(
213
    config_path: Option<PathBuf>,
214
) -> anyhow::Result<StdUnixStream> {
215
    if let Some(config_path) = config_path {
216
        if !executing_in_suid_sgid_mode()? {
217
            anyhow::bail!("Executable is not SUID/SGID - refusing to start internal sever");
218
        }
219

            
220
        // ensure config exists and is readable
221
        if fs::metadata(&config_path).is_err() {
222
            return Err(anyhow::anyhow!("Config file not found or not readable"));
223
        }
224

            
225
        tracing::debug!("Starting server with config at {:?}", config_path);
226
        let socket = invoke_server_with_config(&config_path)?;
227
        drop_privs()?;
228
        return Ok(socket);
229
    }
230

            
231
    let config_path = PathBuf::from(DEFAULT_CONFIG_PATH);
232
    if fs::metadata(&config_path).is_ok() {
233
        if !executing_in_suid_sgid_mode()? {
234
            anyhow::bail!("Executable is not SUID/SGID - refusing to start internal sever");
235
        }
236
        tracing::debug!("Starting server with default config at {:?}", config_path);
237
        let socket = invoke_server_with_config(&config_path)?;
238
        drop_privs()?;
239
        return Ok(socket);
240
    }
241

            
242
    anyhow::bail!("No config path provided, and no default config found");
243
}
244

            
245
// TODO: we should somehow ensure that the forked process is killed on completion,
246
//       just in case the client does not behave properly.
247
/// Fork a child process to run the server with the provided config.
248
/// The server will exit silently by itself when it is done, and this function
249
/// will only return for the client with the socket for the server.
250
fn invoke_server_with_config(config_path: &Path) -> anyhow::Result<StdUnixStream> {
251
    let (server_socket, client_socket) = StdUnixStream::pair()?;
252
    let unix_user = UnixUser::from_uid(nix::unistd::getuid().as_raw())?;
253

            
254
    match (unsafe { nix::unistd::fork() }).context("Failed to fork")? {
255
        nix::unistd::ForkResult::Parent { child } => {
256
            tracing::debug!("Forked child process with PID {}", child);
257
            Ok(client_socket)
258
        }
259
        nix::unistd::ForkResult::Child => {
260
            tracing::debug!("Running server in child process");
261

            
262
            landlock_restrict_server(Some(config_path))
263
                .context("Failed to apply Landlock restrictions to the server process")?;
264

            
265
            match run_forked_server(config_path, server_socket, &unix_user) {
266
                Err(e) => Err(e),
267
                Ok(()) => unreachable!(),
268
            }
269
        }
270
    }
271
}
272

            
273
/// Construct a `MySQL` connection pool that consists of exactly one connection.
274
///
275
/// This is used for the internal server in SUID/SGID mode, where the server session
276
/// only ever will get a single client.
277
async fn construct_single_connection_mysql_pool(
278
    config: &MysqlConfig,
279
) -> anyhow::Result<sqlx::MySqlPool> {
280
    let mysql_config = config.as_mysql_connect_options()?;
281

            
282
    let pool_opts = MySqlPoolOptions::new()
283
        .max_connections(1)
284
        .min_connections(1);
285

            
286
    config.log_connection_notice();
287

            
288
    let pool = match tokio::time::timeout(
289
        Duration::from_secs(config.timeout),
290
        pool_opts.connect_with(mysql_config),
291
    )
292
    .await
293
    {
294
        Ok(connection) => connection.context("Failed to connect to the database"),
295
        Err(_) => Err(anyhow!("Timed out after {} seconds", config.timeout))
296
            .context("Failed to connect to the database"),
297
    }?;
298

            
299
    Ok(pool)
300
}
301

            
302
/// Run a single server session in the forked process.
303
///
304
/// This function will not return, but will exit the process with a success code.
305
/// The function assumes that it's caller has already forked the process.
306
fn run_forked_server(
307
    config_path: &Path,
308
    server_socket: StdUnixStream,
309
    unix_user: &UnixUser,
310
) -> anyhow::Result<()> {
311
    let config = ServerConfig::read_config_from_path(config_path)
312
        .context("Failed to read server config in forked process")?;
313

            
314
    let group_denylist = if let Some(denylist_path) = &config.authorization.group_denylist_file {
315
        read_and_parse_group_denylist(denylist_path)
316
            .context("Failed to read and parse group denylist")?
317
    } else {
318
        GroupDenylist::new()
319
    };
320

            
321
    let result: anyhow::Result<()> = tokio::runtime::Builder::new_current_thread()
322
        .enable_all()
323
        .build()
324
        .context("Failed to start Tokio runtime")?
325
        .block_on(async {
326
            let socket = TokioUnixStream::from_std(server_socket)?;
327
            let db_pool = construct_single_connection_mysql_pool(&config.mysql).await?;
328
            let db_is_mariadb = {
329
                let mut conn = db_pool.acquire().await?;
330
                let version_row: String = sqlx::query_scalar("SELECT VERSION()")
331
                    .fetch_one(&mut *conn)
332
                    .await
333
                    .context("Failed to query MySQL version")?;
334
                version_row.to_lowercase().contains("mariadb")
335
            };
336

            
337
            let session_id = SessionId::new(0);
338
            let db_pool = Arc::new(RwLock::new(db_pool));
339
            session_handler::session_handler_with_unix_user(
340
                socket,
341
                session_id,
342
                unix_user,
343
                db_pool,
344
                db_is_mariadb,
345
                &group_denylist,
346
            )
347
            .await?;
348
            Ok(())
349
        });
350

            
351
    result?;
352

            
353
    unsafe {
354
        exit(EXIT_SUCCESS);
355
    }
356
}