use std::{
collections::{hash_map::Entry, BTreeSet, HashMap, HashSet, VecDeque},
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use bytes::Bytes;
use futures_concurrency::stream::{stream_group, StreamGroup};
use futures_util::FutureExt as _;
use iroh::{
endpoint::Connection,
protocol::{AcceptError, ProtocolHandler},
Endpoint, EndpointAddr, EndpointId, PublicKey, RelayUrl, Watcher,
};
use irpc::WithChannels;
use n0_error::{e, stack_error};
use n0_future::{
task::{self, AbortOnDropHandle, JoinSet},
time::Instant,
Stream, StreamExt as _,
};
use rand::{rngs::StdRng, SeedableRng};
use serde::{Deserialize, Serialize};
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, error_span, trace, warn, Instrument};
use self::{
address_lookup::GossipAddressLookup,
util::{RecvLoop, SendLoop, Timers},
};
use crate::{
api::{self, Command, Event, GossipApi, RpcMessage},
metrics::Metrics,
proto::{self, HyparviewConfig, PeerData, PlumtreeConfig, Scope, TopicId},
};
mod address_lookup;
mod util;
pub const GOSSIP_ALPN: &[u8] = b"/iroh-gossip/1";
const SEND_QUEUE_CAP: usize = 64;
const TO_ACTOR_CAP: usize = 64;
const IN_EVENT_CAP: usize = 1024;
const TOPIC_EVENT_CAP: usize = 256;
pub type ProtoEvent = proto::Event<PublicKey>;
pub type ProtoCommand = proto::Command<PublicKey>;
type InEvent = proto::InEvent<PublicKey>;
type OutEvent = proto::OutEvent<PublicKey>;
type Timer = proto::Timer<PublicKey>;
type ProtoMessage = proto::Message<PublicKey>;
#[derive(Debug, Clone)]
pub struct Gossip {
pub(crate) inner: Arc<Inner>,
}
impl std::ops::Deref for Gossip {
type Target = GossipApi;
fn deref(&self) -> &Self::Target {
&self.inner.api
}
}
#[derive(Debug)]
enum LocalActorMessage {
HandleConnection(Connection),
Shutdown { reply: oneshot::Sender<()> },
}
#[allow(missing_docs)]
#[stack_error(derive, add_meta)]
#[non_exhaustive]
pub enum Error {
ActorDropped {},
}
impl<T> From<mpsc::error::SendError<T>> for Error {
fn from(_value: mpsc::error::SendError<T>) -> Self {
e!(Error::ActorDropped)
}
}
impl From<oneshot::error::RecvError> for Error {
fn from(_value: oneshot::error::RecvError) -> Self {
e!(Error::ActorDropped)
}
}
#[derive(Debug)]
pub(crate) struct Inner {
api: GossipApi,
local_tx: mpsc::Sender<LocalActorMessage>,
_actor_handle: AbortOnDropHandle<()>,
max_message_size: usize,
metrics: Arc<Metrics>,
}
impl ProtocolHandler for Gossip {
async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
self.handle_connection(connection)
.await
.map_err(AcceptError::from_err)?;
Ok(())
}
async fn shutdown(&self) {
if let Err(err) = self.shutdown().await {
warn!("error while shutting down gossip: {err:#}");
}
}
}
#[derive(Debug, Clone)]
pub struct Builder {
config: proto::Config,
alpn: Option<Bytes>,
}
impl Builder {
pub fn max_message_size(mut self, size: usize) -> Self {
self.config.max_message_size = size;
self
}
pub fn membership_config(mut self, config: HyparviewConfig) -> Self {
self.config.membership = config;
self
}
pub fn broadcast_config(mut self, config: PlumtreeConfig) -> Self {
self.config.broadcast = config;
self
}
pub fn alpn(mut self, alpn: impl AsRef<[u8]>) -> Self {
self.alpn = Some(alpn.as_ref().to_vec().into());
self
}
pub fn spawn(self, endpoint: Endpoint) -> Gossip {
let metrics = Arc::new(Metrics::default());
let address_lookup = GossipAddressLookup::default();
endpoint.address_lookup().expect("endpoint closed").add(address_lookup.clone());
let (actor, rpc_tx, local_tx) = Actor::new(
endpoint,
self.config,
metrics.clone(),
self.alpn,
address_lookup,
);
let me = actor.endpoint.id().fmt_short();
let max_message_size = actor.state.max_message_size();
let actor_handle = task::spawn(actor.run().instrument(error_span!("gossip", %me)));
let api = GossipApi::local(rpc_tx);
Gossip {
inner: Inner {
api,
local_tx,
_actor_handle: AbortOnDropHandle::new(actor_handle),
max_message_size,
metrics,
}
.into(),
}
}
}
impl Gossip {
pub fn builder() -> Builder {
Builder {
config: Default::default(),
alpn: None,
}
}
#[cfg(feature = "rpc")]
pub async fn listen(self, endpoint: quinn::Endpoint) {
self.inner.api.listen(endpoint).await
}
pub fn max_message_size(&self) -> usize {
self.inner.max_message_size
}
pub async fn handle_connection(&self, conn: Connection) -> Result<(), Error> {
self.inner
.local_tx
.send(LocalActorMessage::HandleConnection(conn))
.await?;
Ok(())
}
pub async fn shutdown(&self) -> Result<(), Error> {
let (reply, reply_rx) = oneshot::channel();
self.inner
.local_tx
.send(LocalActorMessage::Shutdown { reply })
.await?;
reply_rx.await?;
Ok(())
}
pub fn metrics(&self) -> &Arc<Metrics> {
&self.inner.metrics
}
}
struct Actor {
alpn: Bytes,
state: proto::State<PublicKey, StdRng>,
endpoint: Endpoint,
dialer: Dialer,
rpc_rx: mpsc::Receiver<RpcMessage>,
local_rx: mpsc::Receiver<LocalActorMessage>,
in_event_tx: mpsc::Sender<InEvent>,
in_event_rx: mpsc::Receiver<InEvent>,
timers: Timers<Timer>,
topics: HashMap<TopicId, TopicState>,
peers: HashMap<EndpointId, PeerState>,
command_rx: stream_group::Keyed<TopicCommandStream>,
quit_queue: VecDeque<TopicId>,
connection_tasks: JoinSet<(EndpointId, Connection, Result<(), ConnectionLoopError>)>,
metrics: Arc<Metrics>,
topic_event_forwarders: JoinSet<TopicId>,
address_lookup: GossipAddressLookup,
}
impl Actor {
fn new(
endpoint: Endpoint,
config: proto::Config,
metrics: Arc<Metrics>,
alpn: Option<Bytes>,
address_lookup: GossipAddressLookup,
) -> (
Self,
mpsc::Sender<RpcMessage>,
mpsc::Sender<LocalActorMessage>,
) {
let peer_id = endpoint.id();
let dialer = Dialer::new(endpoint.clone());
let state = proto::State::new(
peer_id,
Default::default(),
config,
rand::rngs::StdRng::from_rng(&mut rand::rng()),
);
let (rpc_tx, rpc_rx) = mpsc::channel(TO_ACTOR_CAP);
let (local_tx, local_rx) = mpsc::channel(16);
let (in_event_tx, in_event_rx) = mpsc::channel(IN_EVENT_CAP);
let actor = Actor {
alpn: alpn.unwrap_or_else(|| GOSSIP_ALPN.to_vec().into()),
endpoint,
state,
dialer,
rpc_rx,
in_event_rx,
in_event_tx,
timers: Timers::new(),
command_rx: StreamGroup::new().keyed(),
peers: Default::default(),
topics: Default::default(),
quit_queue: Default::default(),
connection_tasks: Default::default(),
metrics,
local_rx,
topic_event_forwarders: Default::default(),
address_lookup,
};
(actor, rpc_tx, local_tx)
}
pub async fn run(mut self) {
let mut addr_update_stream = self.setup().await;
let mut i = 0;
while self.event_loop(&mut addr_update_stream, i).await {
i += 1;
}
}
async fn setup(&mut self) -> impl Stream<Item = EndpointAddr> + Send + Unpin + use<> {
let addr_update_stream = self.endpoint.watch_addr().stream();
let initial_addr = self.endpoint.addr();
self.handle_addr_update(initial_addr).await;
addr_update_stream
}
async fn event_loop(
&mut self,
addr_updates: &mut (impl Stream<Item = EndpointAddr> + Send + Unpin),
i: usize,
) -> bool {
self.metrics.actor_tick_main.inc();
tokio::select! {
biased;
conn = self.local_rx.recv() => {
match conn {
Some(LocalActorMessage::Shutdown { reply }) => {
debug!("received shutdown message, quit all topics");
self.quit_queue.extend(self.topics.keys().copied());
self.process_quit_queue().await;
debug!("all topics quit, stop gossip actor");
reply.send(()).ok();
return false;
},
Some(LocalActorMessage::HandleConnection(conn)) => {
self.handle_connection(conn.remote_id(), ConnOrigin::Accept, conn);
}
None => {
debug!("all gossip handles dropped, stop gossip actor");
return false;
}
}
}
msg = self.rpc_rx.recv() => {
trace!(?i, "tick: to_actor_rx");
self.metrics.actor_tick_rx.inc();
match msg {
Some(msg) => {
self.handle_rpc_msg(msg, Instant::now()).await;
}
None => {
debug!("all gossip handles dropped, stop gossip actor");
return false;
}
}
},
Some((key, (topic, command))) = self.command_rx.next(), if !self.command_rx.is_empty() => {
trace!(?i, "tick: command_rx");
self.handle_command(topic, key, command).await;
},
Some(new_address) = addr_updates.next() => {
trace!(?i, "tick: new_address");
self.metrics.actor_tick_endpoint.inc();
self.handle_addr_update(new_address).await;
}
(peer_id, res) = self.dialer.next_conn() => {
trace!(?i, "tick: dialer");
self.metrics.actor_tick_dialer.inc();
match res {
Some(Ok(conn)) => {
debug!(peer = %peer_id.fmt_short(), "dial successful");
self.metrics.actor_tick_dialer_success.inc();
self.handle_connection(peer_id, ConnOrigin::Dial, conn);
}
Some(Err(err)) => {
warn!(peer = %peer_id.fmt_short(), "dial failed: {err}");
self.metrics.actor_tick_dialer_failure.inc();
let peer_state = self.peers.get(&peer_id);
let is_active = matches!(peer_state, Some(PeerState::Active { .. }));
if !is_active {
self.handle_in_event(InEvent::PeerDisconnected(peer_id), Instant::now())
.await;
}
}
None => {
warn!(peer = %peer_id.fmt_short(), "dial disconnected");
self.metrics.actor_tick_dialer_failure.inc();
}
}
}
event = self.in_event_rx.recv() => {
trace!(?i, "tick: in_event_rx");
self.metrics.actor_tick_in_event_rx.inc();
let event = event.expect("unreachable: in_event_tx is never dropped before receiver");
self.handle_in_event(event, Instant::now()).await;
}
_ = self.timers.wait_next() => {
trace!(?i, "tick: timers");
self.metrics.actor_tick_timers.inc();
let now = Instant::now();
while let Some((_instant, timer)) = self.timers.pop_before(now) {
self.handle_in_event(InEvent::TimerExpired(timer), now).await;
}
}
Some(res) = self.connection_tasks.join_next(), if !self.connection_tasks.is_empty() => {
trace!(?i, "tick: connection_tasks");
let (peer_id, conn, result) = res.expect("connection task panicked");
self.handle_connection_task_finished(peer_id, conn, result).await;
}
Some(res) = self.topic_event_forwarders.join_next(), if !self.topic_event_forwarders.is_empty() => {
let topic_id = res.expect("topic event forwarder panicked");
if let Some(state) = self.topics.get_mut(&topic_id) {
if !state.still_needed() {
self.quit_queue.push_back(topic_id);
self.process_quit_queue().await;
}
}
}
}
true
}
async fn handle_addr_update(&mut self, endpoint_addr: EndpointAddr) {
let peer_data = encode_peer_data(&endpoint_addr.into());
self.handle_in_event(InEvent::UpdatePeerData(peer_data), Instant::now())
.await
}
async fn handle_command(
&mut self,
topic: TopicId,
key: stream_group::Key,
command: Option<Command>,
) {
debug!(?topic, ?key, ?command, "handle command");
let Some(state) = self.topics.get_mut(&topic) else {
warn!("received command for unknown topic");
return;
};
match command {
Some(command) => {
let command = match command {
Command::Broadcast(message) => ProtoCommand::Broadcast(message, Scope::Swarm),
Command::BroadcastNeighbors(message) => {
ProtoCommand::Broadcast(message, Scope::Neighbors)
}
Command::JoinPeers(peers) => ProtoCommand::Join(peers),
};
self.handle_in_event(proto::InEvent::Command(topic, command), Instant::now())
.await;
}
None => {
state.command_rx_keys.remove(&key);
if !state.still_needed() {
self.quit_queue.push_back(topic);
self.process_quit_queue().await;
}
}
}
}
fn handle_connection(&mut self, peer_id: EndpointId, origin: ConnOrigin, conn: Connection) {
let (send_tx, send_rx) = mpsc::channel(SEND_QUEUE_CAP);
let conn_id = conn.stable_id();
let queue = match self.peers.entry(peer_id) {
Entry::Occupied(mut entry) => entry.get_mut().accept_conn(send_tx, conn_id),
Entry::Vacant(entry) => {
entry.insert(PeerState::Active {
active_send_tx: send_tx,
active_conn_id: conn_id,
other_conns: Vec::new(),
});
Vec::new()
}
};
let max_message_size = self.state.max_message_size();
let in_event_tx = self.in_event_tx.clone();
self.connection_tasks.spawn(
async move {
let res = connection_loop(
peer_id,
conn.clone(),
origin,
send_rx,
in_event_tx,
max_message_size,
queue,
)
.await;
(peer_id, conn, res)
}
.instrument(error_span!("conn", peer = %peer_id.fmt_short())),
);
}
#[tracing::instrument(name = "conn", skip_all, fields(peer = %peer_id.fmt_short()))]
async fn handle_connection_task_finished(
&mut self,
peer_id: EndpointId,
conn: Connection,
task_result: Result<(), ConnectionLoopError>,
) {
if conn.close_reason().is_none() {
conn.close(0u32.into(), b"close from disconnect");
}
let reason = conn.close_reason().expect("just closed");
let error = task_result.err();
debug!(%reason, ?error, "connection closed");
if let Some(PeerState::Active {
active_conn_id,
other_conns,
..
}) = self.peers.get_mut(&peer_id)
{
if conn.stable_id() == *active_conn_id {
debug!("active send connection closed, mark peer as disconnected");
self.handle_in_event(InEvent::PeerDisconnected(peer_id), Instant::now())
.await;
} else {
other_conns.retain(|x| *x != conn.stable_id());
debug!("remaining {} other connections", other_conns.len() + 1);
}
} else {
debug!("peer already marked as disconnected");
}
}
async fn handle_rpc_msg(&mut self, msg: RpcMessage, now: Instant) {
trace!("handle to_actor {msg:?}");
match msg {
RpcMessage::Join(msg) => {
let WithChannels {
inner,
rx,
tx,
span: _,
} = msg;
let api::JoinRequest {
topic_id,
bootstrap,
} = inner;
let TopicState {
neighbors,
event_sender,
command_rx_keys,
} = self.topics.entry(topic_id).or_default();
let mut sender_dead = false;
if !neighbors.is_empty() {
for neighbor in neighbors.iter() {
if let Err(_err) = tx.try_send(Event::NeighborUp(*neighbor)).await {
sender_dead = true;
break;
}
}
}
if !sender_dead {
let fut =
topic_subscriber_loop(tx, event_sender.subscribe()).map(move |_| topic_id);
self.topic_event_forwarders
.spawn(fut.instrument(tracing::Span::current()));
}
let command_rx = TopicCommandStream::new(topic_id, Box::pin(rx.into_stream()));
let key = self.command_rx.insert(command_rx);
command_rx_keys.insert(key);
self.handle_in_event(
InEvent::Command(
topic_id,
ProtoCommand::Join(bootstrap.into_iter().collect()),
),
now,
)
.await;
}
}
}
async fn handle_in_event(&mut self, event: InEvent, now: Instant) {
self.handle_in_event_inner(event, now).await;
self.process_quit_queue().await;
}
async fn process_quit_queue(&mut self) {
while let Some(topic_id) = self.quit_queue.pop_front() {
self.handle_in_event_inner(
InEvent::Command(topic_id, ProtoCommand::Quit),
Instant::now(),
)
.await;
if self.topics.remove(&topic_id).is_some() {
tracing::debug!(%topic_id, "publishers and subscribers gone; unsubscribing");
}
}
}
async fn handle_in_event_inner(&mut self, event: InEvent, now: Instant) {
if matches!(event, InEvent::TimerExpired(_)) {
trace!(?event, "handle in_event");
} else {
debug!(?event, "handle in_event");
};
let out = self.state.handle(event, now, Some(&self.metrics));
for event in out {
if matches!(event, OutEvent::ScheduleTimer(_, _)) {
trace!(?event, "handle out_event");
} else {
debug!(?event, "handle out_event");
};
match event {
OutEvent::SendMessage(peer_id, message) => {
let state = self.peers.entry(peer_id).or_default();
match state {
PeerState::Active { active_send_tx, .. } => {
if let Err(_err) = active_send_tx.send(message).await {
warn!(
peer = %peer_id.fmt_short(),
"failed to send: connection task send loop terminated",
);
}
}
PeerState::Pending { queue } => {
if queue.is_empty() {
debug!(peer = %peer_id.fmt_short(), "start to dial");
self.dialer.queue_dial(peer_id, self.alpn.clone());
}
queue.push(message);
}
}
}
OutEvent::EmitEvent(topic_id, event) => {
let Some(state) = self.topics.get_mut(&topic_id) else {
warn!(?topic_id, "gossip state emitted event for unknown topic");
continue;
};
let TopicState {
neighbors,
event_sender,
..
} = state;
match &event {
ProtoEvent::NeighborUp(neighbor) => {
neighbors.insert(*neighbor);
}
ProtoEvent::NeighborDown(neighbor) => {
neighbors.remove(neighbor);
}
_ => {}
}
event_sender.send(event).ok();
if !state.still_needed() {
self.quit_queue.push_back(topic_id);
}
}
OutEvent::ScheduleTimer(delay, timer) => {
self.timers.insert(now + delay, timer);
}
OutEvent::DisconnectPeer(peer_id) => {
debug!(peer=%peer_id.fmt_short(), "gossip state indicates disconnect: drop peer");
self.peers.remove(&peer_id);
}
OutEvent::PeerData(endpoint_id, data) => match decode_peer_data(&data) {
Err(err) => warn!("Failed to decode {data:?} from {endpoint_id}: {err}"),
Ok(info) => {
debug!(peer = ?endpoint_id, "add known addrs: {info:?}");
let mut endpoint_addr = EndpointAddr::new(endpoint_id);
for addr in info.direct_addresses {
endpoint_addr = endpoint_addr.with_ip_addr(addr);
}
if let Some(relay_url) = info.relay_url {
endpoint_addr = endpoint_addr.with_relay_url(relay_url);
}
self.address_lookup.add(endpoint_addr);
}
},
}
}
}
}
type ConnId = usize;
#[derive(Debug)]
enum PeerState {
Pending {
queue: Vec<ProtoMessage>,
},
Active {
active_send_tx: mpsc::Sender<ProtoMessage>,
active_conn_id: ConnId,
other_conns: Vec<ConnId>,
},
}
impl PeerState {
fn accept_conn(
&mut self,
send_tx: mpsc::Sender<ProtoMessage>,
conn_id: ConnId,
) -> Vec<ProtoMessage> {
match self {
PeerState::Pending { queue } => {
let queue = std::mem::take(queue);
*self = PeerState::Active {
active_send_tx: send_tx,
active_conn_id: conn_id,
other_conns: Vec::new(),
};
queue
}
PeerState::Active {
active_send_tx,
active_conn_id,
other_conns,
} => {
other_conns.push(*active_conn_id);
*active_send_tx = send_tx;
*active_conn_id = conn_id;
Vec::new()
}
}
}
}
impl Default for PeerState {
fn default() -> Self {
PeerState::Pending { queue: Vec::new() }
}
}
#[derive(Debug)]
struct TopicState {
neighbors: BTreeSet<EndpointId>,
event_sender: broadcast::Sender<ProtoEvent>,
command_rx_keys: HashSet<stream_group::Key>,
}
impl Default for TopicState {
fn default() -> Self {
let (event_sender, _) = broadcast::channel(TOPIC_EVENT_CAP);
Self {
neighbors: Default::default(),
command_rx_keys: Default::default(),
event_sender,
}
}
}
impl TopicState {
fn still_needed(&self) -> bool {
!self.command_rx_keys.is_empty() || self.event_sender.receiver_count() > 0
}
#[cfg(test)]
fn joined(&self) -> bool {
!self.neighbors.is_empty()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ConnOrigin {
Accept,
Dial,
}
#[allow(missing_docs)]
#[stack_error(derive, add_meta, from_sources, std_sources)]
#[non_exhaustive]
enum ConnectionLoopError {
#[error(transparent)]
Write {
source: self::util::WriteError,
},
#[error(transparent)]
Read {
source: self::util::ReadError,
},
#[error(transparent)]
Connection {
#[error(std_err)]
source: iroh::endpoint::ConnectionError,
},
ActorDropped {},
}
impl<T> From<mpsc::error::SendError<T>> for ConnectionLoopError {
fn from(_value: mpsc::error::SendError<T>) -> Self {
e!(ConnectionLoopError::ActorDropped)
}
}
async fn connection_loop(
from: PublicKey,
conn: Connection,
origin: ConnOrigin,
send_rx: mpsc::Receiver<ProtoMessage>,
in_event_tx: mpsc::Sender<InEvent>,
max_message_size: usize,
queue: Vec<ProtoMessage>,
) -> Result<(), ConnectionLoopError> {
debug!(?origin, "connection established");
let mut send_loop = SendLoop::new(conn.clone(), send_rx, max_message_size);
let mut recv_loop = RecvLoop::new(from, conn, in_event_tx, max_message_size);
let send_fut = send_loop.run(queue).instrument(error_span!("send"));
let recv_fut = recv_loop.run().instrument(error_span!("recv"));
let (send_res, recv_res) = tokio::join!(send_fut, recv_fut);
send_res?;
recv_res?;
Ok(())
}
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
struct AddrInfo {
relay_url: Option<RelayUrl>,
direct_addresses: BTreeSet<SocketAddr>,
}
impl From<EndpointAddr> for AddrInfo {
fn from(endpoint_addr: EndpointAddr) -> Self {
Self {
relay_url: endpoint_addr.relay_urls().next().cloned(),
direct_addresses: endpoint_addr.ip_addrs().cloned().collect(),
}
}
}
fn encode_peer_data(info: &AddrInfo) -> PeerData {
let bytes = postcard::to_stdvec(info).expect("serializing AddrInfo may not fail");
PeerData::new(bytes)
}
fn decode_peer_data(peer_data: &PeerData) -> Result<AddrInfo, postcard::Error> {
let bytes = peer_data.as_bytes();
if bytes.is_empty() {
return Ok(AddrInfo::default());
}
let info = postcard::from_bytes(bytes)?;
Ok(info)
}
async fn topic_subscriber_loop(
sender: irpc::channel::mpsc::Sender<Event>,
mut topic_events: broadcast::Receiver<ProtoEvent>,
) {
loop {
tokio::select! {
biased;
msg = topic_events.recv() => {
let event = match msg {
Err(broadcast::error::RecvError::Closed) => break,
Err(broadcast::error::RecvError::Lagged(_)) => Event::Lagged,
Ok(event) => event.into(),
};
if sender.send(event).await.is_err() {
break;
}
}
_ = sender.closed() => break,
}
}
}
type BoxedCommandReceiver =
n0_future::stream::Boxed<Result<Command, irpc::channel::mpsc::RecvError>>;
#[derive(derive_more::Debug)]
struct TopicCommandStream {
topic_id: TopicId,
#[debug("CommandStream")]
stream: BoxedCommandReceiver,
closed: bool,
}
impl TopicCommandStream {
fn new(topic_id: TopicId, stream: BoxedCommandReceiver) -> Self {
Self {
topic_id,
stream,
closed: false,
}
}
}
impl Stream for TopicCommandStream {
type Item = (TopicId, Option<Command>);
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.closed {
return Poll::Ready(None);
}
match Pin::new(&mut self.stream).poll_next(cx) {
Poll::Ready(Some(Ok(item))) => Poll::Ready(Some((self.topic_id, Some(item)))),
Poll::Ready(None) | Poll::Ready(Some(Err(_))) => {
self.closed = true;
Poll::Ready(Some((self.topic_id, None)))
}
Poll::Pending => Poll::Pending,
}
}
}
#[derive(Debug)]
struct Dialer {
endpoint: Endpoint,
pending: JoinSet<(
EndpointId,
Option<Result<Connection, iroh::endpoint::ConnectError>>,
)>,
pending_dials: HashMap<EndpointId, CancellationToken>,
}
impl Dialer {
fn new(endpoint: Endpoint) -> Self {
Self {
endpoint,
pending: Default::default(),
pending_dials: Default::default(),
}
}
fn queue_dial(&mut self, endpoint_id: EndpointId, alpn: Bytes) {
if self.is_pending(endpoint_id) {
return;
}
let cancel = CancellationToken::new();
self.pending_dials.insert(endpoint_id, cancel.clone());
let endpoint = self.endpoint.clone();
self.pending.spawn(
async move {
let res = tokio::select! {
biased;
_ = cancel.cancelled() => None,
res = endpoint.connect(endpoint_id, &alpn) => Some(res),
};
(endpoint_id, res)
}
.instrument(tracing::Span::current()),
);
}
fn is_pending(&self, endpoint: EndpointId) -> bool {
self.pending_dials.contains_key(&endpoint)
}
async fn next_conn(
&mut self,
) -> (
EndpointId,
Option<Result<Connection, iroh::endpoint::ConnectError>>,
) {
match self.pending_dials.is_empty() {
false => {
let (endpoint_id, res) = loop {
match self.pending.join_next().await {
Some(Ok((endpoint_id, res))) => {
self.pending_dials.remove(&endpoint_id);
break (endpoint_id, res);
}
Some(Err(e)) => {
error!("next conn error: {:?}", e);
}
None => {
error!("no more pending conns available");
std::future::pending().await
}
}
};
(endpoint_id, res)
}
true => std::future::pending().await,
}
}
}
#[cfg(test)]
pub(crate) mod test {
use std::time::Duration;
use bytes::Bytes;
use futures_concurrency::future::TryJoin;
use iroh::{
address_lookup::memory::MemoryLookup, endpoint::BindError, protocol::Router, RelayMap,
RelayMode, SecretKey,
};
use n0_error::{AnyError, Result, StdResultExt};
use n0_tracing_test::traced_test;
use rand::{CryptoRng, Rng};
use tokio::{spawn, time::timeout};
use tokio_util::sync::CancellationToken;
use tracing::{info, instrument};
use super::*;
use crate::api::{ApiError, GossipReceiver, GossipSender};
struct ManualActorLoop {
actor: Actor,
step: usize,
}
impl std::ops::Deref for ManualActorLoop {
type Target = Actor;
fn deref(&self) -> &Self::Target {
&self.actor
}
}
impl std::ops::DerefMut for ManualActorLoop {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.actor
}
}
type EndpointHandle = tokio::task::JoinHandle<Result<()>>;
impl ManualActorLoop {
#[instrument(skip_all, fields(me = %actor.endpoint.id().fmt_short()))]
async fn new(mut actor: Actor) -> Self {
let _ = actor.setup().await;
Self { actor, step: 0 }
}
#[instrument(skip_all, fields(me = %self.endpoint.id().fmt_short()))]
async fn step(&mut self) -> bool {
let ManualActorLoop { actor, step } = self;
*step += 1;
let addr_update_stream = &mut futures_lite::stream::pending();
actor.event_loop(addr_update_stream, *step).await
}
async fn steps(&mut self, n: usize) {
for _ in 0..n {
self.step().await;
}
}
async fn finish(mut self) {
while self.step().await {}
}
}
impl Gossip {
async fn t_new_with_actor(
rng: &mut rand_chacha::ChaCha12Rng,
config: proto::Config,
relay_map: RelayMap,
cancel: &CancellationToken,
) -> Result<(Self, Actor, EndpointHandle), BindError> {
let endpoint = create_endpoint(rng, relay_map, None).await?;
let metrics = Arc::new(Metrics::default());
let address_lookup = GossipAddressLookup::default();
endpoint.address_lookup().expect("endpoint closed").add(address_lookup.clone());
let (actor, to_actor_tx, conn_tx) =
Actor::new(endpoint, config, metrics.clone(), None, address_lookup);
let max_message_size = actor.state.max_message_size();
let _actor_handle =
AbortOnDropHandle::new(task::spawn(futures_lite::future::pending()));
let gossip = Self {
inner: Inner {
api: GossipApi::local(to_actor_tx),
local_tx: conn_tx,
_actor_handle,
max_message_size,
metrics,
}
.into(),
};
let endpoint_task = task::spawn(endpoint_loop(
actor.endpoint.clone(),
gossip.clone(),
cancel.child_token(),
));
Ok((gossip, actor, endpoint_task))
}
async fn t_new(
rng: &mut rand_chacha::ChaCha12Rng,
config: proto::Config,
relay_map: RelayMap,
cancel: &CancellationToken,
) -> Result<(Self, Endpoint, EndpointHandle, impl Drop + use<>), BindError> {
let (g, actor, ep_handle) =
Gossip::t_new_with_actor(rng, config, relay_map, cancel).await?;
let ep = actor.endpoint.clone();
let me = ep.id().fmt_short();
let actor_handle =
task::spawn(actor.run().instrument(tracing::error_span!("gossip", %me)));
Ok((g, ep, ep_handle, AbortOnDropHandle::new(actor_handle)))
}
}
pub(crate) async fn create_endpoint(
rng: &mut rand_chacha::ChaCha12Rng,
relay_map: RelayMap,
memory_lookup: Option<MemoryLookup>,
) -> Result<Endpoint, BindError> {
let ep = Endpoint::empty_builder(RelayMode::Custom(relay_map))
.secret_key(SecretKey::generate(rng))
.alpns(vec![GOSSIP_ALPN.to_vec()])
.insecure_skip_relay_cert_verify(true)
.bind()
.await?;
if let Some(memory_lookup) = memory_lookup {
ep.address_lookup().expect("endpoint closed").add(memory_lookup);
}
ep.online().await;
Ok(ep)
}
async fn endpoint_loop(
endpoint: Endpoint,
gossip: Gossip,
cancel: CancellationToken,
) -> Result<()> {
loop {
tokio::select! {
biased;
_ = cancel.cancelled() => break,
incoming = endpoint.accept() => match incoming {
None => break,
Some(incoming) => {
let connecting = match incoming.accept() {
Ok(connecting) => connecting,
Err(err) => {
warn!("incoming connection failed: {err:#}");
continue;
}
};
let connection = connecting
.await
.std_context("await incoming connection")?;
gossip.handle_connection(connection).await?
}
}
}
}
Ok(())
}
#[tokio::test]
#[traced_test]
async fn gossip_net_smoke() {
let mut rng = rand_chacha::ChaCha12Rng::seed_from_u64(1);
let (relay_map, relay_url, _guard) = iroh::test_utils::run_relay_server().await.unwrap();
let memory_lookup = MemoryLookup::new();
let ep1 = create_endpoint(&mut rng, relay_map.clone(), Some(memory_lookup.clone()))
.await
.unwrap();
let ep2 = create_endpoint(&mut rng, relay_map.clone(), Some(memory_lookup.clone()))
.await
.unwrap();
let ep3 = create_endpoint(&mut rng, relay_map.clone(), Some(memory_lookup.clone()))
.await
.unwrap();
let go1 = Gossip::builder().spawn(ep1.clone());
let go2 = Gossip::builder().spawn(ep2.clone());
let go3 = Gossip::builder().spawn(ep3.clone());
debug!("peer1 {:?}", ep1.id());
debug!("peer2 {:?}", ep2.id());
debug!("peer3 {:?}", ep3.id());
let pi1 = ep1.id();
let pi2 = ep2.id();
let cancel = CancellationToken::new();
let tasks = [
spawn(endpoint_loop(ep1.clone(), go1.clone(), cancel.clone())),
spawn(endpoint_loop(ep2.clone(), go2.clone(), cancel.clone())),
spawn(endpoint_loop(ep3.clone(), go3.clone(), cancel.clone())),
];
debug!("----- adding peers ----- ");
let hash_bytes: [u8; 32] = hemera::hash(b"foobar").as_bytes()[..32].try_into().unwrap();
let topic: TopicId = hash_bytes.into();
let addr1 = EndpointAddr::new(pi1).with_relay_url(relay_url.clone());
let addr2 = EndpointAddr::new(pi2).with_relay_url(relay_url);
memory_lookup.add_endpoint_info(addr1.clone());
memory_lookup.add_endpoint_info(addr2.clone());
debug!("----- joining ----- ");
let [sub1, mut sub2, mut sub3] = [
go1.subscribe_and_join(topic, vec![]),
go2.subscribe_and_join(topic, vec![pi1]),
go3.subscribe_and_join(topic, vec![pi2]),
]
.try_join()
.await
.unwrap();
let (sink1, _stream1) = sub1.split();
let len = 2;
let pub1 = spawn(async move {
for i in 0..len {
let message = format!("hi{i}");
info!("go1 broadcast: {message:?}");
sink1.broadcast(message.into_bytes().into()).await.unwrap();
tokio::time::sleep(Duration::from_micros(1)).await;
}
});
let sub2 = spawn(async move {
let mut recv = vec![];
loop {
let ev = sub2.next().await.unwrap().unwrap();
info!("go2 event: {ev:?}");
if let Event::Received(msg) = ev {
recv.push(msg.content);
}
if recv.len() == len {
return recv;
}
}
});
let sub3 = spawn(async move {
let mut recv = vec![];
loop {
let ev = sub3.next().await.unwrap().unwrap();
info!("go3 event: {ev:?}");
if let Event::Received(msg) = ev {
recv.push(msg.content);
}
if recv.len() == len {
return recv;
}
}
});
timeout(Duration::from_secs(10), pub1)
.await
.unwrap()
.unwrap();
let recv2 = timeout(Duration::from_secs(10), sub2)
.await
.unwrap()
.unwrap();
let recv3 = timeout(Duration::from_secs(10), sub3)
.await
.unwrap()
.unwrap();
let expected: HashSet<Bytes> = (0..len)
.map(|i| Bytes::from(format!("hi{i}").into_bytes()))
.collect();
assert_eq!(HashSet::from_iter(recv2), expected);
assert_eq!(HashSet::from_iter(recv3), expected);
cancel.cancel();
for t in tasks {
timeout(Duration::from_secs(10), t)
.await
.unwrap()
.unwrap()
.unwrap();
}
}
#[tokio::test]
#[traced_test]
async fn subscription_cleanup() -> Result {
let rng = &mut rand_chacha::ChaCha12Rng::seed_from_u64(1);
let ct = CancellationToken::new();
let (relay_map, relay_url, _guard) = iroh::test_utils::run_relay_server().await.unwrap();
let (go1, actor, ep1_handle) =
Gossip::t_new_with_actor(rng, Default::default(), relay_map.clone(), &ct).await?;
let mut actor = ManualActorLoop::new(actor).await;
let (go2, ep2, ep2_handle, _test_actor_handle) =
Gossip::t_new(rng, Default::default(), relay_map, &ct).await?;
let endpoint_id1 = actor.endpoint.id();
let endpoint_id2 = ep2.id();
tracing::info!(
endpoint_1 = %endpoint_id1.fmt_short(),
endpoint_2 = %endpoint_id2.fmt_short(),
"endpoints ready"
);
let hash_bytes: [u8; 32] = hemera::hash(b"subscription_cleanup").as_bytes()[..32].try_into().unwrap();
let topic: TopicId = hash_bytes.into();
tracing::info!(%topic, "joining");
let ct2 = ct.clone();
let go2_task = async move {
let (_pub_tx, mut sub_rx) = go2.subscribe_and_join(topic, vec![]).await?.split();
let subscribe_fut = async {
while let Some(ev) = sub_rx.try_next().await? {
match ev {
Event::Lagged => tracing::debug!("missed some messages :("),
Event::Received(_) => unreachable!("test does not send messages"),
other => tracing::debug!(?other, "gs event"),
}
}
tracing::debug!("subscribe stream ended");
Ok::<_, AnyError>(())
};
tokio::select! {
_ = ct2.cancelled() => Ok(()),
res = subscribe_fut => res,
}
}
.instrument(tracing::debug_span!("endpoint_2", %endpoint_id2));
let go2_handle = task::spawn(go2_task);
let addr2 = EndpointAddr::new(endpoint_id2).with_relay_url(relay_url);
let memory_lookup = MemoryLookup::new();
memory_lookup.add_endpoint_info(addr2);
actor.endpoint.address_lookup().expect("endpoint closed").add(memory_lookup);
let (tx, mut rx) = mpsc::channel::<()>(1);
let ct1 = ct.clone();
let go1_task = async move {
tracing::info!("subscribing the first time");
let sub_1a = go1.subscribe_and_join(topic, vec![endpoint_id2]).await?;
rx.recv().await.expect("signal for second subscribe");
tracing::info!("subscribing a second time");
let sub_1b = go1.subscribe_and_join(topic, vec![endpoint_id2]).await?;
drop(sub_1a);
rx.recv().await.expect("signal for second subscribe");
tracing::info!("dropping all handles");
drop(sub_1b);
ct1.cancelled().await;
drop(go1);
Ok::<_, AnyError>(())
}
.instrument(tracing::debug_span!("endpoint_1", %endpoint_id1));
let go1_handle = task::spawn(go1_task);
actor.steps(3).await; let state = actor.topics.get(&topic).expect("get registered topic");
assert!(state.joined());
tx.send(())
.await
.std_context("signal additional subscribe")?;
actor.steps(3).await; let state = actor.topics.get(&topic).expect("get registered topic");
assert!(state.joined());
tx.send(()).await.std_context("signal drop handles")?;
actor.steps(2).await; assert!(!actor.topics.contains_key(&topic));
ct.cancel();
let wait = Duration::from_secs(2);
timeout(wait, ep1_handle)
.await
.std_context("wait endpoint1 task")?
.std_context("join endpoint1 task")??;
timeout(wait, ep2_handle)
.await
.std_context("wait endpoint2 task")?
.std_context("join endpoint2 task")??;
timeout(wait, go1_handle)
.await
.std_context("wait gossip1 task")?
.std_context("join gossip1 task")??;
timeout(wait, go2_handle)
.await
.std_context("wait gossip2 task")?
.std_context("join gossip2 task")??;
timeout(wait, actor.finish())
.await
.std_context("wait actor finish")?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn can_reconnect() -> Result {
let rng = &mut rand_chacha::ChaCha12Rng::seed_from_u64(1);
let ct = CancellationToken::new();
let (relay_map, relay_url, _guard) = iroh::test_utils::run_relay_server().await.unwrap();
let (go1, ep1, ep1_handle, _test_actor_handle1) =
Gossip::t_new(rng, Default::default(), relay_map.clone(), &ct).await?;
let (go2, ep2, ep2_handle, _test_actor_handle2) =
Gossip::t_new(rng, Default::default(), relay_map, &ct).await?;
let endpoint_id1 = ep1.id();
let endpoint_id2 = ep2.id();
tracing::info!(
endpoint_1 = %endpoint_id1.fmt_short(),
endpoint_2 = %endpoint_id2.fmt_short(),
"endpoints ready"
);
let hash_bytes: [u8; 32] = hemera::hash(b"can_reconnect").as_bytes()[..32].try_into().unwrap();
let topic: TopicId = hash_bytes.into();
tracing::info!(%topic, "joining");
let ct2 = ct.child_token();
let (tx, mut rx) = mpsc::channel::<()>(1);
let addr1 = EndpointAddr::new(endpoint_id1).with_relay_url(relay_url.clone());
let memory_lookup = MemoryLookup::new();
memory_lookup.add_endpoint_info(addr1);
ep2.address_lookup().expect("endpoint closed").add(memory_lookup.clone());
let go2_task = async move {
let mut sub = go2.subscribe(topic, Vec::new()).await?;
sub.joined().await?;
rx.recv().await.expect("signal to unsubscribe");
tracing::info!("unsubscribing");
drop(sub);
rx.recv().await.expect("signal to subscribe again");
tracing::info!("resubscribing");
let mut sub = go2.subscribe(topic, vec![endpoint_id1]).await?;
sub.joined().await?;
tracing::info!("subscription successful!");
ct2.cancelled().await;
Ok::<_, ApiError>(())
}
.instrument(tracing::debug_span!("endpoint_2", %endpoint_id2));
let go2_handle = task::spawn(go2_task);
let addr2 = EndpointAddr::new(endpoint_id2).with_relay_url(relay_url);
memory_lookup.add_endpoint_info(addr2);
ep1.address_lookup().expect("endpoint closed").add(memory_lookup);
let mut sub = go1.subscribe(topic, vec![endpoint_id2]).await?;
sub.joined().await?;
tx.send(()).await.std_context("signal unsubscribe")?;
let conn_timeout = Duration::from_millis(500);
let ev = timeout(conn_timeout, sub.try_next())
.await
.std_context("wait neighbor down")??;
assert_eq!(ev, Some(Event::NeighborDown(endpoint_id2)));
tracing::info!("endpoint 2 left");
tx.send(()).await.std_context("signal resubscribe")?;
let conn_timeout = Duration::from_millis(500);
let ev = timeout(conn_timeout, sub.try_next())
.await
.std_context("wait neighbor up")??;
assert_eq!(ev, Some(Event::NeighborUp(endpoint_id2)));
tracing::info!("endpoint 2 rejoined!");
ct.cancel();
let wait = Duration::from_secs(2);
timeout(wait, ep1_handle)
.await
.std_context("wait endpoint1 task")?
.std_context("join endpoint1 task")??;
timeout(wait, ep2_handle)
.await
.std_context("wait endpoint2 task")?
.std_context("join endpoint2 task")??;
timeout(wait, go2_handle)
.await
.std_context("wait gossip2 task")?
.std_context("join gossip2 task")??;
Result::Ok(())
}
#[tokio::test]
#[traced_test]
async fn can_die_and_reconnect() -> Result {
fn run_in_thread<T: Send + 'static>(
cancel: CancellationToken,
fut: impl std::future::Future<Output = T> + Send + 'static,
) -> std::thread::JoinHandle<Option<T>> {
std::thread::spawn(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async move { cancel.run_until_cancelled(fut).await })
})
}
async fn spawn_gossip(
secret_key: SecretKey,
relay_map: RelayMap,
) -> Result<(Router, Gossip), BindError> {
let ep = Endpoint::empty_builder(RelayMode::Custom(relay_map))
.secret_key(secret_key)
.insecure_skip_relay_cert_verify(true)
.bind()
.await?;
let gossip = Gossip::builder().spawn(ep.clone());
let router = Router::builder(ep)
.accept(GOSSIP_ALPN, gossip.clone())
.spawn();
Ok((router, gossip))
}
async fn broadcast_once(
secret_key: SecretKey,
relay_map: RelayMap,
bootstrap_addr: EndpointAddr,
topic_id: TopicId,
message: String,
) -> Result {
let (router, gossip) = spawn_gossip(secret_key, relay_map).await?;
info!(endpoint_id = %router.endpoint().id().fmt_short(), "broadcast endpoint spawned");
let bootstrap = vec![bootstrap_addr.id];
let memory_lookup = MemoryLookup::new();
memory_lookup.add_endpoint_info(bootstrap_addr);
router.endpoint().address_lookup().expect("endpoint closed").add(memory_lookup);
let mut topic = gossip.subscribe_and_join(topic_id, bootstrap).await?;
topic.broadcast(message.as_bytes().to_vec().into()).await?;
std::future::pending::<()>().await;
Ok(())
}
let (relay_map, _relay_url, _guard) = iroh::test_utils::run_relay_server().await.unwrap();
let mut rng = &mut rand_chacha::ChaCha12Rng::seed_from_u64(1);
let topic_id = TopicId::from_bytes(rng.random());
let (addr_tx, addr_rx) = tokio::sync::oneshot::channel();
let (msgs_recv_tx, mut msgs_recv_rx) = tokio::sync::mpsc::channel(3);
let recv_task = tokio::task::spawn({
let relay_map = relay_map.clone();
let secret_key = SecretKey::generate(&mut rng);
async move {
let (router, gossip) = spawn_gossip(secret_key, relay_map).await?;
router.endpoint().online().await;
let addr = router.endpoint().addr();
info!(endpoint_id = %addr.id.fmt_short(), "recv endpoint spawned");
addr_tx.send(addr).unwrap();
let mut topic = gossip.subscribe_and_join(topic_id, vec![]).await?;
while let Some(event) = topic.try_next().await.unwrap() {
if let Event::Received(message) = event {
let message = std::str::from_utf8(&message.content)
.std_context("decode broadcast message")?
.to_string();
msgs_recv_tx
.send(message)
.await
.std_context("forward received message")?;
}
}
Ok::<_, AnyError>(())
}
});
let endpoint0_addr = addr_rx.await.std_context("receive endpoint address")?;
let max_wait = Duration::from_secs(5);
let cancel = CancellationToken::new();
let secret = SecretKey::generate(&mut rng);
let join_handle_1 = run_in_thread(
cancel.clone(),
broadcast_once(
secret.clone(),
relay_map.clone(),
endpoint0_addr.clone(),
topic_id,
"msg1".to_string(),
),
);
let msg = timeout(max_wait, msgs_recv_rx.recv())
.await
.std_context("wait for first broadcast")?
.std_context("receiver dropped channel")?;
assert_eq!(&msg, "msg1");
info!("kill broadcast endpoint");
cancel.cancel();
let cancel = CancellationToken::new();
let join_handle_2 = run_in_thread(
cancel.clone(),
broadcast_once(
secret.clone(),
relay_map.clone(),
endpoint0_addr.clone(),
topic_id,
"msg2".to_string(),
),
);
let msg = timeout(max_wait, msgs_recv_rx.recv())
.await
.std_context("wait for second broadcast")?
.std_context("receiver dropped channel")?;
assert_eq!(&msg, "msg2");
info!("kill broadcast endpoint");
cancel.cancel();
info!("kill recv endpoint");
recv_task.abort();
assert!(join_handle_1.join().unwrap().is_none());
assert!(join_handle_2.join().unwrap().is_none());
Ok(())
}
#[tokio::test]
#[traced_test]
async fn gossip_change_alpn() -> n0_error::Result<()> {
let alpn = b"my-gossip-alpn";
let topic_id = TopicId::from([0u8; 32]);
let ep1 = Endpoint::empty_builder(RelayMode::Disabled).bind().await?;
let ep2 = Endpoint::empty_builder(RelayMode::Disabled).bind().await?;
let gossip1 = Gossip::builder().alpn(alpn).spawn(ep1.clone());
let gossip2 = Gossip::builder().alpn(alpn).spawn(ep2.clone());
let router1 = Router::builder(ep1).accept(alpn, gossip1.clone()).spawn();
let router2 = Router::builder(ep2).accept(alpn, gossip2.clone()).spawn();
let addr1 = router1.endpoint().addr();
let id1 = addr1.id;
let memory_lookup = MemoryLookup::new();
memory_lookup.add_endpoint_info(addr1);
router2.endpoint().address_lookup().expect("endpoint closed").add(memory_lookup);
let mut topic1 = gossip1.subscribe(topic_id, vec![]).await?;
let mut topic2 = gossip2.subscribe(topic_id, vec![id1]).await?;
timeout(Duration::from_secs(3), topic1.joined())
.await
.std_context("wait topic1 join")??;
timeout(Duration::from_secs(3), topic2.joined())
.await
.std_context("wait topic2 join")??;
router1.shutdown().await.std_context("shutdown router1")?;
router2.shutdown().await.std_context("shutdown router2")?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn gossip_rely_on_gossip_address_lookup() -> n0_error::Result<()> {
let rng = &mut rand_chacha::ChaCha12Rng::seed_from_u64(1);
async fn spawn(
rng: &mut impl CryptoRng,
) -> n0_error::Result<(EndpointId, Router, Gossip, GossipSender, GossipReceiver)> {
let topic_id = TopicId::from([0u8; 32]);
let ep = Endpoint::empty_builder(RelayMode::Disabled)
.secret_key(SecretKey::generate(rng))
.bind()
.await?;
let endpoint_id = ep.id();
let gossip = Gossip::builder().spawn(ep.clone());
let router = Router::builder(ep)
.accept(GOSSIP_ALPN, gossip.clone())
.spawn();
let topic = gossip.subscribe(topic_id, vec![]).await?;
let (sender, receiver) = topic.split();
Ok((endpoint_id, router, gossip, sender, receiver))
}
let (n1, r1, _g1, _tx1, mut rx1) = spawn(rng).await?;
let (n2, r2, _g2, tx2, mut rx2) = spawn(rng).await?;
let (n3, r3, _g3, tx3, mut rx3) = spawn(rng).await?;
println!("endpoints {:?}", [n1, n2, n3]);
let addr1 = r1.endpoint().addr();
let lookup = MemoryLookup::new();
lookup.add_endpoint_info(addr1);
r2.endpoint().address_lookup().expect("endpoint closed").add(lookup.clone());
tx2.join_peers(vec![n1]).await?;
timeout(Duration::from_secs(3), rx1.joined())
.await
.std_context("wait rx1 join")??;
timeout(Duration::from_secs(3), rx2.joined())
.await
.std_context("wait rx2 join")??;
r3.endpoint().address_lookup().expect("endpoint closed").add(lookup.clone());
tx3.join_peers(vec![n1]).await?;
let ev = timeout(Duration::from_secs(3), rx3.next())
.await
.std_context("wait rx3 first neighbor")?;
assert!(matches!(ev, Some(Ok(Event::NeighborUp(_)))));
let ev = timeout(Duration::from_secs(3), rx3.next())
.await
.std_context("wait rx3 second neighbor")?;
assert!(matches!(ev, Some(Ok(Event::NeighborUp(_)))));
assert_eq!(sorted(rx3.neighbors()), sorted([n1, n2]));
let ev = timeout(Duration::from_secs(3), rx2.next())
.await
.std_context("wait rx2 neighbor")?;
assert!(matches!(ev, Some(Ok(Event::NeighborUp(n))) if n == n3));
let ev = timeout(Duration::from_secs(3), rx1.next())
.await
.std_context("wait rx1 neighbor")?;
assert!(matches!(ev, Some(Ok(Event::NeighborUp(n))) if n == n3));
tokio::try_join!(r1.shutdown(), r2.shutdown(), r3.shutdown())
.std_context("shutdown routers")?;
Ok(())
}
fn sorted<T: Ord>(input: impl IntoIterator<Item = T>) -> Vec<T> {
let mut out: Vec<_> = input.into_iter().collect();
out.sort();
out
}
#[tokio::test]
#[traced_test]
async fn topic_stays_alive_after_sender_drop() -> n0_error::Result<()> {
let topic_id = TopicId::from([99u8; 32]);
let ep1 = Endpoint::empty_builder(RelayMode::Disabled).bind().await?;
let ep2 = Endpoint::empty_builder(RelayMode::Disabled).bind().await?;
let gossip1 = Gossip::builder().spawn(ep1.clone());
let gossip2 = Gossip::builder().spawn(ep2.clone());
let router1 = Router::builder(ep1)
.accept(crate::ALPN, gossip1.clone())
.spawn();
let router2 = Router::builder(ep2)
.accept(crate::ALPN, gossip2.clone())
.spawn();
let addr1 = router1.endpoint().addr();
let id1 = addr1.id;
let mem_lookup = MemoryLookup::new();
mem_lookup.add_endpoint_info(addr1);
router2.endpoint().address_lookup().expect("endpoint closed").add(mem_lookup);
let topic1 = gossip1.subscribe(topic_id, vec![]).await?;
let topic2 = gossip2.subscribe(topic_id, vec![id1]).await?;
let (tx1, mut rx1) = topic1.split();
let (tx2, mut rx2) = topic2.split();
timeout(Duration::from_secs(3), rx1.joined())
.await
.std_context("wait rx1 join")??;
timeout(Duration::from_secs(3), rx2.joined())
.await
.std_context("wait rx2 join")??;
drop(tx1);
tx2.broadcast(b"hello from node2".to_vec().into()).await?;
let event = timeout(Duration::from_secs(3), rx1.next())
.await
.std_context("wait for message on rx1")?;
match event {
Some(Ok(Event::Received(msg))) => {
assert_eq!(&msg.content[..], b"hello from node2");
}
other => panic!("expected Received event, got {:?}", other),
}
drop(tx2);
drop(rx1);
drop(rx2);
router1.shutdown().await.std_context("shutdown router1")?;
router2.shutdown().await.std_context("shutdown router2")?;
Ok(())
}
}