use std::{net::SocketAddr, sync::Arc};
use n0_error::stack_error;
use n0_future::time::Duration;
use quinn::{VarInt, crypto::rustls::QuicClientConfig};
use tokio::sync::watch;
pub const ALPN_QUIC_ADDR_DISC: &[u8] = b"/iroh-qad/0";
pub const QUIC_ADDR_DISC_CLOSE_CODE: VarInt = VarInt::from_u32(1);
pub const QUIC_ADDR_DISC_CLOSE_REASON: &[u8] = b"finished";
#[cfg(feature = "server")]
pub(crate) mod server {
use n0_error::e;
use quinn::{
ApplicationClose, ConnectionError,
crypto::rustls::{NoInitialCipherSuite, QuicServerConfig},
};
use tokio::task::JoinSet;
use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle};
use tracing::{Instrument, debug, info, info_span};
use super::*;
pub use crate::server::QuicConfig;
pub struct QuicServer {
bind_addr: SocketAddr,
cancel: CancellationToken,
handle: AbortOnDropHandle<()>,
}
#[allow(missing_docs)]
#[stack_error(derive, add_meta)]
#[non_exhaustive]
pub enum QuicSpawnError {
#[error(transparent)]
NoInitialCipherSuite {
#[error(std_err, from)]
source: NoInitialCipherSuite,
},
#[error("Unable to spawn a QUIC endpoint server")]
EndpointServer {
#[error(std_err)]
source: std::io::Error,
},
#[error("Unable to get the local address from the endpoint")]
LocalAddr {
#[error(std_err)]
source: std::io::Error,
},
}
impl QuicServer {
pub fn handle(&self) -> ServerHandle {
ServerHandle {
cancel_token: self.cancel.clone(),
}
}
pub fn task_handle(&mut self) -> &mut AbortOnDropHandle<()> {
&mut self.handle
}
pub fn bind_addr(&self) -> SocketAddr {
self.bind_addr
}
pub(crate) fn spawn(mut quic_config: QuicConfig) -> Result<Self, QuicSpawnError> {
quic_config.server_config.alpn_protocols =
vec![crate::quic::ALPN_QUIC_ADDR_DISC.to_vec()];
let server_config = QuicServerConfig::try_from(quic_config.server_config)?;
let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(server_config));
let transport_config =
Arc::get_mut(&mut server_config.transport).expect("not used yet");
transport_config
.max_concurrent_uni_streams(0_u8.into())
.max_concurrent_bidi_streams(0_u8.into())
.send_observed_address_reports(true);
let endpoint = quinn::Endpoint::server(server_config, quic_config.bind_addr)
.map_err(|err| e!(QuicSpawnError::EndpointServer, err))?;
let bind_addr = endpoint
.local_addr()
.map_err(|err| e!(QuicSpawnError::LocalAddr, err))?;
info!(?bind_addr, "QUIC server listening on");
let cancel = CancellationToken::new();
let cancel_accept_loop = cancel.clone();
let task = tokio::task::spawn(
async move {
let mut set = JoinSet::new();
debug!("waiting for connections...");
loop {
tokio::select! {
biased;
_ = cancel_accept_loop.cancelled() => {
break;
}
Some(res) = set.join_next() => {
if let Err(err) = res {
if err.is_panic() {
panic!("task panicked: {err:#?}");
} else {
debug!("error accepting incoming connection: {err:#?}");
}
}
}
res = endpoint.accept() => match res {
Some(conn) => {
debug!("accepting connection");
let remote_addr = conn.remote_address();
set.spawn(
handle_connection(conn).instrument(info_span!("qad-conn", %remote_addr))
); }
None => {
debug!("endpoint closed");
break;
}
}
}
}
endpoint.close(QUIC_ADDR_DISC_CLOSE_CODE, QUIC_ADDR_DISC_CLOSE_REASON);
endpoint.wait_idle().await;
set.abort_all();
while !set.is_empty() {
_ = set.join_next().await;
}
debug!("quic endpoint has been shutdown.");
}
.instrument(info_span!("quic-endpoint")),
);
Ok(Self {
bind_addr,
cancel,
handle: AbortOnDropHandle::new(task),
})
}
pub async fn shutdown(mut self) {
self.cancel.cancel();
if !self.task_handle().is_finished() {
_ = self.task_handle().await;
}
}
}
#[derive(Debug, Clone)]
pub struct ServerHandle {
cancel_token: CancellationToken,
}
impl ServerHandle {
pub fn shutdown(&self) {
self.cancel_token.cancel()
}
}
async fn handle_connection(incoming: quinn::Incoming) -> Result<(), ConnectionError> {
let connection = match incoming.await {
Ok(conn) => conn,
Err(e) => {
return Err(e);
}
};
debug!("established");
let connection_err = connection.closed().await;
match connection_err {
quinn::ConnectionError::ApplicationClosed(ApplicationClose { error_code, .. })
if error_code == QUIC_ADDR_DISC_CLOSE_CODE =>
{
Ok(())
}
_ => Err(connection_err),
}
}
}
#[allow(missing_docs)]
#[stack_error(derive, add_meta, from_sources, std_sources)]
#[non_exhaustive]
pub enum Error {
#[error(transparent)]
Connect {
#[error(std_err)]
source: quinn::ConnectError,
},
#[error(transparent)]
Connection {
#[error(std_err)]
source: quinn::ConnectionError,
},
#[error(transparent)]
WatchRecv {
#[error(std_err)]
source: watch::error::RecvError,
},
}
#[derive(Debug, Clone)]
pub struct QuicClient {
ep: quinn::Endpoint,
client_config: quinn::ClientConfig,
}
impl QuicClient {
pub fn new(ep: quinn::Endpoint, mut client_config: rustls::ClientConfig) -> Self {
client_config.alpn_protocols = vec![ALPN_QUIC_ADDR_DISC.into()];
let mut client_config = quinn::ClientConfig::new(Arc::new(
QuicClientConfig::try_from(client_config).expect("known ciphersuite"),
));
let mut transport = quinn_proto::TransportConfig::default();
transport.initial_rtt(Duration::from_millis(111));
transport.receive_observed_address_reports(true);
transport.keep_alive_interval(Some(Duration::from_secs(25)));
transport.max_idle_timeout(Some(
Duration::from_secs(35).try_into().expect("known value"),
));
client_config.transport_config(Arc::new(transport));
Self { ep, client_config }
}
#[cfg(all(test, feature = "server"))]
async fn get_addr_and_latency(
&self,
server_addr: SocketAddr,
host: &str,
) -> Result<(SocketAddr, std::time::Duration), Error> {
use quinn_proto::PathId;
let connecting = self
.ep
.connect_with(self.client_config.clone(), server_addr, host);
let conn = connecting?.await?;
let mut external_addresses = conn.observed_external_addr();
let res = match external_addresses.wait_for(|addr| addr.is_some()).await {
Ok(res) => res,
Err(err) => {
conn.close(QUIC_ADDR_DISC_CLOSE_CODE, QUIC_ADDR_DISC_CLOSE_REASON);
return Err(err.into());
}
};
let mut observed_addr = res.expect("checked");
observed_addr = SocketAddr::new(observed_addr.ip().to_canonical(), observed_addr.port());
let latency = conn.rtt(PathId::ZERO).unwrap_or_default();
conn.close(QUIC_ADDR_DISC_CLOSE_CODE, QUIC_ADDR_DISC_CLOSE_REASON);
Ok((observed_addr, latency))
}
pub async fn create_conn(
&self,
server_addr: SocketAddr,
host: &str,
) -> Result<quinn::Connection, Error> {
let config = self.client_config.clone();
let connecting = self.ep.connect_with(config, server_addr, host);
let conn = connecting?.await?;
Ok(conn)
}
}
#[cfg(all(test, feature = "server"))]
mod tests {
use std::net::Ipv4Addr;
use n0_error::{Result, StdResultExt};
use n0_future::{
task::AbortOnDropHandle,
time::{self, Instant},
};
use n0_tracing_test::traced_test;
use quinn::crypto::rustls::QuicServerConfig;
use tracing::{Instrument, debug, info, info_span};
use webpki_types::PrivatePkcs8KeyDer;
use super::*;
#[tokio::test]
#[traced_test]
#[cfg(feature = "test-utils")]
async fn quic_endpoint_basic() -> Result {
use super::server::{QuicConfig, QuicServer};
let host: Ipv4Addr = "127.0.0.1".parse().unwrap();
let (_, server_config) = super::super::server::testing::self_signed_tls_certs_and_config();
let bind_addr = SocketAddr::new(host.into(), 0);
let quic_server = QuicServer::spawn(QuicConfig {
server_config,
bind_addr,
})?;
let client_endpoint =
quinn::Endpoint::client(SocketAddr::new(host.into(), 0)).std_context("client")?;
let client_addr = client_endpoint.local_addr().std_context("local addr")?;
let client_config = crate::client::make_dangerous_client_config();
let quic_client = QuicClient::new(client_endpoint.clone(), client_config);
let (addr, _latency) = quic_client
.get_addr_and_latency(quic_server.bind_addr(), &host.to_string())
.await?;
client_endpoint.wait_idle().await;
quic_server.shutdown().await;
assert_eq!(client_addr, addr);
Ok(())
}
#[tokio::test(start_paused = true)]
#[traced_test]
async fn test_qad_client_closes_unresponsive_fast() -> Result {
let client_endpoint =
quinn::Endpoint::client(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0))
.std_context("client")?;
let server_socket =
tokio::net::UdpSocket::bind(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0))
.await
.std_context("bind")?;
let server_addr = server_socket.local_addr().std_context("local addr")?;
let client_config = crate::client::make_dangerous_client_config();
let quic_client = QuicClient::new(client_endpoint.clone(), client_config);
let task = AbortOnDropHandle::new(tokio::spawn({
async move {
quic_client
.get_addr_and_latency(server_addr, "localhost")
.await
}
}));
tokio::time::sleep(Duration::from_millis(1000)).await;
assert!(!task.is_finished());
let before = Instant::now();
client_endpoint.close(0u32.into(), b"byeeeee");
client_endpoint.wait_idle().await;
let time = Instant::now().duration_since(before);
assert_eq!(time, Duration::from_millis(999));
Ok(())
}
#[tokio::test]
async fn test_qad_connect_delayed() -> Result {
tracing_subscriber::fmt::try_init().ok();
let socket = tokio::net::UdpSocket::bind(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0))
.await
.std_context("bind")?;
let server_addr = socket.local_addr().std_context("local addr")?;
info!(addr = ?server_addr, "server socket bound");
let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()])
.std_context("self signed")?;
let key = PrivatePkcs8KeyDer::from(cert.signing_key.serialize_der());
let mut server_crypto = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert.cert.into()], key.into())
.std_context("tls")?;
server_crypto.key_log = Arc::new(rustls::KeyLogFile::new());
server_crypto.alpn_protocols = vec![ALPN_QUIC_ADDR_DISC.to_vec()];
let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(
QuicServerConfig::try_from(server_crypto).std_context("config")?,
));
let transport_config = Arc::get_mut(&mut server_config.transport).unwrap();
transport_config.send_observed_address_reports(true);
let start = Instant::now();
let server_task = tokio::spawn(
async move {
info!("Dropping all packets");
time::timeout(Duration::from_secs(2), async {
let mut buf = [0u8; 1500];
loop {
let (len, src) = socket.recv_from(&mut buf).await.unwrap();
debug!(%len, ?src, "Dropped a packet");
}
})
.await
.ok();
info!("starting server");
let server = quinn::Endpoint::new(
Default::default(),
Some(server_config),
socket.into_std().unwrap(),
Arc::new(quinn::TokioRuntime),
)
.std_context("endpoint new")?;
info!("accepting conn");
let incoming = server.accept().await.expect("missing conn");
info!("incoming!");
let conn = incoming.await.std_context("incoming")?;
conn.closed().await;
server.wait_idle().await;
n0_error::Ok(())
}
.instrument(info_span!("server")),
);
let server_task = AbortOnDropHandle::new(server_task);
info!("starting client");
let client_endpoint =
quinn::Endpoint::client(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0))
.std_context("client")?;
let client_config = crate::client::make_dangerous_client_config();
let quic_client = QuicClient::new(client_endpoint.clone(), client_config);
info!("making QAD request");
let (addr, latency) = time::timeout(
Duration::from_secs(10),
quic_client.get_addr_and_latency(server_addr, "localhost"),
)
.await
.std_context("timeout")??;
let duration = start.elapsed();
info!(?duration, ?addr, ?latency, "QAD succeeded");
assert!(duration >= Duration::from_secs(1));
time::timeout(Duration::from_secs(10), server_task)
.await
.std_context("timeout")?
.std_context("server task")??;
Ok(())
}
}