use std::{fmt, future::Future, net::SocketAddr, num::NonZeroU32, pin::Pin, sync::Arc};
use derive_more::Debug;
use http::{
HeaderMap, HeaderValue, Method, Request, Response, StatusCode, header::InvalidHeaderValue,
response::Builder as ResponseBuilder,
};
use hyper::body::Incoming;
use iroh_base::EndpointId;
#[cfg(feature = "test-utils")]
use iroh_base::RelayUrl;
use n0_error::{e, stack_error};
use n0_future::{StreamExt, future::Boxed};
use serde::Serialize;
use tokio::{
net::TcpListener,
task::{JoinError, JoinSet},
};
use tokio_util::task::AbortOnDropHandle;
use tracing::{Instrument, debug, error, info, info_span, instrument};
use crate::{
defaults::DEFAULT_KEY_CACHE_CAPACITY,
http::RELAY_PROBE_PATH,
quic::server::{QuicServer, QuicSpawnError, ServerHandle as QuicServerHandle},
};
pub mod client;
pub mod clients;
pub mod http_server;
mod metrics;
pub(crate) mod resolver;
pub mod streams;
#[cfg(feature = "test-utils")]
pub mod testing;
pub use self::{
http_server::{Handlers, RelayService},
metrics::{Metrics, RelayMetrics},
resolver::{DEFAULT_CERT_RELOAD_INTERVAL, ReloadingResolver},
};
const NO_CONTENT_CHALLENGE_HEADER: &str = "X-Iroh-Challenge";
const NO_CONTENT_RESPONSE_HEADER: &str = "X-Iroh-Response";
const NOTFOUND: &[u8] = b"Not Found";
const ROBOTS_TXT: &[u8] = b"User-agent: *\nDisallow: /\n";
const INDEX: &[u8] = br#"<html><body>
<h1>Iroh Relay</h1>
<p>
This is an <a href="https://iroh.computer/">Iroh</a> Relay server.
</p>
"#;
const TLS_HEADERS: [(&str, &str); 2] = [
(
"Strict-Transport-Security",
"max-age=63072000; includeSubDomains",
),
(
"Content-Security-Policy",
"default-src 'none'; frame-ancestors 'none'; form-action 'none'; base-uri 'self'; block-all-mixed-content; plugin-types 'none'",
),
];
type BytesBody = http_body_util::Full<hyper::body::Bytes>;
type HyperError = Box<dyn std::error::Error + Send + Sync>;
type HyperResult<T> = std::result::Result<T, HyperError>;
fn body_empty() -> BytesBody {
http_body_util::Full::new(hyper::body::Bytes::new())
}
#[derive(Debug, Default)]
pub struct ServerConfig<EC: fmt::Debug, EA: fmt::Debug = EC> {
pub relay: Option<RelayConfig<EC, EA>>,
pub quic: Option<QuicConfig>,
#[cfg(feature = "metrics")]
pub metrics_addr: Option<SocketAddr>,
}
#[derive(Debug)]
pub struct RelayConfig<EC: fmt::Debug, EA: fmt::Debug = EC> {
pub http_bind_addr: SocketAddr,
pub tls: Option<TlsConfig<EC, EA>>,
pub limits: Limits,
pub key_cache_capacity: Option<usize>,
pub access: AccessConfig,
}
#[derive(derive_more::Debug)]
pub enum AccessConfig {
Everyone,
#[debug("restricted")]
Restricted(Box<dyn Fn(EndpointId) -> Boxed<Access> + Send + Sync + 'static>),
}
impl AccessConfig {
pub async fn is_allowed(&self, endpoint: EndpointId) -> bool {
match self {
Self::Everyone => true,
Self::Restricted(check) => {
let res = check(endpoint).await;
matches!(res, Access::Allow)
}
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum Access {
Allow,
Deny,
}
#[derive(Debug)]
pub struct QuicConfig {
pub bind_addr: SocketAddr,
pub server_config: rustls::ServerConfig,
}
#[derive(Debug)]
pub struct TlsConfig<EC: fmt::Debug, EA: fmt::Debug = EC> {
pub https_bind_addr: SocketAddr,
pub quic_bind_addr: SocketAddr,
pub cert: CertConfig<EC, EA>,
pub server_config: rustls::ServerConfig,
}
#[derive(Debug, Default)]
pub struct Limits {
pub accept_conn_limit: Option<f64>,
pub accept_conn_burst: Option<usize>,
pub client_rx: Option<ClientRateLimit>,
}
#[derive(Debug, Copy, Clone)]
pub struct ClientRateLimit {
pub bytes_per_second: NonZeroU32,
pub max_burst_bytes: Option<NonZeroU32>,
}
#[derive(derive_more::Debug)]
pub enum CertConfig<EC: fmt::Debug, EA: fmt::Debug = EC> {
LetsEncrypt {
#[debug("AcmeConfig")]
state: tokio_rustls_acme::AcmeState<EC, EA>,
},
Manual {
certs: Vec<rustls::pki_types::CertificateDer<'static>>,
},
Reloading,
}
#[derive(Debug)]
pub struct Server {
http_addr: Option<SocketAddr>,
https_addr: Option<SocketAddr>,
quic_addr: Option<SocketAddr>,
relay_handle: Option<http_server::ServerHandle>,
quic_handle: Option<QuicServerHandle>,
supervisor: AbortOnDropHandle<Result<(), SupervisorError>>,
certificates: Option<Vec<rustls::pki_types::CertificateDer<'static>>>,
metrics: RelayMetrics,
}
#[allow(missing_docs)]
#[stack_error(derive, add_meta, std_sources)]
#[non_exhaustive]
pub enum SpawnError {
#[error("Unable to get local address")]
LocalAddr { source: std::io::Error },
#[error("Failed to bind QAD listener")]
QuicSpawn { source: QuicSpawnError },
#[error("Failed to parse TLS header")]
TlsHeaderParse { source: InvalidHeaderValue },
#[error("Failed to bind TcpListener")]
BindTlsListener { source: std::io::Error },
#[error("No local address")]
NoLocalAddr { source: std::io::Error },
#[error("Failed to bind server socket to {addr}")]
BindTcpListener {
source: std::io::Error,
addr: SocketAddr,
},
}
#[allow(missing_docs)]
#[stack_error(derive, add_meta)]
#[non_exhaustive]
pub enum SupervisorError {
#[error("Error starting metrics server")]
Metrics {
#[error(std_err)]
source: std::io::Error,
},
#[error("Acme event stream finished")]
AcmeEventStreamFinished {},
#[error(transparent)]
JoinError {
#[error(from, std_err)]
source: JoinError,
},
#[error("No relay services are enabled")]
NoRelayServicesEnabled {},
#[error("Task cancelled")]
TaskCancelled {},
}
impl Server {
pub async fn spawn<EC, EA>(config: ServerConfig<EC, EA>) -> Result<Self, SpawnError>
where
EC: fmt::Debug + 'static,
EA: fmt::Debug + 'static,
{
let mut tasks = JoinSet::new();
let metrics = RelayMetrics::default();
#[cfg(feature = "metrics")]
if let Some(addr) = config.metrics_addr {
debug!("Starting metrics server");
let mut registry = iroh_metrics::Registry::default();
registry.register_all(&metrics);
tasks.spawn(
async move {
iroh_metrics::service::start_metrics_server(addr, Arc::new(registry))
.await
.map_err(|err| e!(SupervisorError::Metrics, err))
}
.instrument(info_span!("metrics-server")),
);
}
let certificates = config.relay.as_ref().and_then(|relay| {
relay.tls.as_ref().and_then(|tls| match tls.cert {
CertConfig::LetsEncrypt { .. } => None,
CertConfig::Manual { ref certs, .. } => Some(certs.clone()),
CertConfig::Reloading => None,
})
});
let quic_server = match config.quic {
Some(quic_config) => {
debug!("Starting QUIC server {}", quic_config.bind_addr);
Some(QuicServer::spawn(quic_config).map_err(|err| e!(SpawnError::QuicSpawn, err))?)
}
None => None,
};
let quic_addr = quic_server.as_ref().map(|srv| srv.bind_addr());
let quic_handle = quic_server.as_ref().map(|srv| srv.handle());
let (relay_server, http_addr) = match config.relay {
Some(relay_config) => {
debug!("Starting Relay server");
let mut headers = HeaderMap::new();
for (name, value) in TLS_HEADERS.iter() {
headers.insert(
*name,
value
.parse()
.map_err(|err| e!(SpawnError::TlsHeaderParse, err))?,
);
}
let relay_bind_addr = match relay_config.tls {
Some(ref tls) => tls.https_bind_addr,
None => relay_config.http_bind_addr,
};
let key_cache_capacity = relay_config
.key_cache_capacity
.unwrap_or(DEFAULT_KEY_CACHE_CAPACITY);
let mut builder = http_server::ServerBuilder::new(relay_bind_addr)
.metrics(metrics.server.clone())
.headers(headers)
.key_cache_capacity(key_cache_capacity)
.access(relay_config.access)
.request_handler(Method::GET, "/", Box::new(root_handler))
.request_handler(Method::GET, "/index.html", Box::new(root_handler))
.request_handler(Method::GET, RELAY_PROBE_PATH, Box::new(probe_handler))
.request_handler(Method::GET, "/robots.txt", Box::new(robots_handler))
.request_handler(Method::GET, "/healthz", Box::new(healthz_handler));
if let Some(cfg) = relay_config.limits.client_rx {
builder = builder.client_rx_ratelimit(cfg);
}
let http_addr = match relay_config.tls {
Some(tls_config) => {
let server_tls_config = match tls_config.cert {
CertConfig::LetsEncrypt { mut state } => {
let acceptor =
http_server::TlsAcceptor::LetsEncrypt(state.acceptor());
tasks.spawn(
async move {
while let Some(event) = state.next().await {
match event {
Ok(ok) => debug!("acme event: {ok:?}"),
Err(err) => error!("error: {err:?}"),
}
}
Err(e!(SupervisorError::AcmeEventStreamFinished))
}
.instrument(info_span!("acme")),
);
Some(http_server::TlsConfig {
config: Arc::new(tls_config.server_config),
acceptor,
})
}
CertConfig::Manual { .. } | CertConfig::Reloading => {
let server_config = Arc::new(tls_config.server_config);
let acceptor =
tokio_rustls::TlsAcceptor::from(server_config.clone());
let acceptor = http_server::TlsAcceptor::Manual(acceptor);
Some(http_server::TlsConfig {
config: server_config,
acceptor,
})
}
};
builder = builder.tls_config(server_tls_config);
let http_listener = TcpListener::bind(&relay_config.http_bind_addr)
.await
.map_err(|err| e!(SpawnError::BindTlsListener, err))?;
let http_addr = http_listener
.local_addr()
.map_err(|err| e!(SpawnError::NoLocalAddr, err))?;
tasks.spawn(
async move {
run_captive_portal_service(http_listener).await;
Ok(())
}
.instrument(info_span!("http-service", addr = %http_addr)),
);
Some(http_addr)
}
None => {
builder = builder.request_handler(
Method::GET,
"/generate_204",
Box::new(serve_no_content_handler),
);
None
}
};
let relay_server = builder.spawn().await?;
(Some(relay_server), http_addr)
}
None => (None, None),
};
let relay_addr = relay_server.as_ref().map(|srv| srv.addr());
let relay_handle = relay_server.as_ref().map(|srv| srv.handle());
let task = tokio::spawn(relay_supervisor(tasks, relay_server, quic_server));
Ok(Self {
http_addr: http_addr.or(relay_addr),
https_addr: http_addr.and(relay_addr),
quic_addr,
relay_handle,
quic_handle,
supervisor: AbortOnDropHandle::new(task),
certificates,
metrics,
})
}
pub async fn shutdown(self) -> Result<(), SupervisorError> {
if let Some(handle) = self.relay_handle {
handle.shutdown();
}
if let Some(handle) = self.quic_handle {
handle.shutdown();
}
self.supervisor.await?
}
pub fn task_handle(&mut self) -> &mut AbortOnDropHandle<Result<(), SupervisorError>> {
&mut self.supervisor
}
pub fn https_addr(&self) -> Option<SocketAddr> {
self.https_addr
}
pub fn http_addr(&self) -> Option<SocketAddr> {
self.http_addr
}
pub fn quic_addr(&self) -> Option<SocketAddr> {
self.quic_addr
}
pub fn certificates(&self) -> Option<Vec<rustls::pki_types::CertificateDer<'static>>> {
self.certificates.clone()
}
#[cfg(feature = "test-utils")]
pub fn https_url(&self) -> Option<RelayUrl> {
self.https_addr.map(|addr| {
url::Url::parse(&format!("https://{addr}"))
.expect("valid url")
.into()
})
}
#[cfg(feature = "test-utils")]
pub fn http_url(&self) -> Option<RelayUrl> {
self.http_addr.map(|addr| {
url::Url::parse(&format!("http://{addr}"))
.expect("valid url")
.into()
})
}
pub fn metrics(&self) -> &RelayMetrics {
&self.metrics
}
}
#[instrument(skip_all)]
async fn relay_supervisor(
mut tasks: JoinSet<Result<(), SupervisorError>>,
mut relay_http_server: Option<http_server::Server>,
mut quic_server: Option<QuicServer>,
) -> Result<(), SupervisorError> {
let quic_enabled = quic_server.is_some();
let mut quic_fut = match quic_server {
Some(ref mut server) => n0_future::Either::Left(server.task_handle()),
None => n0_future::Either::Right(n0_future::future::pending()),
};
let relay_enabled = relay_http_server.is_some();
let mut relay_fut = match relay_http_server {
Some(ref mut server) => n0_future::Either::Left(server.task_handle()),
None => n0_future::Either::Right(n0_future::future::pending()),
};
let res = tokio::select! {
biased;
Some(ret) = tasks.join_next() => ret,
ret = &mut quic_fut, if quic_enabled => ret.map(Ok),
ret = &mut relay_fut, if relay_enabled => ret.map(Ok),
else => Ok(Err(e!(SupervisorError::NoRelayServicesEnabled))),
};
let ret = match res {
Ok(Ok(())) => {
debug!("Task exited");
Ok(())
}
Ok(Err(err)) => {
error!(%err, "Task failed");
Err(err)
}
Err(err) => {
if let Ok(panic) = err.try_into_panic() {
error!("Task panicked");
std::panic::resume_unwind(panic);
}
debug!("Task cancelled");
Err(e!(SupervisorError::TaskCancelled))
}
};
if let Some(server) = relay_http_server {
server.shutdown();
}
if let Some(server) = quic_server {
server.shutdown().await;
}
tasks.shutdown().await;
ret
}
fn root_handler(
_r: Request<Incoming>,
response: ResponseBuilder,
) -> HyperResult<Response<BytesBody>> {
response
.status(StatusCode::OK)
.header("Content-Type", "text/html; charset=utf-8")
.body(INDEX.into())
.map_err(|err| Box::new(err) as HyperError)
}
fn probe_handler(
_r: Request<Incoming>,
response: ResponseBuilder,
) -> HyperResult<Response<BytesBody>> {
response
.status(StatusCode::OK)
.header("Access-Control-Allow-Origin", "*")
.body(body_empty())
.map_err(|err| Box::new(err) as HyperError)
}
fn robots_handler(
_r: Request<Incoming>,
response: ResponseBuilder,
) -> HyperResult<Response<BytesBody>> {
response
.status(StatusCode::OK)
.body(ROBOTS_TXT.into())
.map_err(|err| Box::new(err) as HyperError)
}
fn serve_no_content_handler<B: hyper::body::Body>(
r: Request<B>,
mut response: ResponseBuilder,
) -> HyperResult<Response<BytesBody>> {
let check = |c: &HeaderValue| {
!c.is_empty() && c.len() < 64 && c.as_bytes().iter().all(|c| is_challenge_char(*c as char))
};
if let Some(challenge) = r.headers().get(NO_CONTENT_CHALLENGE_HEADER)
&& check(challenge)
{
response = response.header(
NO_CONTENT_RESPONSE_HEADER,
format!("response {}", challenge.to_str()?),
);
}
response
.status(StatusCode::NO_CONTENT)
.body(body_empty())
.map_err(|err| Box::new(err) as HyperError)
}
fn is_challenge_char(c: char) -> bool {
c.is_ascii_lowercase()
|| c.is_ascii_uppercase()
|| c.is_ascii_digit()
|| c == '.'
|| c == '-'
|| c == '_'
}
#[derive(Serialize)]
struct Health {
status: &'static str,
version: &'static str,
git_hash: &'static str,
}
fn healthz_handler(
_r: Request<Incoming>,
response: ResponseBuilder,
) -> HyperResult<Response<BytesBody>> {
let health = Health {
status: "ok",
version: env!("CARGO_PKG_VERSION"),
git_hash: option_env!("VERGEN_GIT_SHA").unwrap_or("unknown"),
};
let body = serde_json::to_string(&health).unwrap_or_else(|_| r#"{"status":"error"}"#.into());
response
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(body.into())
.map_err(|err| Box::new(err) as HyperError)
}
async fn run_captive_portal_service(http_listener: TcpListener) {
info!("serving");
let mut tasks = JoinSet::new();
loop {
tokio::select! {
biased;
Some(res) = tasks.join_next() => {
if let Err(err) = res
&& err.is_panic()
{
panic!("task panicked: {err:#?}");
}
}
res = http_listener.accept() => {
match res {
Ok((stream, peer_addr)) => {
debug!(%peer_addr, "Connection opened",);
let handler = CaptivePortalService;
tasks.spawn(async move {
let stream = crate::server::streams::MaybeTlsStream::Plain(stream);
let stream = hyper_util::rt::TokioIo::new(stream);
if let Err(err) = hyper::server::conn::http1::Builder::new()
.serve_connection(stream, handler)
.with_upgrades()
.await
{
error!("Failed to serve connection: {err:?}");
}
});
}
Err(err) => {
error!(
"[CaptivePortalService] failed to accept connection: {:#?}",
err
);
}
}
}
}
}
}
#[derive(Clone)]
struct CaptivePortalService;
impl hyper::service::Service<Request<Incoming>> for CaptivePortalService {
type Response = Response<BytesBody>;
type Error = HyperError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn call(&self, req: Request<Incoming>) -> Self::Future {
match (req.method(), req.uri().path()) {
(&Method::GET, "/generate_204") => {
Box::pin(async move { serve_no_content_handler(req, Response::builder()) })
}
_ => {
let r = Response::builder()
.status(StatusCode::NOT_FOUND)
.body(NOTFOUND.into())
.map_err(|err| Box::new(err) as HyperError);
Box::pin(async move { r })
}
}
}
}
#[cfg(test)]
mod tests {
use std::{net::Ipv4Addr, time::Duration};
use http::StatusCode;
use iroh_base::{EndpointId, RelayUrl, SecretKey};
use n0_error::Result;
use n0_future::{FutureExt, SinkExt, StreamExt};
use n0_tracing_test::traced_test;
use rand::SeedableRng;
use tracing::{info, instrument};
use super::{
Access, AccessConfig, NO_CONTENT_CHALLENGE_HEADER, NO_CONTENT_RESPONSE_HEADER, RelayConfig,
Server, ServerConfig, SpawnError,
};
use crate::{
client::{ClientBuilder, ConnectError},
dns::DnsResolver,
protos::{
handshake,
relay::{ClientToRelayMsg, Datagrams, RelayToClientMsg},
},
};
async fn spawn_local_relay() -> std::result::Result<Server, SpawnError> {
Server::spawn(ServerConfig::<(), ()> {
relay: Some(RelayConfig::<(), ()> {
http_bind_addr: (Ipv4Addr::LOCALHOST, 0).into(),
tls: None,
limits: Default::default(),
key_cache_capacity: Some(1024),
access: AccessConfig::Everyone,
}),
quic: None,
metrics_addr: None,
})
.await
}
#[instrument]
async fn try_send_recv(
client_a: &mut crate::client::Client,
client_b: &mut crate::client::Client,
b_key: EndpointId,
msg: Datagrams,
) -> Result<RelayToClientMsg> {
for _ in 0..10 {
client_a
.send(ClientToRelayMsg::Datagrams {
dst_endpoint_id: b_key,
datagrams: msg.clone(),
})
.await?;
let Ok(res) = tokio::time::timeout(Duration::from_millis(500), client_b.next()).await
else {
continue;
};
let res = res.expect("stream finished")?;
return Ok(res);
}
panic!("failed to send and recv message");
}
fn dns_resolver() -> DnsResolver {
DnsResolver::new()
}
#[tokio::test]
#[traced_test]
async fn test_no_services() {
let mut server = Server::spawn(ServerConfig::<(), ()>::default())
.await
.unwrap();
let res = tokio::time::timeout(Duration::from_secs(5), server.task_handle())
.await
.expect("timeout, server not finished")
.expect("server task JoinError");
assert!(res.is_err());
}
#[tokio::test]
#[traced_test]
async fn test_conflicting_bind() {
let mut server = Server::spawn(ServerConfig::<(), ()> {
relay: Some(RelayConfig {
http_bind_addr: (Ipv4Addr::LOCALHOST, 1234).into(),
tls: None,
limits: Default::default(),
key_cache_capacity: Some(1024),
access: AccessConfig::Everyone,
}),
quic: None,
metrics_addr: Some((Ipv4Addr::LOCALHOST, 1234).into()),
})
.await
.unwrap();
let res = tokio::time::timeout(Duration::from_secs(5), server.task_handle())
.await
.expect("timeout, server not finished")
.expect("server task JoinError");
assert!(res.is_err()); }
#[tokio::test]
#[traced_test]
async fn test_root_handler() {
let server = spawn_local_relay().await.unwrap();
let url = format!("http://{}", server.http_addr().unwrap());
let client = reqwest::Client::builder().use_rustls_tls().build().unwrap();
let response = client.get(&url).send().await.unwrap();
assert_eq!(response.status(), 200);
let body = response.text().await.unwrap();
assert!(body.contains("iroh.computer"));
}
#[tokio::test]
#[traced_test]
async fn test_captive_portal_service() {
let server = spawn_local_relay().await.unwrap();
let url = format!("http://{}/generate_204", server.http_addr().unwrap());
let challenge = "123az__.";
let client = reqwest::Client::builder().use_rustls_tls().build().unwrap();
let response = client
.get(&url)
.header(NO_CONTENT_CHALLENGE_HEADER, challenge)
.send()
.await
.unwrap();
assert_eq!(response.status(), StatusCode::NO_CONTENT);
let header = response.headers().get(NO_CONTENT_RESPONSE_HEADER).unwrap();
assert_eq!(header.to_str().unwrap(), format!("response {challenge}"));
let body = response.text().await.unwrap();
assert!(body.is_empty());
}
#[tokio::test]
#[traced_test]
async fn test_relay_clients() -> Result<()> {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
let server = spawn_local_relay().await?;
let relay_url = format!("http://{}", server.http_addr().unwrap());
let relay_url: RelayUrl = relay_url.parse()?;
let a_secret_key = SecretKey::generate(&mut rng);
let a_key = a_secret_key.public();
let resolver = dns_resolver();
info!("client a build & connect");
let mut client_a = ClientBuilder::new(relay_url.clone(), a_secret_key, resolver.clone())
.connect()
.await?;
let b_secret_key = SecretKey::generate(&mut rng);
let b_key = b_secret_key.public();
info!("client b build & connect");
let mut client_b = ClientBuilder::new(relay_url.clone(), b_secret_key, resolver.clone())
.connect()
.await?;
info!("sending a -> b");
let msg = Datagrams::from("hello, b");
let res = try_send_recv(&mut client_a, &mut client_b, b_key, msg.clone()).await?;
let RelayToClientMsg::Datagrams {
remote_endpoint_id,
datagrams,
} = res
else {
panic!("client_b received unexpected message {res:?}");
};
assert_eq!(a_key, remote_endpoint_id);
assert_eq!(msg, datagrams);
info!("sending b -> a");
let msg = Datagrams::from("howdy, a");
let res = try_send_recv(&mut client_b, &mut client_a, a_key, msg.clone()).await?;
let RelayToClientMsg::Datagrams {
remote_endpoint_id,
datagrams,
} = res
else {
panic!("client_a received unexpected message {res:?}");
};
assert_eq!(b_key, remote_endpoint_id);
assert_eq!(msg, datagrams);
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_relay_access_control() -> Result<()> {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
let current_span = tracing::info_span!("this is a test");
let _guard = current_span.enter();
let a_secret_key = SecretKey::generate(&mut rng);
let a_key = a_secret_key.public();
let server = Server::spawn(ServerConfig::<(), ()> {
relay: Some(RelayConfig::<(), ()> {
http_bind_addr: (Ipv4Addr::LOCALHOST, 0).into(),
tls: None,
limits: Default::default(),
key_cache_capacity: Some(1024),
access: AccessConfig::Restricted(Box::new(move |endpoint_id| {
async move {
info!("checking {}", endpoint_id);
if endpoint_id == a_key {
Access::Deny
} else {
Access::Allow
}
}
.boxed()
})),
}),
quic: None,
metrics_addr: None,
})
.await?;
let relay_url = format!("http://{}", server.http_addr().unwrap());
let relay_url: RelayUrl = relay_url.parse()?;
let resolver = dns_resolver();
let result = ClientBuilder::new(relay_url.clone(), a_secret_key, resolver)
.connect()
.await;
assert!(
matches!(result, Err(ConnectError::Handshake { source: handshake::Error::ServerDeniedAuth { reason, .. }, .. }) if reason == "not authorized")
);
let b_secret_key = SecretKey::generate(&mut rng);
let b_key = b_secret_key.public();
let resolver = dns_resolver();
let mut client_b = ClientBuilder::new(relay_url.clone(), b_secret_key, resolver)
.connect()
.await?;
let c_secret_key = SecretKey::generate(&mut rng);
let c_key = c_secret_key.public();
let resolver = dns_resolver();
let mut client_c = ClientBuilder::new(relay_url.clone(), c_secret_key, resolver)
.connect()
.await?;
let msg = Datagrams::from("hello, c");
let res = try_send_recv(&mut client_b, &mut client_c, c_key, msg.clone()).await?;
if let RelayToClientMsg::Datagrams {
remote_endpoint_id,
datagrams,
} = res
{
assert_eq!(b_key, remote_endpoint_id);
assert_eq!(msg, datagrams);
} else {
panic!("client_c received unexpected message {res:?}");
}
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_relay_clients_full() -> Result<()> {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
let server = spawn_local_relay().await.unwrap();
let relay_url = format!("http://{}", server.http_addr().unwrap());
let relay_url: RelayUrl = relay_url.parse().unwrap();
let a_secret_key = SecretKey::generate(&mut rng);
let resolver = dns_resolver();
let mut client_a = ClientBuilder::new(relay_url.clone(), a_secret_key, resolver.clone())
.connect()
.await?;
let b_secret_key = SecretKey::generate(&mut rng);
let b_key = b_secret_key.public();
let _client_b = ClientBuilder::new(relay_url.clone(), b_secret_key, resolver.clone())
.connect()
.await?;
let msg = Datagrams::from("hello, b");
for _i in 0..1000 {
client_a
.send(ClientToRelayMsg::Datagrams {
dst_endpoint_id: b_key,
datagrams: msg.clone(),
})
.await?;
}
Ok(())
}
}