Skip to main content

zlink_tokio/unix/
stream.rs

1use crate::{
2    Result,
3    connection::socket::{self, Socket},
4};
5use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
6use tokio::net::{UnixStream, unix};
7
8/// The connection type that uses Unix Domain Sockets for transport.
9pub type Connection = crate::Connection<Stream>;
10
11/// Connect to Unix Domain Socket at the given path.
12pub async fn connect<P>(path: P) -> Result<Connection>
13where
14    P: AsRef<std::path::Path>,
15{
16    UnixStream::connect(path)
17        .await
18        .map(Stream)
19        .map(Connection::new)
20        .map_err(Into::into)
21}
22
23/// The [`Socket`] implementation using Unix Domain Sockets.
24#[derive(Debug)]
25pub struct Stream(UnixStream);
26
27impl Socket for Stream {
28    type ReadHalf = ReadHalf;
29    type WriteHalf = WriteHalf;
30
31    const CAN_TRANSFER_FDS: bool = true;
32
33    fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
34        let (read, write) = self.0.into_split();
35
36        (ReadHalf(read), WriteHalf(write))
37    }
38}
39
40impl From<UnixStream> for Stream {
41    fn from(stream: UnixStream) -> Self {
42        Self(stream)
43    }
44}
45
46impl socket::UnixSocket for Stream {}
47
48impl AsFd for Stream {
49    fn as_fd(&self) -> BorrowedFd<'_> {
50        self.0.as_fd()
51    }
52}
53
54/// The [`ReadHalf`] implementation using Unix Domain Sockets.
55#[derive(Debug)]
56pub struct ReadHalf(unix::OwnedReadHalf);
57
58impl socket::ReadHalf for ReadHalf {
59    async fn read(&mut self, buf: &mut [u8]) -> Result<(usize, Vec<OwnedFd>)> {
60        use std::{future::poll_fn, task::Poll};
61
62        poll_fn(|cx| {
63            loop {
64                let stream: &UnixStream = self.0.as_ref();
65                match stream.try_io(tokio::io::Interest::READABLE, || {
66                    crate::unix_utils::recvmsg(stream, buf)
67                }) {
68                    Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
69                        match stream.poll_read_ready(cx) {
70                            Poll::Pending => return Poll::Pending,
71                            Poll::Ready(res) => res?,
72                        }
73                    }
74                    v => return Poll::Ready(v.map_err(Into::into)),
75                }
76            }
77        })
78        .await
79    }
80}
81
82impl AsFd for ReadHalf {
83    fn as_fd(&self) -> BorrowedFd<'_> {
84        let stream: &UnixStream = self.0.as_ref();
85        stream.as_fd()
86    }
87}
88
89impl socket::UnixSocket for ReadHalf {}
90
91/// The [`WriteHalf`] implementation using Unix Domain Sockets.
92#[derive(Debug)]
93pub struct WriteHalf(unix::OwnedWriteHalf);
94
95impl socket::WriteHalf for WriteHalf {
96    async fn write(&mut self, buf: &[u8], fds: &[impl AsFd]) -> Result<()> {
97        use std::{future::poll_fn, task::Poll};
98
99        // Convert to BorrowedFd for rustix.
100        let borrowed_fds: Vec<BorrowedFd<'_>> = fds.iter().map(|f| f.as_fd()).collect();
101
102        let mut pos = 0;
103        while pos < buf.len() {
104            // Use FDs on first write, empty slice on subsequent writes.
105            let fds_to_send = if pos == 0 { &borrowed_fds[..] } else { &[] };
106
107            let n: usize = poll_fn(|cx| {
108                loop {
109                    let stream: &UnixStream = self.0.as_ref();
110                    match stream.try_io(tokio::io::Interest::WRITABLE, || {
111                        crate::unix_utils::sendmsg(stream, &buf[pos..], fds_to_send)
112                    }) {
113                        Ok(bytes_sent) => return Poll::Ready(Ok::<_, crate::Error>(bytes_sent)),
114                        Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
115                            match stream.poll_write_ready(cx) {
116                                Poll::Pending => return Poll::Pending,
117                                Poll::Ready(res) => res?,
118                            }
119                        }
120                        Err(e) => return Poll::Ready(Err(e.into())),
121                    }
122                }
123            })
124            .await?;
125
126            pos += n;
127        }
128
129        Ok(())
130    }
131}
132
133impl AsFd for WriteHalf {
134    fn as_fd(&self) -> BorrowedFd<'_> {
135        let stream: &UnixStream = self.0.as_ref();
136        stream.as_fd()
137    }
138}
139
140impl socket::UnixSocket for WriteHalf {}