Index: .efiles ================================================================== --- .efiles +++ .efiles @@ -1,6 +1,13 @@ Cargo.toml README.md www/index.md www/changelog.md +src/err.rs src/lib.rs src/tokio.rs +src/tokio/server.rs +src/tokio/server/listener.rs +src/tokio/client.rs +src/tokio/client/connector.rs +tests/listener-acceptor.rs +examples/listener-acceptor.rs Index: Cargo.toml ================================================================== --- Cargo.toml +++ Cargo.toml @@ -1,33 +1,49 @@ [package] name = "protwrap" -version = "0.2.2" +version = "0.3.0" edition = "2021" license = "0BSD" +# https://crates.io/category_slugs categories = [ "asynchronous", "network-programming" ] keywords = [ "network", "wrapper" ] repository = "https://repos.qrnch.tech/pub/protwrap" -description = "Thin protocol agnostic wrapper for network applications." +description = "Thin protocol wrapper for network applications." exclude = [ ".fossil-settings", ".efiles", ".fslckout", "rustfmt.toml", "www" ] +# https://doc.rust-lang.org/cargo/reference/manifest.html#the-badges-section +[badges] +maintenance = { status = "experimental" } + [features] -tokio = ["dep:tokio", "dep:tokio-util"] +tls = ["dep:tokio-rustls"] +tokio = ["dep:tokio", "dep:tokio-util", "dep:async-trait", "dep:killswitch"] [dependencies] -tokio-util = { version = "0.7.9", optional = true } +async-trait = { version = "0.1.80", optional = true } +killswitch = { version = "0.4.2", optional = true } +tokio = { version = "1.37.0", optional = true, features = [ + "macros", "net", "rt" +] } +tokio-rustls = { version = "0.24.0", optional = true, features = [ + "dangerous_configuration" +] } +tokio-util = { version = "0.7.11", optional = true } + +[target.'cfg(unix)'.dependencies] +tokio = { version = "1.38.0", optional = true, features = ["fs"] } -[dependencies.tokio] -package = "tokio" -version = "1.32.0" -features = ["net"] -optional = true +[dev-dependencies] +tokio = { version = "1.38.0", features = [ + "io-util", "rt-multi-thread", "time" +] } [package.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs", "--generate-link-to-definition"] ADDED examples/listener-acceptor.rs Index: examples/listener-acceptor.rs ================================================================== --- /dev/null +++ examples/listener-acceptor.rs @@ -0,0 +1,140 @@ +#[cfg(feature = "tokio")] +mod tok { + + pub(super) use protwrap::tokio::{ + client::connector, + server::listener::{ + async_trait, Acceptor, KillSwitch, Listener, SockAddr + }, + ServerStream + }; + + pub(super) use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + sync::oneshot + }; + + pub(super) struct MyAcceptor { + pub(super) tx_port: Option>, + pub(super) ks: KillSwitch + } + + #[async_trait] + impl Acceptor for MyAcceptor { + async fn bound(&mut self, _listener: &Listener, sa: SockAddr) { + // + // The listener has been successfully bound to a socket address + // + // Retreive the system-allocated port number and send it to the client + // ask using the one-shot channel. + // + let sa = sa.unwrap_std(); + println!("Bound to {:?}", sa); + let port = sa.port(); + let Some(tx) = self.tx_port.take() else { + panic!("Channel end-point missing"); + }; + tx.send(port).unwrap(); + } + + async fn unbound(&mut self, _listener: &Listener) { + println!("Unbound"); + } + + async fn connected(&mut self, sa: SockAddr, mut strm: ServerStream) { + let sa = sa.unwrap_std(); + println!( + "server listener: Received an incoming connection from {:?}", + sa + ); + + let killswitch = self.ks.clone(); + tokio::task::spawn(async move { + let mut buf = [0u8; 5]; + + println!("client: Waiting for 'hello' from client"); + let n = strm.read(&mut buf[..]).await.unwrap(); + assert_eq!(n, 5); + + println!("client: Sending 'world' to client"); + let n = strm.write("world".as_bytes()).await.unwrap(); + assert_eq!(n, 5); + + println!("client: Triggering killswitch to terminate listener"); + killswitch.trigger(); + }); + } + } +} + +#[cfg(feature = "tokio")] +use {std::str::FromStr, tok::*}; + +#[cfg(feature = "tokio")] +#[tokio::main] +async fn main() { + // channel used to pass port number from the server task to the client task. + let (tx, rx) = oneshot::channel(); + + // + // Prepare server task. + // + // We binding to port 0, which means the operating system should allocate + // the port number. The Acceptor::bound() callback will receive a call once + // the server port has been bound, and we use it to pass the port number to + // the client task. + // + let listener = Listener::from_str("127.0.0.1:0").unwrap(); + + let ks = KillSwitch::new(); + + let acceptor = MyAcceptor { + tx_port: Some(tx), + ks: ks.clone() + }; + + let killswitch = ks.clone(); + let jh_server = tokio::task::spawn(async move { + listener.run(killswitch, acceptor).await.unwrap(); + }); + + // + // Set up and spawn client task + // + let jh_client = tokio::task::spawn(async move { + let port = rx.await.unwrap(); + + let inf = connector::TcpConnInfo { + addr: format!("127.0.0.1:{}", port) + }; + let c = connector::Connector::Tcp(inf); + + let mut strm = c.connect().await.unwrap(); + + println!("server: Sending 'hello' to client"); + let n = strm.write("hello".as_bytes()).await.unwrap(); + assert_eq!(n, 5); + + println!("server: Waiting for 'world' reply from server"); + let mut buf = [0u8; 5]; + let n = strm.read(&mut buf[..]).await.unwrap(); + assert_eq!(n, 5); + }); + + println!("main: Wait for killswitch to trigger"); + ks.wait().await; + + println!("main: killswitch was triggered"); + + jh_client.await.unwrap(); + jh_server.await.unwrap(); + + println!("main: server and client tasks have terminated"); +} + +#[cfg(not(feature = "tokio"))] +fn main() { + println!("example requires 'tokio' feature"); +} + +// vim: set ft=rust et sw=2 ts=2 sts=2 cinoptions=2 tw=79 : ADDED src/err.rs Index: src/err.rs ================================================================== --- /dev/null +++ src/err.rs @@ -0,0 +1,38 @@ +use std::{fmt, io}; + +/// Crate-specific errors. +#[derive(Debug)] +pub enum Error { + /// Invalid protocol specifier. + BadProtSpec(String), + IO(String) +} + +impl Error { + pub fn bad_protspec(s: S) -> Self { + Error::BadProtSpec(s.to_string()) + } +} + +impl std::error::Error for Error {} + +impl From for Error { + fn from(err: io::Error) -> Self { + Error::IO(err.to_string()) + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Error::BadProtSpec(s) => { + write!(f, "Unable to parse protocol specifier string; {}", s) + } + Error::IO(s) => { + write!(f, "I/O error; {}", s) + } + } + } +} + +// vim: set ft=rust et sw=2 ts=2 sts=2 cinoptions=2 tw=79 : Index: src/lib.rs ================================================================== --- src/lib.rs +++ src/lib.rs @@ -1,17 +1,22 @@ +//! Wrappers around common network primitives to facilitate writing +//! client/server end-points. + #![cfg_attr(docsrs, feature(doc_cfg))] + +mod err; #[cfg(feature = "tokio")] #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))] pub mod tokio; -use std::fmt; -use std::str::FromStr; +use std::{fmt, str::FromStr}; #[cfg(unix)] use std::path::PathBuf; +pub use err::Error; /// Protocol selection enum. #[derive(Debug, Clone)] pub enum ProtAddr { /// Connect over TCP/IP. The `String` is a socket address in the form Index: src/tokio.rs ================================================================== --- src/tokio.rs +++ src/tokio.rs @@ -1,135 +1,31 @@ -#[cfg(unix)] -use std::fs; - -#[cfg(unix)] -use std::os::unix::fs::FileTypeExt; - -#[cfg(unix)] -use std::path::Path; - -use tokio::net::{TcpListener, TcpStream}; - -#[cfg(unix)] -use tokio::net::{UnixListener, UnixStream}; - -#[cfg(unix)] -use tokio_util::either::Either; - -#[cfg(unix)] -pub type Stream = Either; - -#[cfg(windows)] -pub type Stream = TcpStream; - -use crate::ProtAddr; - -pub async fn connect(pa: &ProtAddr) -> Result { - let strm = match pa { - ProtAddr::Tcp(sa) => connect_tcp(sa).await?, - - #[cfg(unix)] - ProtAddr::Uds(sa) => connect_uds(sa).await? - }; - - Ok(strm) -} - -/// Attempt to establish a TCP/IP socket connection. -async fn connect_tcp(addr: &str) -> Result { - let stream = TcpStream::connect(addr).await?; - - #[cfg(unix)] - return Ok(Either::Left(stream)); - - #[cfg(windows)] - return Ok(stream); -} - -/// Attempt to establish a unix domain socket connection. -/// Currently only available on unix-like platforms. -#[cfg(unix)] -async fn connect_uds(addr: &Path) -> Result { - let addr = match addr.to_str() { - Some(a) => a.to_string(), - None => unreachable!() - }; - let stream = UnixStream::connect(addr).await?; - Ok(Either::Right(stream)) -} - -pub enum Listener { - #[cfg(unix)] - Unix(UnixListener), - Tcp(TcpListener) -} - -#[derive(Debug)] -pub enum SockAddr { - Std(std::net::SocketAddr), - - #[cfg(unix)] - TokioUnix(tokio::net::unix::SocketAddr) -} - -impl Listener { - pub async fn accept(&self) -> Result<(Stream, SockAddr), tokio::io::Error> { - match self { - #[cfg(unix)] - Listener::Unix(u) => { - let (stream, sa) = u.accept().await?; - - let sa = SockAddr::TokioUnix(sa); - - return Ok((Either::Right(stream), sa)); - } - Listener::Tcp(t) => { - let (stream, sa) = t.accept().await?; - - let sa = SockAddr::Std(sa); - - #[cfg(unix)] - return Ok((Either::Left(stream), sa)); - - #[cfg(windows)] - return Ok((stream, sa)); - } - } - } -} - -pub async fn bind(pa: &ProtAddr) -> Result { - let listener = match pa { - ProtAddr::Tcp(sa) => Listener::Tcp(TcpListener::bind(sa).await?), - - #[cfg(unix)] - ProtAddr::Uds(sa) => Listener::Unix(UnixListener::bind(Path::new(sa))?) - }; - - Ok(listener) -} - -pub async fn force_bind(pa: &ProtAddr) -> Result { - let listener = match pa { - ProtAddr::Tcp(_) => bind(pa).await?, - - #[cfg(unix)] - ProtAddr::Uds(sa) => { - if sa.exists() { - let md = fs::metadata(sa)?; - let ft = md.file_type(); - if !ft.is_socket() { - return Err(std::io::Error::new( - std::io::ErrorKind::Other, - "Not a socket" - )); - } - fs::remove_file(sa)?; - } - Listener::Unix(UnixListener::bind(Path::new(sa))?) - } - }; - - Ok(listener) +//! Utility functions specific to tokio. + +pub mod client; +pub mod server; + +use tokio::io::Result; + +pub use client::Stream as ClientStream; +pub use server::Stream as ServerStream; + +/// Unified type covering both [`ServerStream`] and [`ClientStream`] types. +pub type Stream = tokio_util::either::Either; + +#[deprecated( + since = "0.3.0", + note = "Use `client::Connector::connect()` instead" +)] +pub async fn connect(pa: &super::ProtAddr) -> Result { + #[allow(irrefutable_let_patterns)] + let super::ProtAddr::Tcp(addr) = pa + else { + panic!("Not TCP"); + }; + let inf = client::connector::TcpConnInfo { + addr: addr.to_string() + }; + let c = client::Connector::Tcp(inf); + c.connect().await } // vim: set ft=rust et sw=2 ts=2 sts=2 cinoptions=2 tw=79 : ADDED src/tokio/client.rs Index: src/tokio/client.rs ================================================================== --- /dev/null +++ src/tokio/client.rs @@ -0,0 +1,86 @@ +//! Helpers for working on the end-points initiating connection requests. + +pub mod connector; + +use std::{ + pin::Pin, + task::{Context, Poll} +}; + +use tokio::{ + io::{AsyncRead, AsyncWrite, ReadBuf, Result}, + net::TcpStream +}; + +#[cfg(unix)] +use tokio::net::UnixStream; + +#[cfg(feature = "tls")] +use tokio_rustls::client::TlsStream; + +pub use connector::Connector; + +/// Representation of a stream acting as a client end-point (actively +/// established connection). +pub enum Stream { + /// TCP-based client stream. + Tcp(TcpStream), + + /// Unix local domain client stream. + #[cfg(unix)] + Uds(UnixStream), + + /// TLS, based on TCP, client stream. + #[cfg(feature = "tls")] + TlsTcp(TlsStream) +} + +macro_rules! delegate_call { + ($self:ident.$method:ident($($args:ident),+)) => { + unsafe { + match $self.get_unchecked_mut() { + Self::Tcp(s) => Pin::new_unchecked(s).$method($($args),+), + #[cfg(unix)] + Self::Uds(s) => Pin::new_unchecked(s).$method($($args),+), + #[cfg(feature = "tls")] + Self::TlsTcp(s) => Pin::new_unchecked(s).$method($($args),+), + } + } + } +} + +impl AsyncRead for Stream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_> + ) -> Poll> { + delegate_call!(self.poll_read(cx, buf)) + } +} + +impl AsyncWrite for Stream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8] + ) -> Poll> { + delegate_call!(self.poll_write(cx, buf)) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut Context<'_> + ) -> Poll> { + delegate_call!(self.poll_flush(cx)) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_> + ) -> Poll> { + delegate_call!(self.poll_shutdown(cx)) + } +} + +// vim: set ft=rust et sw=2 ts=2 sts=2 cinoptions=2 tw=79 : ADDED src/tokio/client/connector.rs Index: src/tokio/client/connector.rs ================================================================== --- /dev/null +++ src/tokio/client/connector.rs @@ -0,0 +1,196 @@ +//! Utility functions for establishing connections for common stream types. + +use std::str::FromStr; + +#[cfg(unix)] +use {std::path::PathBuf, tokio::net::UnixStream}; + +use tokio::net::TcpStream; + +#[cfg(feature = "tls")] +use { + std::{sync::Arc, time::SystemTime}, + tokio_rustls::{ + rustls::{ + self, + client::{ServerCertVerified, ServerCertVerifier, ServerName}, + Certificate + }, + TlsConnector + } +}; + +use super::Stream; + +use crate::err::Error; + +/// Context used to establish TCP connections. +pub struct TcpConnInfo { + /// Socket address. + pub addr: String +} + +impl FromStr for TcpConnInfo { + type Err = Error; + + fn from_str(s: &str) -> Result { + Ok(Self { + addr: s.to_string() + }) + } +} + + +/// Context used to establish unix local domain connections. +#[cfg(unix)] +pub struct UdsConnInfo { + /// Socket address pathname. + pub fname: PathBuf +} + +#[cfg(unix)] +impl FromStr for UdsConnInfo { + type Err = Error; + + fn from_str(s: &str) -> Result { + Ok(Self { + fname: PathBuf::from(s) + }) + } +} + + +/// Context used to establish TLS (based on TCP) connections. +// ToDo: Add key/cert fields +#[cfg(feature = "tls")] +pub struct TlsTcpConnInfo { + /// Socket address. + pub addr: String +} + +#[cfg(feature = "tls")] +impl FromStr for TlsTcpConnInfo { + type Err = Error; + + fn from_str(s: &str) -> Result { + Ok(Self { + addr: s.to_string() + }) + } +} + + +/// Protocol-specific connector helper. +pub enum Connector { + Tcp(TcpConnInfo), + #[cfg(unix)] + Uds(UdsConnInfo), + #[cfg(feature = "tls")] + TlsTcp(TlsTcpConnInfo) +} + +impl Connector { + /// Create a TCP listener from a string. + pub fn tcp(s: &str) -> Result { + Ok(Connector::Tcp(TcpConnInfo::from_str(s)?)) + } + + /// Create an unix domain socket listener from a string. + #[cfg(unix)] + pub fn uds(s: &str) -> Result { + Ok(Connector::Uds(UdsConnInfo::from_str(s)?)) + } + + #[cfg(feature = "tls")] + pub fn tls_tcp(s: &str) -> Result { + Ok(Connector::TlsTcp(TlsTcpConnInfo::from_str(s)?)) + } +} + + +// ToDo: Add tls/tcp parameters parsing +impl FromStr for Connector { + type Err = Error; + fn from_str(s: &str) -> Result { + #[cfg(unix)] + if s.find('/').is_some() { + // Assume unix domain socket + Ok(Connector::Uds(UdsConnInfo::from_str(s)?)) + } else { + // Assume IP socket address + Ok(Connector::Tcp(TcpConnInfo::from_str(s)?)) + } + + #[cfg(windows)] + Ok(Connector::Tcp(TcpConnInfo::from_str(s)?)) + } +} + +impl Connector { + pub async fn connect(&self) -> Result { + match self { + Self::Tcp(info) => { + let strm = TcpStream::connect(&info.addr).await?; + Ok(Stream::Tcp(strm)) + } + + #[cfg(unix)] + Self::Uds(info) => { + let strm = UnixStream::connect(&info.fname).await?; + Ok(Stream::Uds(strm)) + } + + #[cfg(feature = "tls")] + Self::TlsTcp(info) => { + // Connect to server, without SNI and with a custom certificate + // validation (which does nothing) + let versions = rustls::DEFAULT_VERSIONS.to_vec(); + let cfg = rustls::ClientConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(&versions) + .expect("inconsistent cipher-suite/versions selected") + .with_custom_certificate_verifier(Arc::new(CertVerifier {})) + .with_no_client_auth(); + + let connector = TlsConnector::from(Arc::new(cfg)); + + let raw_stream = TcpStream::connect(&info.addr).await.unwrap(); + + let domain = rustls::ServerName::try_from("localhost").unwrap(); + /* + map_err(|_| { + io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname") + })?; + */ + + let strm = connector.connect(domain, raw_stream).await.unwrap(); + + Ok(Stream::TlsTcp(strm)) + } + } + } +} + + +/// Place-holder for a "Null" cert verifier, usable for prototyping. +#[cfg(feature = "tls")] +struct CertVerifier {} + +#[cfg(feature = "tls")] +impl ServerCertVerifier for CertVerifier { + fn verify_server_cert( + &self, + _end_entity: &Certificate, + _intermediates: &[Certificate], + _server_name: &ServerName, + _scts: &mut dyn Iterator, + _ocsp_response: &[u8], + _now: SystemTime + ) -> Result { + //tracing::debug!("Verify server certificate"); + Ok(ServerCertVerified::assertion()) + } +} + +// vim: set ft=rust et sw=2 ts=2 sts=2 cinoptions=2 tw=79 : ADDED src/tokio/server.rs Index: src/tokio/server.rs ================================================================== --- /dev/null +++ src/tokio/server.rs @@ -0,0 +1,85 @@ +//! Helpers for working on the end-points receiving connection requests. + +pub mod listener; + +use std::{ + pin::Pin, + task::{Context, Poll} +}; + +use tokio::{ + io::{AsyncRead, AsyncWrite, ReadBuf, Result}, + net::TcpStream +}; + +#[cfg(unix)] +use tokio::net::UnixStream; + +#[cfg(feature = "tls")] +use tokio_rustls::server::TlsStream; + + +/// Representation of a stream acting as a server end-point (passively +/// established connection). +pub enum Stream { + /// TCP-based server stream. + Tcp(TcpStream), + + /// Unix local domain-based server stream. + #[cfg(unix)] + Uds(UnixStream), + + /// TLS, based on TCP, based server stream. + #[cfg(feature = "tls")] + TlsTcp(TlsStream) +} + +macro_rules! delegate_call { + ($self:ident.$method:ident($($args:ident),+)) => { + unsafe { + match $self.get_unchecked_mut() { + Self::Tcp(s) => Pin::new_unchecked(s).$method($($args),+), + #[cfg(unix)] + Self::Uds(s) => Pin::new_unchecked(s).$method($($args),+), + #[cfg(feature = "tls")] + Self::TlsTcp(s) => Pin::new_unchecked(s).$method($($args),+), + } + } + } +} + +impl AsyncRead for Stream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_> + ) -> Poll> { + delegate_call!(self.poll_read(cx, buf)) + } +} + +impl AsyncWrite for Stream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8] + ) -> Poll> { + delegate_call!(self.poll_write(cx, buf)) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut Context<'_> + ) -> Poll> { + delegate_call!(self.poll_flush(cx)) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_> + ) -> Poll> { + delegate_call!(self.poll_shutdown(cx)) + } +} + +// vim: set ft=rust et sw=2 ts=2 sts=2 cinoptions=2 tw=79 : ADDED src/tokio/server/listener.rs Index: src/tokio/server/listener.rs ================================================================== --- /dev/null +++ src/tokio/server/listener.rs @@ -0,0 +1,294 @@ +//! Utilities for running abortable listeners. + +use std::{path::PathBuf, str::FromStr}; + +use tokio::net::TcpListener; + +#[cfg(unix)] +use { + std::os::unix::fs::FileTypeExt, + tokio::{fs, net::UnixListener} +}; + +#[cfg(feature = "tokio")] +pub use {async_trait::async_trait, killswitch::KillSwitch}; + +use super::Stream; + +use crate::err::Error; + +/// Abstraction around std's [`SocketAddr`](std::net::SocketAddr) (for +/// IPv4/IPv6) and tokio's (unix local domain) +/// [`SocketAddr`](tokio::net::unix::SocketAddr). +/// +/// In an idea world, this would not be needed (or, at least, this create would +/// not need to define it), but this is a less-than ideal world. +#[derive(Debug)] +pub enum SockAddr { + Std(std::net::SocketAddr), + + #[cfg(unix)] + TokioUnix(tokio::net::unix::SocketAddr) +} + +impl SockAddr { + /// Unwrap the [`std::net::SocketAddr`] (i.e. IPv4/IPv6) case. + /// + /// # Panics + /// Will panic if the type is not `SockAddr::Std`. + pub fn unwrap_std(self) -> std::net::SocketAddr { + #[allow(irrefutable_let_patterns)] + let SockAddr::Std(s) = self + else { + panic!("Not SockAddr::Std()"); + }; + s + } + + pub fn try_as_std(&self) -> Option<&std::net::SocketAddr> { + #[allow(irrefutable_let_patterns)] + if let SockAddr::Std(s) = self { + Some(s) + } else { + None + } + } + + /// Unwrap the [`tokio::net::unix::SocketAddr`] (i.e. unix local domain + /// socket) case. + /// + /// # Panics + /// Will panic if the type is not `SockAddr::TokioUnix`. + #[cfg(unix)] + pub fn unwrap_tokunix(self) -> tokio::net::unix::SocketAddr { + let SockAddr::TokioUnix(s) = self else { + panic!("Not SockAddr::TokioUnix()"); + }; + s + } + + #[cfg(unix)] + pub fn try_as_tokunix(&self) -> Option<&tokio::net::unix::SocketAddr> { + if let SockAddr::TokioUnix(s) = self { + Some(s) + } else { + None + } + } +} + +impl TryFrom for std::net::SocketAddr { + type Error = SockAddr; + + fn try_from(orig: SockAddr) -> Result { + match orig { + SockAddr::Std(sa) => Ok(sa), + #[allow(unreachable_patterns)] + a => Err(a) + } + } +} + +#[cfg(unix)] +impl TryFrom for tokio::net::unix::SocketAddr { + type Error = SockAddr; + + fn try_from(orig: SockAddr) -> Result { + match orig { + SockAddr::TokioUnix(sa) => Ok(sa), + a => Err(a) + } + } +} + + +/// Callbacks for the [`Listener`] type. +#[async_trait] +pub trait Acceptor { + /// Called once the listener has successfully bound. + async fn bound(&mut self, listener: &Listener, sa: SockAddr); + + /// Called when the listener has terminated. + async fn unbound(&mut self, listener: &Listener); + + /// Called when the listener has accepted a client connection request. + async fn connected(&mut self, sa: SockAddr, strm: Stream); +} + + +/// Context used to define a TCP listener. +pub struct TcpListenerInfo { + /// Socket address to bind listener to. + pub addr: String +} + +impl FromStr for TcpListenerInfo { + type Err = Error; + + fn from_str(s: &str) -> Result { + Ok(Self { + addr: s.to_string() + }) + } +} + + +/// Context used to define a unix local domain listener. +pub struct UdsListenerInfo { + /// Socket pathname to bind listener to. + pub fname: PathBuf, + + /// Create directory for socket file, if required. + pub mkdir: bool, + + /// If socket file already exists, then remove it before binding. + pub force: bool +} + +impl FromStr for UdsListenerInfo { + type Err = Error; + + fn from_str(s: &str) -> Result { + Ok(Self { + fname: PathBuf::from(s), + mkdir: false, + force: false + }) + } +} + + +/// Wrapper around common protocol-specific listener specifiers. +pub enum Listener { + Tcp(TcpListenerInfo), + #[cfg(unix)] + Uds(UdsListenerInfo) +} + +impl Listener { + /// Create a TCP listener from a string. + pub fn tcp(s: &str) -> Result { + Ok(Listener::Tcp(TcpListenerInfo::from_str(s)?)) + } + + /// Create an unix domain socket listener from a string. + #[cfg(unix)] + pub fn uds(s: &str) -> Result { + Ok(Listener::Uds(UdsListenerInfo::from_str(s)?)) + } +} + +impl FromStr for Listener { + type Err = Error; + fn from_str(s: &str) -> Result { + #[cfg(unix)] + if s.find('/').is_some() { + // Assume unix domain socket + Ok(Listener::Uds(UdsListenerInfo::from_str(s)?)) + } else { + // Assume IP socket address + Ok(Listener::Tcp(TcpListenerInfo::from_str(s)?)) + } + + #[cfg(windows)] + Ok(Listener::Tcp(TcpListenerInfo::from_str(s)?)) + } +} + +impl Listener { + /// Run a listener loop. + /// + /// If the socket bind is successful the [`Acceptor::bound()`] of `acceptor` + /// will be called, where the bound socket address will be passed as an + /// argument. (This can be used to retreive the port number if the + /// application requested the port number to be automatically assigned. + /// + /// Each time a client has connected the acceptor will call + /// [`Acceptor::connected()`] to allow the application to process the + /// connection. The ownership of the newly established connection will be + /// passed to the `connected()` method. + /// + /// # Unix domain sockets + /// If the listener is a unix domain socket, the socket file will be removed + /// if the listener is aborted. + pub async fn run( + &self, + ks: KillSwitch, + mut acceptor: impl Acceptor + ) -> Result<(), std::io::Error> { + match self { + Listener::Tcp(info) => { + let listener = TcpListener::bind(&info.addr).await?; + + let sa = listener.local_addr()?; + acceptor.bound(self, SockAddr::Std(sa)).await; + + loop { + tokio::select! { + ret = listener.accept() => { + let (strm, sa) = ret?; + let sa = SockAddr::Std(sa); + acceptor.connected(sa, Stream::Tcp(strm)).await; + } + _ = ks.wait() => { + break; + } + } + } + + drop(listener); + + acceptor.unbound(self).await; + } + + #[cfg(unix)] + Listener::Uds(info) => { + if info.mkdir { + if let Some(dir) = info.fname.parent() { + fs::create_dir_all(dir).await?; + } + } + if info.force && info.fname.exists() { + let md = fs::metadata(&info.fname).await?; + let ft = md.file_type(); + if !ft.is_socket() { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Not a socket" + )); + } + fs::remove_file(&info.fname).await?; + } + + let listener = UnixListener::bind(&info.fname)?; + + let sa = listener.local_addr()?; + acceptor.bound(self, SockAddr::TokioUnix(sa)).await; + + loop { + tokio::select! { + ret = listener.accept() => { + let (strm, sa) = ret?; + let sa = SockAddr::TokioUnix(sa); + acceptor.connected(sa, Stream::Uds(strm)).await; + } + _ = ks.wait() => { + break; + } + } + } + + drop(listener); + + // Don't abort here, because unbound should be called before doing so. + let res = fs::remove_file(&info.fname).await; + acceptor.unbound(self).await; + res?; + } + } + + Ok(()) + } +} + +// vim: set ft=rust et sw=2 ts=2 sts=2 cinoptions=2 tw=79 : ADDED tests/listener-acceptor.rs Index: tests/listener-acceptor.rs ================================================================== --- /dev/null +++ tests/listener-acceptor.rs @@ -0,0 +1,108 @@ +#[cfg(feature = "tokio")] +mod tokio_tests { + use std::str::FromStr; + + use protwrap::tokio::{ + client::connector, + server::listener::{ + async_trait, Acceptor, KillSwitch, Listener, SockAddr + }, + ServerStream + }; + + use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + sync::oneshot + }; + + struct MyAcceptor { + tx_port: Option>, + ks: KillSwitch, + did_bind: bool, + did_connect: bool, + did_unbind: bool + } + + #[async_trait] + impl Acceptor for MyAcceptor { + async fn bound(&mut self, _listener: &Listener, sa: SockAddr) { + self.did_bind = true; + + let port = sa.unwrap_std().port(); + let Some(tx) = self.tx_port.take() else { + panic!("Channel end-point missing"); + }; + tx.send(port).unwrap(); + } + + async fn unbound(&mut self, _listener: &Listener) { + self.did_unbind = true; + } + + async fn connected(&mut self, _sa: SockAddr, mut strm: ServerStream) { + self.did_connect = true; + + let killswitch = self.ks.clone(); + tokio::task::spawn(async move { + let mut buf = [0u8; 5]; + + let n = strm.read(&mut buf[..]).await.unwrap(); + assert_eq!(n, 5); + assert_eq!(buf, "hello".as_bytes()); + + let n = strm.write("world".as_bytes()).await.unwrap(); + assert_eq!(n, 5); + + killswitch.trigger(); + }); + } + } + + #[tokio::test] + async fn main() { + let (tx, rx) = oneshot::channel(); + + let listener = Listener::from_str("127.0.0.1:0").unwrap(); + + let ks = KillSwitch::new(); + + let acceptor = MyAcceptor { + tx_port: Some(tx), + ks: ks.clone(), + did_bind: false, + did_connect: false, + did_unbind: false + }; + + let killswitch = ks.clone(); + let jh_server = tokio::task::spawn(async move { + listener.run(killswitch, acceptor).await.unwrap(); + }); + + let jh_client = tokio::task::spawn(async move { + // Use side-channel to receive port number from server + let port = rx.await.unwrap(); + + let addr = format!("127.0.0.1:{}", port); + let c = connector::Connector::from_str(&addr).unwrap(); + + let mut strm = c.connect().await.unwrap(); + + println!("server: Sending 'hello' to client"); + let n = strm.write("hello".as_bytes()).await.unwrap(); + assert_eq!(n, 5); + + let mut buf = [0u8; 5]; + let n = strm.read(&mut buf[..]).await.unwrap(); + assert_eq!(n, 5); + assert_eq!(buf, "world".as_bytes()); + }); + + ks.wait().await; + + jh_client.await.unwrap(); + jh_server.await.unwrap(); + } +} + +// vim: set ft=rust et sw=2 ts=2 sts=2 cinoptions=2 tw=79 : Index: www/changelog.md ================================================================== --- www/changelog.md +++ www/changelog.md @@ -1,16 +1,41 @@ # Change Log ## [Unreleased] +[Details](/vdiff?from=protwrap-0.3.0&to=trunk) + ### Added ### Changed ### Removed + +--- + +## [0.3.0] - 2024-05-31 + +[Details](/vdiff?from=protwrap-0.2.2&to=protwrap-0.3.0) + +This is a major rewrite. + +### Added + +- `Acceptor::unbound()`. +- Listener, and their "info" buffers, gained `FromStr` parsers. + +### Changed + +- Add `&Listener` parameter to `Acceptor::bound()`. +- uds listener removes socket file when loop terminated. +- Put `connector` under `client` and `listener` under `server`. + +--- ## [0.2.2] - 2023-10-03 + +[Details](/vdiff?from=protwrap-0.2.1&to=protwrap-0.2.2) ### Changed - Fix fallout after earlier feature/dependency rename. Index: www/index.md ================================================================== --- www/index.md +++ www/index.md @@ -3,10 +3,20 @@ Protocol Wrapper is a thin wrapper on top of common low-level network protocol API's to allow developers to easily support common protocols like TcpStream and UnixStream without having to explicitly write support for them in application code. + +## Feature labels in documentation + +The crate's documentation uses automatically generated feature labels, which +currently requires nightly featuers. To build the documentation locally use: + +``` +RUSTFLAGS="--cfg docsrs" RUSTDOCFLAGS="--cfg docsrs" cargo +nightly doc --all-features +``` + ## Change log The details of changes can always be found in the timeline, but for a high-level view of changes between released versions there's a manually