use std::{collections::HashSet, sync::Arc, time::Duration};
use iroh_base::EndpointId;
use n0_error::{e, stack_error};
use n0_future::{SinkExt, StreamExt};
use rand::Rng;
use time::{Date, OffsetDateTime};
use tokio::{
sync::mpsc::{self, error::TrySendError},
time::MissedTickBehavior,
};
use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle};
use tracing::{Instrument, debug, trace, warn};
use crate::{
PingTracker,
protos::{
relay::{ClientToRelayMsg, Datagrams, PING_INTERVAL, RelayToClientMsg},
streams::BytesStreamSink,
},
server::{
clients::Clients,
metrics::Metrics,
streams::{RecvError as RelayRecvError, RelayedStream, SendError as RelaySendError},
},
};
#[derive(Debug, Clone)]
pub(super) struct Packet {
src: EndpointId,
data: Datagrams,
}
#[derive(Debug)]
pub struct Config<S> {
pub endpoint_id: EndpointId,
pub stream: RelayedStream<S>,
pub write_timeout: Duration,
pub channel_capacity: usize,
}
#[derive(Debug)]
pub struct Client {
endpoint_id: EndpointId,
connection_id: u64,
done: CancellationToken,
handle: AbortOnDropHandle<()>,
packet_queue: mpsc::Sender<Packet>,
message_queue: mpsc::Sender<RelayToClientMsg>,
}
impl Client {
pub(super) fn new<S>(
config: Config<S>,
connection_id: u64,
clients: &Clients,
metrics: Arc<Metrics>,
) -> Client
where
S: BytesStreamSink + Send + 'static,
{
let Config {
endpoint_id,
stream,
write_timeout,
channel_capacity,
} = config;
let (packet_send_queue_s, packet_send_queue_r) = mpsc::channel(channel_capacity);
let (message_send_queue_s, message_send_queue_r) = mpsc::channel(channel_capacity);
let done = CancellationToken::new();
let actor = Actor {
stream,
timeout: write_timeout,
packet_send_queue: packet_send_queue_r,
message_send_queue: message_send_queue_r,
endpoint_id,
connection_id,
clients: clients.clone(),
client_counter: ClientCounter::default(),
ping_tracker: PingTracker::default(),
metrics,
};
let io_done = done.clone();
let handle = tokio::task::spawn(actor.run(io_done).instrument(tracing::info_span!(
"client-connection-actor",
remote_endpoint = %endpoint_id.fmt_short(),
connection_id = connection_id
)));
Client {
endpoint_id,
connection_id,
handle: AbortOnDropHandle::new(handle),
done,
packet_queue: packet_send_queue_s,
message_queue: message_send_queue_s,
}
}
pub(super) fn connection_id(&self) -> u64 {
self.connection_id
}
pub(super) async fn shutdown(self) {
self.start_shutdown();
if let Err(e) = self.handle.await {
warn!(
remote_endpoint = %self.endpoint_id.fmt_short(),
"error closing actor loop: {e:#?}",
);
};
}
pub(super) fn start_shutdown(&self) {
self.done.cancel();
}
pub(super) fn try_send_packet(
&self,
src: EndpointId,
data: Datagrams,
) -> Result<(), TrySendError<Packet>> {
self.packet_queue.try_send(Packet { src, data })
}
pub(super) fn try_send_peer_gone(
&self,
key: EndpointId,
) -> Result<(), TrySendError<RelayToClientMsg>> {
self.message_queue
.try_send(RelayToClientMsg::EndpointGone(key))
}
pub(super) fn try_send_health(
&self,
problem: String,
) -> Result<(), TrySendError<RelayToClientMsg>> {
self.message_queue
.try_send(RelayToClientMsg::Health { problem })
}
}
#[stack_error(derive, add_meta, from_sources)]
#[allow(missing_docs)]
#[non_exhaustive]
pub enum HandleFrameError {
#[error(transparent)]
ForwardPacket { source: ForwardPacketError },
#[error("Stream terminated")]
StreamTerminated {},
#[error(transparent)]
Recv { source: RelayRecvError },
#[error(transparent)]
Send { source: WriteFrameError },
}
#[stack_error(derive, add_meta, from_sources)]
#[allow(missing_docs)]
#[non_exhaustive]
pub enum WriteFrameError {
#[error(transparent)]
Stream { source: RelaySendError },
#[error(transparent)]
Timeout {
#[error(std_err)]
source: tokio::time::error::Elapsed,
},
}
#[stack_error(derive, add_meta)]
#[allow(missing_docs)]
#[non_exhaustive]
pub enum RunError {
#[error(transparent)]
ForwardPacket {
#[error(from)]
source: ForwardPacketError,
},
#[error("Flush")]
Flush {},
#[error(transparent)]
HandleFrame {
#[error(from)]
source: HandleFrameError,
},
#[error("Failed to send packet")]
PacketSend { source: WriteFrameError },
#[error("Handle was dropped")]
HandleDropped {},
#[error("Writing a frame failed")]
WriteFrame { source: WriteFrameError },
#[error("Tick flush")]
TickFlush {},
}
#[derive(Debug)]
struct Actor<S> {
stream: RelayedStream<S>,
timeout: Duration,
packet_send_queue: mpsc::Receiver<Packet>,
message_send_queue: mpsc::Receiver<RelayToClientMsg>,
endpoint_id: EndpointId,
connection_id: u64,
clients: Clients,
client_counter: ClientCounter,
ping_tracker: PingTracker,
metrics: Arc<Metrics>,
}
impl<S> Actor<S>
where
S: BytesStreamSink,
{
async fn run(mut self, done: CancellationToken) {
self.metrics.accepts.inc();
if self.client_counter.update(self.endpoint_id) {
self.metrics.unique_client_keys.inc();
}
match self.run_inner(done).await {
Err(e) => {
warn!("actor errored {e:#}, exiting");
}
Ok(()) => {
debug!("actor finished, exiting");
}
}
self.clients
.unregister(self.connection_id, self.endpoint_id);
self.metrics.disconnects.inc();
}
async fn run_inner(&mut self, done: CancellationToken) -> Result<(), RunError> {
let next_interval = || {
let random_secs = rand::rng().random_range(1..=5);
Duration::from_secs(random_secs) + PING_INTERVAL
};
let mut ping_interval = tokio::time::interval(next_interval());
ping_interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
ping_interval.tick().await;
loop {
tokio::select! {
biased;
_ = done.cancelled() => {
trace!("actor loop cancelled, exiting");
self.stream.flush().await.map_err(|_| e!(RunError::Flush))?;
break;
}
maybe_frame = self.stream.next() => {
self
.handle_frame(maybe_frame)
.await?;
ping_interval.reset();
}
packet = self.packet_send_queue.recv() => {
let packet = packet.ok_or_else(|| e!(RunError::HandleDropped))?;
self.send_packet(packet)
.await
.map_err(|err| e!(RunError::PacketSend, err))?;
}
message = self.message_send_queue.recv() => {
let message = message .ok_or_else(|| e!(RunError::HandleDropped))?;
trace!("send {message:?}");
self.write_frame(message)
.await
.map_err(|err| e!(RunError::WriteFrame, err))?;
}
_ = self.ping_tracker.timeout() => {
trace!("pong timed out");
break;
}
_ = ping_interval.tick() => {
trace!("keep alive ping");
ping_interval.reset_after(next_interval());
let data = self.ping_tracker.new_ping();
self.write_frame(RelayToClientMsg::Ping(data))
.await
.map_err(|err| e!(RunError::WriteFrame, err))?;
}
}
self.stream
.flush()
.await
.map_err(|_| e!(RunError::TickFlush))?;
}
Ok(())
}
async fn write_frame(&mut self, frame: RelayToClientMsg) -> Result<(), WriteFrameError> {
tokio::time::timeout(self.timeout, self.stream.send(frame)).await??;
Ok(())
}
async fn send_raw(&mut self, packet: Packet) -> Result<(), WriteFrameError> {
let remote_endpoint_id = packet.src;
let datagrams = packet.data;
if let Ok(len) = datagrams.contents.len().try_into() {
self.metrics.bytes_sent.inc_by(len);
}
self.write_frame(RelayToClientMsg::Datagrams {
remote_endpoint_id,
datagrams,
})
.await
}
async fn send_packet(&mut self, packet: Packet) -> Result<(), WriteFrameError> {
trace!("send packet");
match self.send_raw(packet).await {
Ok(()) => {
self.metrics.send_packets_sent.inc();
Ok(())
}
Err(err) => {
self.metrics.send_packets_dropped.inc();
Err(err)
}
}
}
async fn handle_frame(
&mut self,
maybe_frame: Option<Result<ClientToRelayMsg, RelayRecvError>>,
) -> Result<(), HandleFrameError> {
trace!(?maybe_frame, "handle incoming frame");
let frame = match maybe_frame {
Some(frame) => frame?,
None => return Err(e!(HandleFrameError::StreamTerminated)),
};
match frame {
ClientToRelayMsg::Datagrams {
dst_endpoint_id: dst_key,
datagrams,
} => {
let packet_len = datagrams.contents.len();
if let Err(err @ ForwardPacketError { .. }) =
self.handle_frame_send_packet(dst_key, datagrams)
{
warn!("failed to handle send packet frame: {err:#}");
}
self.metrics.bytes_recv.inc_by(packet_len as u64);
}
ClientToRelayMsg::Ping(data) => {
self.metrics.got_ping.inc();
self.write_frame(RelayToClientMsg::Pong(data)).await?;
self.metrics.sent_pong.inc();
}
ClientToRelayMsg::Pong(data) => {
self.ping_tracker.pong_received(data);
}
}
Ok(())
}
fn handle_frame_send_packet(
&self,
dst: EndpointId,
data: Datagrams,
) -> Result<(), ForwardPacketError> {
self.metrics.send_packets_recv.inc();
self.clients
.send_packet(dst, data, self.endpoint_id, &self.metrics)?;
Ok(())
}
}
#[derive(Debug)]
pub(crate) enum SendError {
Full,
Closed,
}
#[stack_error(derive, add_meta)]
#[error("failed to forward packet: {reason:?}")]
pub struct ForwardPacketError {
reason: SendError,
}
#[derive(Debug)]
struct ClientCounter {
clients: HashSet<EndpointId>,
last_clear_date: Date,
}
impl Default for ClientCounter {
fn default() -> Self {
Self {
clients: HashSet::new(),
last_clear_date: OffsetDateTime::now_utc().date(),
}
}
}
impl ClientCounter {
fn check_and_clear(&mut self) {
let today = OffsetDateTime::now_utc().date();
if today != self.last_clear_date {
self.clients.clear();
self.last_clear_date = today;
}
}
fn update(&mut self, client: EndpointId) -> bool {
self.check_and_clear();
self.clients.insert(client)
}
}
#[cfg(test)]
mod tests {
use iroh_base::SecretKey;
use n0_error::{Result, StdResultExt, bail_any};
use n0_future::Stream;
use n0_tracing_test::traced_test;
use rand::SeedableRng;
use tracing::info;
use super::*;
use crate::{client::conn::Conn, protos::common::FrameType};
async fn recv_frame<
E: std::error::Error + Sync + Send + 'static,
S: Stream<Item = Result<RelayToClientMsg, E>> + Unpin,
>(
frame_type: FrameType,
mut stream: S,
) -> Result<RelayToClientMsg> {
match stream.next().await {
Some(Ok(frame)) => {
if frame_type != frame.typ() {
bail_any!(
"Unexpected frame, got {:?}, but expected {:?}",
frame.typ(),
frame_type
);
}
Ok(frame)
}
Some(Err(err)) => Err(err).anyerr(),
None => bail_any!("Unexpected EOF, expected frame {frame_type:?}"),
}
}
#[tokio::test]
#[traced_test]
async fn test_client_actor_basic() -> Result {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
let (send_queue_s, send_queue_r) = mpsc::channel(10);
let (message_s, message_r) = mpsc::channel(10);
let endpoint_id = SecretKey::generate(&mut rng).public();
let (io, io_rw) = tokio::io::duplex(1024);
let mut io_rw = Conn::test(io_rw);
let stream = RelayedStream::test(io);
let clients = Clients::default();
let metrics = Arc::new(Metrics::default());
let actor = Actor {
stream,
timeout: Duration::from_secs(1),
packet_send_queue: send_queue_r,
message_send_queue: message_r,
connection_id: 0,
endpoint_id,
clients: clients.clone(),
client_counter: ClientCounter::default(),
ping_tracker: PingTracker::default(),
metrics,
};
let done = CancellationToken::new();
let io_done = done.clone();
let handle = tokio::task::spawn(async move { actor.run(io_done).await });
println!("-- write");
let data = b"hello world!";
println!(" send packet");
let packet = Packet {
src: endpoint_id,
data: Datagrams::from(&data[..]),
};
send_queue_s
.send(packet.clone())
.await
.std_context("send")?;
let frame = recv_frame(FrameType::RelayToClientDatagram, &mut io_rw)
.await
.anyerr()?;
assert_eq!(
frame,
RelayToClientMsg::Datagrams {
remote_endpoint_id: endpoint_id,
datagrams: data.to_vec().into()
}
);
println!("send peer gone");
message_s
.send(RelayToClientMsg::EndpointGone(endpoint_id))
.await
.std_context("send")?;
let frame = recv_frame(FrameType::EndpointGone, &mut io_rw)
.await
.anyerr()?;
assert_eq!(frame, RelayToClientMsg::EndpointGone(endpoint_id));
println!("--read");
let data = b"pingpong";
io_rw.send(ClientToRelayMsg::Ping(*data)).await?;
println!(" recv pong");
let frame = recv_frame(FrameType::Pong, &mut io_rw).await?;
assert_eq!(frame, RelayToClientMsg::Pong(*data));
let target = SecretKey::generate(&mut rng).public();
println!(" send packet");
let data = b"hello world!";
io_rw
.send(ClientToRelayMsg::Datagrams {
dst_endpoint_id: target,
datagrams: Datagrams::from(data),
})
.await
.std_context("send")?;
done.cancel();
handle.await.std_context("join")?;
Ok(())
}
#[tokio::test(start_paused = true)]
#[traced_test]
async fn test_rate_limit() -> Result {
const LIMIT: u32 = 50;
const MAX_FRAMES: u32 = 100;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
let (io_read, io_write) = tokio::io::duplex((LIMIT * MAX_FRAMES) as _);
let mut frame_writer = Conn::test(io_write);
let mut stream = RelayedStream::test_limited(io_read, LIMIT / 10, LIMIT)?;
let data = Datagrams::from(b"hello world!!!!!");
let target = SecretKey::generate(&mut rng).public();
let frame = ClientToRelayMsg::Datagrams {
dst_endpoint_id: target,
datagrams: data.clone(),
};
let frame_len = frame.to_bytes().len();
assert_eq!(frame_len, LIMIT as usize);
info!("-- send packet");
frame_writer.send(frame.clone()).await.std_context("send")?;
frame_writer.flush().await.std_context("flush")?;
let recv_frame = tokio::time::timeout(Duration::from_millis(500), stream.next())
.await
.expect("timeout")
.expect("option")
.expect("ok");
assert_eq!(recv_frame, frame);
info!("-- send packet");
frame_writer.send(frame.clone()).await.std_context("send")?;
frame_writer.flush().await.std_context("flush")?;
let res = tokio::time::timeout(Duration::from_millis(100), stream.next()).await;
assert!(res.is_err(), "expecting a timeout");
info!("-- timeout happened");
info!("-- sleep");
tokio::time::sleep(Duration::from_secs(1)).await;
let recv_frame = tokio::time::timeout(Duration::from_millis(500), stream.next())
.await
.expect("timeout")
.expect("option")
.expect("ok");
assert_eq!(recv_frame, frame);
Ok(())
}
}