Lines
0 %
Functions
use std::{
collections::BTreeSet,
fs,
os::unix::{io::FromRawFd, net::UnixListener as StdUnixListener},
path::PathBuf,
sync::Arc,
time::Duration,
};
use anyhow::Context;
use futures_util::{SinkExt, StreamExt};
use indoc::concatdoc;
use tokio::{
net::{UnixListener as TokioUnixListener, UnixStream as TokioUnixStream},
time::interval,
use sqlx::MySqlConnection;
use sqlx::prelude::*;
use crate::core::protocol::SetPasswordError;
use crate::server::sql::database_operations::list_databases;
use crate::{
core::{
common::{DEFAULT_SOCKET_PATH, UnixUser},
protocol::{
Request, Response, ServerToClientMessageStream, create_server_to_client_message_stream,
},
server::{
config::{ServerConfig, create_mysql_connection_from_config},
sql::{
database_operations::{create_databases, drop_databases, list_all_databases_for_user},
database_privilege_operations::{
apply_privilege_diffs, get_all_database_privileges, get_databases_privilege_data,
user_operations::{
create_database_users, drop_database_users, list_all_database_users_for_unix_user,
list_database_users, lock_database_users, set_password_for_database_user,
unlock_database_users,
// TODO: consider using a connection pool
pub async fn listen_for_incoming_connections_with_socket_path(
socket_path: Option<PathBuf>,
config: ServerConfig,
) -> anyhow::Result<()> {
let socket_path = socket_path.unwrap_or(PathBuf::from(DEFAULT_SOCKET_PATH));
let parent_directory = socket_path.parent().unwrap();
if !parent_directory.exists() {
log::debug!("Creating directory {:?}", parent_directory);
fs::create_dir_all(parent_directory)?;
}
log::info!("Listening on socket {:?}", socket_path);
match fs::remove_file(socket_path.as_path()) {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
Err(e) => return Err(e.into()),
let listener = TokioUnixListener::bind(socket_path)?;
listen_for_incoming_connections_with_listener(listener, config).await
pub async fn listen_for_incoming_connections_with_systemd_socket(
let fd = sd_notify::listen_fds()
.context("Failed to get file descriptors from systemd")?
.next()
.context("No file descriptors received from systemd")?;
debug_assert!(fd == 3, "Unexpected file descriptor from systemd: {}", fd);
log::debug!(
"Received file descriptor from systemd with id: '{}', assuming socket",
fd
);
let std_unix_listener = unsafe { StdUnixListener::from_raw_fd(fd) };
let listener = TokioUnixListener::from_std(std_unix_listener)?;
pub async fn listen_for_incoming_connections_with_listener(
listener: TokioUnixListener,
let connection_counter = Arc::new(());
let connection_counter_for_log = Arc::clone(&connection_counter);
tokio::spawn(async move {
let mut interval = interval(Duration::from_secs(1));
loop {
interval.tick().await;
let count = Arc::strong_count(&connection_counter_for_log) - 2;
let message = if count > 0 {
format!("Handling {} connections", count)
} else {
"Waiting for connections".to_string()
sd_notify::notify(false, &[sd_notify::NotifyState::Status(message.as_str())]).ok();
});
sd_notify::notify(false, &[sd_notify::NotifyState::Ready]).ok();
while let Ok((conn, _addr)) = listener.accept().await {
let uid = match conn.peer_cred() {
Ok(cred) => cred.uid(),
Err(e) => {
log::error!("Failed to get peer credentials from socket: {}", e);
let mut message_stream = create_server_to_client_message_stream(conn);
message_stream
.send(Response::Error(
(concatdoc! {
"Server failed to get peer credentials from socket\n",
"Please check the server logs or contact the system administrators"
})
.to_string(),
))
.await
.ok();
continue;
let _connection_counter_guard = Arc::clone(&connection_counter);
log::debug!("Accepted connection from uid {}", uid);
let unix_user = match UnixUser::from_uid(uid) {
Ok(user) => user,
log::error!("Failed to get username from uid: {}", e);
"Server failed to get user data from the system\n",
log::info!("Accepted connection from {}", unix_user.username);
match handle_requests_for_single_session(conn, &unix_user, &config).await {
Ok(()) => {}
log::error!("Failed to run server: {}", e);
Ok(())
async fn close_or_ignore_db_connection(db_connection: MySqlConnection) {
if let Err(e) = db_connection.close().await {
log::error!("Failed to close database connection: {}", e);
log::error!("{}", e);
log::error!("Ignoring...");
pub async fn handle_requests_for_single_session(
socket: TokioUnixStream,
unix_user: &UnixUser,
config: &ServerConfig,
let mut message_stream = create_server_to_client_message_stream(socket);
log::debug!("Opening connection to database");
let mut db_connection = match create_mysql_connection_from_config(&config.mysql).await {
Ok(connection) => connection,
Err(err) => {
"Server failed to connect to database\n",
.await?;
message_stream.flush().await?;
return Err(err);
log::debug!("Verifying that database connection is valid");
if let Err(e) = db_connection.ping().await {
log::error!("Failed to ping database: {}", e);
close_or_ignore_db_connection(db_connection).await;
return Err(e.into());
log::debug!("Successfully connected to database");
let result = handle_requests_for_single_session_with_db_connection(
message_stream,
unix_user,
&mut db_connection,
)
.await;
result
// TODO: ensure proper db_connection hygiene for functions that invoke
// this function
async fn handle_requests_for_single_session_with_db_connection(
mut stream: ServerToClientMessageStream,
db_connection: &mut MySqlConnection,
stream.send(Response::Ready).await?;
// TODO: better error handling
let request = match stream.next().await {
Some(Ok(request)) => request,
Some(Err(e)) => return Err(e.into()),
None => {
log::warn!("Client disconnected without sending an exit message");
break;
// TODO: don't clone the request
let request_to_display = match &request {
Request::PasswdUser((db_user, _)) => {
Request::PasswdUser((db_user.to_owned(), "<REDACTED>".to_string()))
request => request.to_owned(),
log::info!("Received request: {:#?}", request_to_display);
let response = match request {
Request::CreateDatabases(databases_names) => {
let result = create_databases(databases_names, unix_user, db_connection).await;
Response::CreateDatabases(result)
Request::DropDatabases(databases_names) => {
let result = drop_databases(databases_names, unix_user, db_connection).await;
Response::DropDatabases(result)
Request::ListDatabases(database_names) => match database_names {
Some(database_names) => {
let result = list_databases(database_names, unix_user, db_connection).await;
Response::ListDatabases(result)
let result = list_all_databases_for_user(unix_user, db_connection).await;
Response::ListAllDatabases(result)
Request::ListPrivileges(database_names) => match database_names {
let privilege_data =
get_databases_privilege_data(database_names, unix_user, db_connection)
Response::ListPrivileges(privilege_data)
get_all_database_privileges(unix_user, db_connection).await;
Response::ListAllPrivileges(privilege_data)
Request::ModifyPrivileges(database_privilege_diffs) => {
let result = apply_privilege_diffs(
BTreeSet::from_iter(database_privilege_diffs),
db_connection,
Response::ModifyPrivileges(result)
Request::CreateUsers(db_users) => {
let result = create_database_users(db_users, unix_user, db_connection).await;
Response::CreateUsers(result)
Request::DropUsers(db_users) => {
let result = drop_database_users(db_users, unix_user, db_connection).await;
Response::DropUsers(result)
Request::PasswdUser((db_user, password)) => {
let result =
set_password_for_database_user(&db_user, &password, unix_user, db_connection)
Response::SetUserPassword(result)
Request::ListUsers(db_users) => match db_users {
Some(db_users) => {
let result = list_database_users(db_users, unix_user, db_connection).await;
Response::ListUsers(result)
list_all_database_users_for_unix_user(unix_user, db_connection).await;
Response::ListAllUsers(result)
Request::LockUsers(db_users) => {
let result = lock_database_users(db_users, unix_user, db_connection).await;
Response::LockUsers(result)
Request::UnlockUsers(db_users) => {
let result = unlock_database_users(db_users, unix_user, db_connection).await;
Response::UnlockUsers(result)
Request::Exit => {
// TODO: don't clone the response
let response_to_display = match &response {
Response::SetUserPassword(Err(SetPasswordError::MySqlError(_))) => {
Response::SetUserPassword(Err(SetPasswordError::MySqlError(
"<REDACTED>".to_string(),
)))
response => response.to_owned(),
log::info!("Response: {:#?}", response_to_display);
stream.send(response).await?;
stream.flush().await?;
log::debug!("Successfully processed request");