use std::{
collections::{BTreeMap, BTreeSet},
fmt::Display,
future::poll_fn,
io,
net::{IpAddr, SocketAddr},
sync::{
Arc, Mutex, RwLock,
atomic::{AtomicBool, Ordering},
},
};
use iroh_base::{EndpointAddr, EndpointId, PublicKey, RelayUrl, SecretKey, TransportAddr};
use iroh_relay::{RelayConfig, RelayMap};
use n0_error::{bail, e, stack_error};
use n0_future::{
task::{self, AbortOnDropHandle},
time::{self, Duration, Instant},
};
use n0_watcher::{self, Watchable, Watcher};
#[cfg(not(wasm_browser))]
use netwatch::ip::LocalAddresses;
use netwatch::netmon;
use quinn::WeakConnectionHandle;
use rand::Rng;
use tokio::sync::{
Mutex as AsyncMutex,
mpsc::{self},
oneshot,
};
use tokio_util::sync::{CancellationToken, WaitForCancellationFutureOwned};
use tracing::{Instrument, Level, debug, event, info_span, instrument, trace, warn};
use transports::{LocalAddrsWatch, Transport, TransportConfig};
use url::Url;
use self::{
remote_map::{RemoteMap, RemoteStateMessage},
transports::{RelayActorConfig, Transports},
};
#[cfg(not(wasm_browser))]
use crate::dns::DnsResolver;
#[cfg(not(wasm_browser))]
use crate::net_report::QuicConfig;
use crate::{
address_lookup::{self, AddressLookup, EndpointData, Error as AddressLookupError, UserData},
defaults::timeouts::NET_REPORT_TIMEOUT,
endpoint::{hooks::EndpointHooksList, quic::QuicTransportConfig},
metrics::EndpointMetrics,
net_report::{self, IfStateDetails, Report},
runtime::Runtime,
socket::{
concurrent_read_map::ReadOnlyMap,
remote_map::{MappedAddrs, PathWatchable, RemoteInfo},
},
tls,
};
mod metrics;
pub(crate) mod concurrent_read_map;
pub(crate) mod mapped_addrs;
pub(crate) mod remote_map;
pub(crate) mod transports;
use self::mapped_addrs::{EndpointIdMappedAddr, MappedAddr};
pub use self::metrics::Metrics;
pub(crate) const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
pub(crate) const PATH_MAX_IDLE_TIMEOUT: Duration = Duration::from_millis(6500);
pub(crate) const MAX_MULTIPATH_PATHS: u32 = 12;
#[stack_error(add_meta, derive)]
#[error("endpoint state actor stopped")]
#[derive(Clone)]
pub(crate) struct RemoteStateActorStoppedError;
impl From<mpsc::error::SendError<RemoteStateMessage>> for RemoteStateActorStoppedError {
#[track_caller]
fn from(_value: mpsc::error::SendError<RemoteStateMessage>) -> Self {
Self::new()
}
}
#[derive(derive_more::Debug)]
pub(crate) struct Options {
pub(crate) transports: Vec<TransportConfig>,
pub(crate) secret_key: SecretKey,
pub(crate) address_lookup_user_data: Option<UserData>,
#[cfg(not(wasm_browser))]
pub(crate) dns_resolver: DnsResolver,
pub(crate) proxy_url: Option<Url>,
pub(crate) server_config: quinn_proto::ServerConfig,
#[cfg(any(test, feature = "test-utils"))]
pub(crate) insecure_skip_relay_cert_verify: bool,
pub(crate) metrics: EndpointMetrics,
pub(crate) hooks: EndpointHooksList,
pub(crate) static_config: StaticConfig,
}
#[derive(Debug, derive_more::Deref)]
pub(crate) struct EndpointInner {
#[deref(forward)]
sock: Arc<Socket>,
actor_task: Mutex<Option<AbortOnDropHandle<()>>>,
actor_sender: mpsc::Sender<ActorMessage>,
endpoint: quinn::Endpoint,
runtime: Arc<Runtime>,
pub(crate) static_config: StaticConfig,
}
impl Drop for EndpointInner {
fn drop(&mut self) {
if self.sock.is_closed() {
return;
}
tracing::error!(
"Endpoint dropped without calling `Endpoint::close`. Aborting ungracefully."
);
self.abort();
}
}
#[derive(Debug)]
pub(crate) struct StaticConfig {
pub(crate) tls_config: tls::TlsConfig,
pub(crate) transport_config: QuicTransportConfig,
pub(crate) keylog: bool,
}
impl StaticConfig {
pub(crate) fn create_server_config(
&self,
alpn_protocols: Vec<Vec<u8>>,
) -> quinn_proto::ServerConfig {
let quic_server_config = self
.tls_config
.make_server_config(alpn_protocols, self.keylog);
let mut inner = quinn::ServerConfig::with_crypto(Arc::new(quic_server_config));
inner.transport_config(self.transport_config.to_inner_arc());
inner
}
}
#[derive(Debug)]
struct ShutdownState {
at_close_start: CancellationToken,
at_endpoint_closed: CancellationToken,
closed: AtomicBool,
}
impl Default for ShutdownState {
fn default() -> Self {
Self {
at_close_start: CancellationToken::new(),
at_endpoint_closed: CancellationToken::new(),
closed: AtomicBool::new(false),
}
}
}
impl ShutdownState {
fn is_closing(&self) -> bool {
self.at_close_start.is_cancelled()
}
fn is_closed(&self) -> bool {
self.closed.load(Ordering::Relaxed)
}
}
#[derive(Debug)]
pub(crate) struct Socket {
remote_actors: ReadOnlyMap<EndpointId, mpsc::Sender<RemoteStateMessage>>,
public_key: PublicKey,
shutdown: ShutdownState,
direct_addrs: DiscoveredDirectAddrs,
net_report: Watchable<(Option<Report>, UpdateReason)>,
ipv6_reported: Arc<AtomicBool>,
mapped_addrs: MappedAddrs,
local_addrs_watch: LocalAddrsWatch,
#[cfg(not(wasm_browser))]
ip_bind_addrs: Vec<SocketAddr>,
#[cfg(not(wasm_browser))]
dns_resolver: DnsResolver,
relay_map: RelayMap,
address_lookup: address_lookup::ConcurrentAddressLookup,
address_lookup_user_data: RwLock<Option<UserData>>,
pub(crate) metrics: EndpointMetrics,
pub(crate) hooks: EndpointHooksList,
}
impl Socket {
pub(crate) async fn spawn(opts: Options) -> Result<EndpointInner, BindError> {
EndpointInner::new(opts).await
}
pub(crate) fn my_relay(&self) -> Option<RelayUrl> {
self.local_addr().into_iter().find_map(|a| {
if let transports::Addr::Relay(url, _) = a {
Some(url)
} else {
None
}
})
}
pub(crate) fn is_closed(&self) -> bool {
self.shutdown.is_closed()
}
fn is_closing(&self) -> bool {
self.shutdown.is_closing()
}
pub(crate) fn closed(&self) -> WaitForCancellationFutureOwned {
self.shutdown.at_close_start.clone().cancelled_owned()
}
pub(crate) fn local_addr(&self) -> Vec<transports::Addr> {
self.local_addrs_watch.clone().get()
}
#[cfg(not(wasm_browser))]
fn ip_bind_addrs(&self) -> &[SocketAddr] {
&self.ip_bind_addrs
}
fn ip_local_addrs(&self) -> impl Iterator<Item = SocketAddr> + use<> {
self.local_addr()
.into_iter()
.filter_map(|addr| addr.into_socket_addr())
}
pub(crate) fn try_send_remote_state_msg(
&self,
endpoint_id: EndpointId,
message: RemoteStateMessage,
) -> Result<(), RemoteStateMessage> {
let Some(sender) = self.remote_actors.get(&endpoint_id) else {
return Err(message);
};
sender.try_send(message).map_err(|err| err.into_inner())
}
pub(crate) fn ip_addrs(&self) -> n0_watcher::Direct<BTreeSet<DirectAddr>> {
self.direct_addrs.addrs.watch()
}
pub(crate) fn net_report(&self) -> impl Watcher<Value = Option<Report>> + use<> {
self.net_report.watch().map(|(r, _)| r)
}
pub(crate) fn home_relay(&self) -> impl Watcher<Value = Vec<RelayUrl>> + use<> {
self.local_addrs_watch.clone().map(|addrs| {
addrs
.into_iter()
.filter_map(|addr| {
if let transports::Addr::Relay(url, _) = addr {
Some(url)
} else {
None
}
})
.collect()
})
}
fn store_direct_addresses(&self, addrs: BTreeSet<DirectAddr>) {
let updated = self.direct_addrs.update(addrs);
if updated {
self.publish_my_addr();
}
}
#[cfg(not(wasm_browser))]
pub(crate) fn dns_resolver(&self) -> &DnsResolver {
&self.dns_resolver
}
pub(crate) fn address_lookup(&self) -> &address_lookup::ConcurrentAddressLookup {
&self.address_lookup
}
pub(crate) fn set_user_data_for_address_lookup(&self, user_data: Option<UserData>) {
let mut guard = self
.address_lookup_user_data
.write()
.expect("lock poisened");
if *guard != user_data {
*guard = user_data;
drop(guard);
self.publish_my_addr();
}
}
fn process_datagrams(
&self,
bufs: &mut [io::IoSliceMut<'_>],
metas: &mut [quinn_udp::RecvMeta],
source_addrs: &[transports::Addr],
) {
debug_assert_eq!(bufs.len(), metas.len(), "non matching bufs & metas");
debug_assert_eq!(
bufs.len(),
source_addrs.len(),
"non matching bufs & source_addrs"
);
for i in 0..metas.len() {
let quinn_meta = &mut metas[i];
let source_addr = &source_addrs[i];
let datagram_count = quinn_meta.len.div_ceil(quinn_meta.stride);
self.metrics
.socket
.recv_datagrams
.inc_by(datagram_count as _);
if quinn_meta.len > quinn_meta.stride {
trace!(
src = ?source_addr,
len = quinn_meta.len,
stride = %quinn_meta.stride,
datagram_count = quinn_meta.len.div_ceil(quinn_meta.stride),
"GRO datagram received",
);
self.metrics.socket.recv_gro_datagrams.inc();
} else {
trace!(src = ?source_addr, len = quinn_meta.len, "datagram received");
}
match source_addr {
transports::Addr::Ip(SocketAddr::V4(..)) => {
self.metrics
.socket
.recv_data_ipv4
.inc_by(quinn_meta.len as _);
}
transports::Addr::Ip(SocketAddr::V6(..)) => {
self.metrics
.socket
.recv_data_ipv6
.inc_by(quinn_meta.len as _);
}
transports::Addr::Relay(src_url, src_node) => {
self.metrics
.socket
.recv_data_relay
.inc_by(quinn_meta.len as _);
let mapped_addr = self
.mapped_addrs
.relay_addrs
.get(&(src_url.clone(), *src_node));
quinn_meta.addr = mapped_addr.private_socket_addr();
}
}
}
}
fn publish_my_addr(&self) {
let relay_url = self.my_relay();
let mut addrs: BTreeSet<_> = self
.direct_addrs
.sockaddrs()
.map(TransportAddr::Ip)
.collect();
let user_data = self
.address_lookup_user_data
.read()
.expect("lock poisened")
.clone();
if relay_url.is_none() && addrs.is_empty() && user_data.is_none() {
return;
}
if let Some(url) = relay_url {
addrs.insert(TransportAddr::Relay(url));
}
let data = EndpointData::new(addrs).with_user_data(user_data);
self.address_lookup.publish(&data);
}
}
#[derive(Debug)]
struct DirectAddrUpdateState {
want_update: Option<UpdateReason>,
sock: Arc<Socket>,
#[cfg(not(wasm_browser))]
port_mapper: portmapper::Client,
net_reporter: Arc<AsyncMutex<net_report::Client>>,
relay_map: RelayMap,
run_done: mpsc::Sender<()>,
shutdown_token: CancellationToken,
}
#[derive(Default, Debug, PartialEq, Eq, Clone, Copy)]
enum UpdateReason {
#[default]
None,
Periodic,
PortmapUpdated,
LinkChangeMajor,
LinkChangeMinor,
RelayMapChange,
}
impl UpdateReason {
fn is_major(self) -> bool {
matches!(self, Self::LinkChangeMajor | Self::RelayMapChange)
}
}
impl DirectAddrUpdateState {
fn new(
sock: Arc<Socket>,
#[cfg(not(wasm_browser))] port_mapper: portmapper::Client,
net_reporter: Arc<AsyncMutex<net_report::Client>>,
relay_map: RelayMap,
run_done: mpsc::Sender<()>,
shutdown_token: CancellationToken,
) -> Self {
DirectAddrUpdateState {
want_update: Default::default(),
#[cfg(not(wasm_browser))]
port_mapper,
net_reporter,
sock,
relay_map,
run_done,
shutdown_token,
}
}
fn schedule_run(&mut self, why: UpdateReason, if_state: IfStateDetails) {
match self.net_reporter.clone().try_lock_owned() {
Ok(net_reporter) => {
self.run(why, if_state, net_reporter);
}
Err(_) => {
let _ = self.want_update.insert(why);
}
}
}
fn try_run(&mut self, if_state: IfStateDetails) {
match self.net_reporter.clone().try_lock_owned() {
Ok(net_reporter) => {
if let Some(why) = self.want_update.take() {
self.run(why, if_state, net_reporter);
}
}
Err(_) => {
}
}
}
fn run(
&mut self,
why: UpdateReason,
if_state: IfStateDetails,
mut net_reporter: tokio::sync::OwnedMutexGuard<net_report::Client>,
) {
debug!("starting direct addr update ({:?})", why);
if self.shutdown_token.is_cancelled() {
debug!("skipping net_report, socket is shutting down");
#[cfg(not(wasm_browser))]
self.port_mapper.deactivate();
return;
}
if self.relay_map.is_empty() {
debug!("skipping net_report, empty RelayMap");
self.sock.net_report.set((None, why)).ok();
return;
}
#[cfg(not(wasm_browser))]
self.port_mapper.procure_mapping();
debug!("requesting net_report report");
let sock = self.sock.clone();
let run_done = self.run_done.clone();
let token = self.shutdown_token.child_token();
let inner_token = token.child_token();
task::spawn(
async move {
let fut = token.run_until_cancelled(time::timeout(
NET_REPORT_TIMEOUT,
net_reporter.get_report(if_state, why.is_major(), inner_token),
));
match fut.await {
Some(Ok(report)) => {
sock.net_report.set((Some(report), why)).ok();
}
Some(Err(time::Elapsed { .. })) => {
warn!("net_report report timed out");
}
None => {
trace!("net_report cancelled");
}
}
debug!("direct addr update done ({:?})", why);
run_done.send(()).await.ok();
}
.instrument(tracing::Span::current()),
);
}
}
#[allow(missing_docs)]
#[stack_error(derive, add_meta)]
#[non_exhaustive]
pub enum BindError {
#[error("Failed to bind sockets")]
Sockets { source: io::Error },
#[error("Failed to create internal QUIC endpoint")]
CreateQuicEndpoint { source: io::Error },
#[error("Failed to create netmon monitor")]
CreateNetmonMonitor { source: netmon::Error },
#[error("Invalid transport configuration")]
InvalidTransportConfig,
#[error("Failed to create an address lookup service")]
AddressLookup {
#[error(from)]
source: crate::address_lookup::IntoAddressLookupError,
},
}
impl EndpointInner {
async fn new(opts: Options) -> Result<Self, BindError> {
let Options {
secret_key,
transports: transport_configs,
address_lookup_user_data,
#[cfg(not(wasm_browser))]
dns_resolver,
proxy_url,
server_config,
#[cfg(any(test, feature = "test-utils"))]
insecure_skip_relay_cert_verify,
metrics,
hooks,
static_config,
} = opts;
let address_lookup = address_lookup::ConcurrentAddressLookup::default();
#[cfg(not(wasm_browser))]
let port_mapper =
portmapper::Client::with_metrics(Default::default(), metrics.portmapper.clone());
let relay_transport_configs: Vec<_> = transport_configs
.iter()
.filter(|t| matches!(t, TransportConfig::Relay { .. }))
.collect();
if relay_transport_configs.len() > 1 {
bail!(BindError::InvalidTransportConfig);
}
let relay_map = relay_transport_configs
.iter()
.filter_map(|t| {
#[allow(irrefutable_let_patterns)]
if let TransportConfig::Relay { relay_map, .. } = t {
Some(relay_map.clone())
} else {
None
}
})
.next()
.unwrap_or_else(RelayMap::empty);
let my_relay = Watchable::new(None);
let ipv6_reported = Arc::new(AtomicBool::new(false));
let relay_actor_config = RelayActorConfig {
my_relay: my_relay.clone(),
secret_key: secret_key.clone(),
#[cfg(not(wasm_browser))]
dns_resolver: dns_resolver.clone(),
proxy_url: proxy_url.clone(),
ipv6_reported: ipv6_reported.clone(),
#[cfg(any(test, feature = "test-utils"))]
insecure_skip_relay_cert_verify,
metrics: metrics.socket.clone(),
};
let shutdown_state = ShutdownState::default();
let shutdown_token = shutdown_state.at_endpoint_closed.child_token();
let transports = Transports::bind(
&transport_configs,
relay_actor_config,
&metrics,
shutdown_token.child_token(),
)
.map_err(|err| e!(BindError::Sockets, err))?;
#[cfg(not(wasm_browser))]
{
if let Some(v4_port) = transports.local_addrs().into_iter().find_map(|t| {
if let transports::Addr::Ip(SocketAddr::V4(addr)) = t {
Some(addr.port())
} else {
None
}
}) {
match v4_port.try_into() {
Ok(non_zero_port) => {
port_mapper.update_local_port(non_zero_port);
}
Err(_zero_port) => debug!("Skipping port mapping with zero local port"),
}
}
}
let (actor_sender, actor_receiver) = mpsc::channel(256);
#[cfg(not(wasm_browser))]
let has_ipv6_transport = transports
.ip_bind_addrs()
.into_iter()
.any(|addr| addr.is_ipv6());
#[cfg(not(wasm_browser))]
let has_ip_transports = !transports.ip_bind_addrs().is_empty();
let direct_addrs = DiscoveredDirectAddrs::default();
let remote_map = {
RemoteMap::new(
secret_key.public(),
metrics.socket.clone(),
direct_addrs.addrs.watch(),
address_lookup.clone(),
shutdown_token.child_token(),
)
};
let sock = Arc::new(Socket {
public_key: secret_key.public(),
remote_actors: remote_map.senders(),
shutdown: shutdown_state,
ipv6_reported,
mapped_addrs: remote_map.mapped_addrs.clone(),
address_lookup,
relay_map: relay_map.clone(),
address_lookup_user_data: RwLock::new(address_lookup_user_data),
direct_addrs,
net_report: Watchable::new((None, UpdateReason::None)),
#[cfg(not(wasm_browser))]
dns_resolver: dns_resolver.clone(),
metrics: metrics.clone(),
local_addrs_watch: transports.local_addrs_watch(),
#[cfg(not(wasm_browser))]
ip_bind_addrs: transports.ip_bind_addrs(),
hooks,
});
let mut endpoint_config = quinn::EndpointConfig::default();
endpoint_config.grease_quic_bit(false);
let local_addrs_watch = transports.local_addrs_watch();
let network_change_sender = transports.create_network_change_sender();
let runtime = Arc::new(Runtime::new(secret_key.public()));
let endpoint = quinn::Endpoint::new_with_abstract_socket(
endpoint_config,
Some(server_config),
Box::new(Transport::new(sock.clone(), transports)),
runtime.clone(),
)
.map_err(|err| e!(BindError::CreateQuicEndpoint, err))?;
let network_monitor = netmon::Monitor::new()
.await
.map_err(|err| e!(BindError::CreateNetmonMonitor, err))?;
#[cfg(any(test, feature = "test-utils"))]
let client_config = if insecure_skip_relay_cert_verify {
iroh_relay::client::make_dangerous_client_config()
} else {
default_quic_client_config()
};
#[cfg(not(any(test, feature = "test-utils")))]
let client_config = default_quic_client_config();
let net_report_config = net_report::Options::default();
#[cfg(not(wasm_browser))]
let net_report_config = {
let qad_config = has_ip_transports.then(|| QuicConfig {
ep: endpoint.clone(),
client_config,
ipv4: true,
ipv6: has_ipv6_transport,
});
net_report_config.quic_config(qad_config)
};
#[cfg(any(test, feature = "test-utils"))]
let net_report_config =
net_report_config.insecure_skip_relay_cert_verify(insecure_skip_relay_cert_verify);
let net_reporter = net_report::Client::new(
#[cfg(not(wasm_browser))]
dns_resolver,
relay_map.clone(),
net_report_config,
metrics.net_report.clone(),
);
let (direct_addr_done_tx, direct_addr_done_rx) = mpsc::channel(8);
let direct_addr_update_state = DirectAddrUpdateState::new(
sock.clone(),
#[cfg(not(wasm_browser))]
port_mapper,
Arc::new(AsyncMutex::new(net_reporter)),
relay_map,
direct_addr_done_tx,
sock.shutdown.at_close_start.child_token(),
);
let netmon_watcher = network_monitor.interface_state();
#[cfg_attr(not(wasm_browser), allow(unused_mut))]
let mut actor = Actor {
sock: sock.clone(),
remote_map,
periodic_re_stun_timer: new_re_stun_timer(false),
network_monitor,
netmon_watcher,
direct_addr_update_state,
network_change_sender,
direct_addr_done_rx,
};
#[cfg(not(wasm_browser))]
actor.update_direct_addresses(None);
let actor_task = task::spawn(
actor
.run(
actor_receiver,
shutdown_token.child_token(),
local_addrs_watch,
)
.instrument(info_span!("actor")),
);
let actor_task = Mutex::new(Some(AbortOnDropHandle::new(actor_task)));
Ok(EndpointInner {
sock,
actor_sender,
actor_task,
endpoint,
runtime,
static_config,
})
}
pub(crate) fn quinn_endpoint(&self) -> &quinn::Endpoint {
&self.endpoint
}
#[instrument(skip_all)]
pub(crate) async fn close(&self) {
if self.sock.is_closed() || self.sock.is_closing() {
return;
}
trace!(me = ?self.public_key, "socket closing...");
self.sock.shutdown.at_close_start.cancel();
self.sock.address_lookup().clear();
self.quinn_endpoint().close(0u16.into(), b"");
trace!("wait_idle start");
self.quinn_endpoint().wait_idle().await;
trace!("wait_idle done");
self.sock.shutdown.at_endpoint_closed.cancel();
let task = self.actor_task.lock().expect("poisoned").take();
if let Some(task) = task {
let shutdown_done = time::timeout(Duration::from_millis(100), async move {
if let Err(err) = task.await {
warn!("unexpected error in task shutdown: {:?}", err);
}
})
.await;
match shutdown_done {
Ok(_) => trace!("tasks finished in time, shutdown complete"),
Err(time::Elapsed { .. }) => {
warn!("tasks didn't finish in time, aborting");
}
}
}
self.runtime.shutdown().await;
self.sock.shutdown.closed.store(true, Ordering::SeqCst);
trace!("socket closed");
}
#[instrument(skip_all)]
pub(crate) fn abort(&self) {
if self.sock.is_closed() || self.sock.is_closing() {
return;
}
trace!(me = ?self.public_key, "aborting socket...");
self.sock.shutdown.at_close_start.cancel();
self.sock.address_lookup().clear();
self.sock.shutdown.at_endpoint_closed.cancel();
self.runtime.abort();
self.actor_task.lock().expect("poisoned").take();
self.sock.shutdown.closed.store(true, Ordering::SeqCst);
trace!("socket closed");
}
pub(crate) async fn insert_relay(
&self,
relay: RelayUrl,
endpoint: Arc<RelayConfig>,
) -> Option<Arc<RelayConfig>> {
let res = self.relay_map.insert(relay, endpoint);
self.actor_sender
.send(ActorMessage::RelayMapChange)
.await
.ok();
res
}
pub(crate) async fn remove_relay(&self, relay: &RelayUrl) -> Option<Arc<RelayConfig>> {
let res = self.relay_map.remove(relay);
self.actor_sender
.send(ActorMessage::RelayMapChange)
.await
.ok();
res
}
pub(crate) async fn network_change(&self) {
self.actor_sender
.send(ActorMessage::NetworkChange)
.await
.ok();
}
#[cfg(test)]
async fn force_network_change(&self, is_major: bool) {
self.actor_sender
.send(ActorMessage::ForceNetworkChange(is_major))
.await
.ok();
}
pub(crate) async fn resolve_remote(
&self,
addr: EndpointAddr,
) -> Result<Result<EndpointIdMappedAddr, AddressLookupError>, RemoteStateActorStoppedError>
{
let (tx, rx) = oneshot::channel();
self.actor_sender
.send(ActorMessage::ResolveRemote(addr, tx))
.await
.ok();
rx.await.map_err(|_| RemoteStateActorStoppedError::new())?
}
pub(crate) async fn remote_info(&self, id: EndpointId) -> Option<RemoteInfo> {
let (tx, rx) = oneshot::channel();
self.actor_sender
.send(ActorMessage::RemoteInfo(id, tx))
.await
.ok()?;
rx.await.ok()
}
pub(crate) fn register_connection(
&self,
remote: EndpointId,
conn: WeakConnectionHandle,
) -> impl Future<Output = Result<PathWatchable, RemoteStateActorStoppedError>> + Send + 'static
{
let (tx, rx) = oneshot::channel();
let sender = self.actor_sender.clone();
async move {
sender
.send(ActorMessage::AddConnection(remote, conn, tx))
.await
.map_err(|_| RemoteStateActorStoppedError::new())?;
rx.await.map_err(|_| RemoteStateActorStoppedError::new())
}
}
}
fn default_quic_client_config() -> rustls::ClientConfig {
let root_store =
rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
rustls::client::ClientConfig::builder_with_provider(Arc::new(
rustls::crypto::ring::default_provider(),
))
.with_safe_default_protocol_versions()
.expect("ring supports these")
.with_root_certificates(root_store)
.with_no_client_auth()
}
#[derive(derive_more::Debug)]
#[allow(clippy::enum_variant_names)]
enum ActorMessage {
NetworkChange,
RelayMapChange,
#[debug("ResolveRemote(..)")]
ResolveRemote(
EndpointAddr,
oneshot::Sender<
Result<Result<EndpointIdMappedAddr, AddressLookupError>, RemoteStateActorStoppedError>,
>,
),
#[debug("AddConnection(..)")]
AddConnection(
EndpointId,
WeakConnectionHandle,
oneshot::Sender<PathWatchable>,
),
#[debug("RemoteInfo(..)")]
RemoteInfo(EndpointId, oneshot::Sender<RemoteInfo>),
#[cfg(test)]
ForceNetworkChange(bool),
}
struct Actor {
sock: Arc<Socket>,
remote_map: RemoteMap,
periodic_re_stun_timer: time::Interval,
network_monitor: netmon::Monitor,
netmon_watcher: n0_watcher::Direct<netmon::State>,
network_change_sender: transports::NetworkChangeSender,
direct_addr_update_state: DirectAddrUpdateState,
direct_addr_done_rx: mpsc::Receiver<()>,
}
impl Actor {
async fn run(
mut self,
mut msg_receiver: mpsc::Receiver<ActorMessage>,
shutdown_token: CancellationToken,
mut local_addrs_watcher: impl Watcher<Value = Vec<transports::Addr>> + Send + Sync,
) {
let mut current_netmon_state = self.netmon_watcher.get();
#[cfg(not(wasm_browser))]
let mut portmap_watcher = self
.direct_addr_update_state
.port_mapper
.watch_external_address();
let mut receiver_closed = false;
#[cfg_attr(wasm_browser, allow(unused_mut))]
let mut portmap_watcher_closed = false;
let mut net_report_watcher = self.sock.net_report.watch();
self.sock.publish_my_addr();
while !shutdown_token.is_cancelled() {
self.sock.metrics.socket.actor_tick_main.inc();
#[cfg(not(wasm_browser))]
let portmap_watcher_changed = portmap_watcher.changed();
#[cfg(wasm_browser)]
let portmap_watcher_changed = n0_future::future::pending();
tokio::select! {
_ = shutdown_token.cancelled() => {
debug!("tick: shutting down");
return;
}
msg = msg_receiver.recv(), if !receiver_closed => {
let Some(msg) = msg else {
trace!("tick: socket receiver closed");
self.sock.metrics.socket.actor_tick_other.inc();
receiver_closed = true;
continue;
};
trace!(?msg, "tick: msg");
self.sock.metrics.socket.actor_tick_msg.inc();
self.handle_actor_message(msg).await;
}
tick = self.periodic_re_stun_timer.tick() => {
trace!("tick: re_stun {:?}", tick);
self.sock.metrics.socket.actor_tick_re_stun.inc();
self.re_stun(UpdateReason::Periodic);
}
new_addr = local_addrs_watcher.updated() => {
match new_addr {
Ok(addrs) => {
if !addrs.is_empty() {
trace!(?addrs, "local addrs");
self.sock.publish_my_addr();
}
}
Err(_) => {
warn!("local addr watcher stopped");
}
}
}
report = net_report_watcher.updated() => {
match report {
Ok((report, _)) => {
self.handle_net_report_report(report);
#[cfg(not(wasm_browser))]
{
self.periodic_re_stun_timer = new_re_stun_timer(true);
}
}
Err(_) => {
warn!("net report watcher stopped");
}
}
}
reason = self.direct_addr_done_rx.recv() => {
match reason {
Some(()) => {
let state = self.netmon_watcher.get();
self.direct_addr_update_state.try_run(state.into());
}
None => {
warn!("direct addr watcher died");
}
}
}
change = portmap_watcher_changed, if !portmap_watcher_closed => {
#[cfg(not(wasm_browser))]
{
if change.is_err() {
trace!("tick: portmap watcher closed");
self.sock.metrics.socket.actor_tick_other.inc();
portmap_watcher_closed = true;
continue;
}
trace!("tick: portmap changed");
self.sock.metrics.socket.actor_tick_portmap_changed.inc();
let new_external_address = *portmap_watcher.borrow();
debug!("external address updated: {new_external_address:?}");
self.re_stun(UpdateReason::PortmapUpdated);
}
#[cfg(wasm_browser)]
let _unused_in_browsers = change;
},
state = self.netmon_watcher.updated() => {
let Ok(state) = state else {
trace!("tick: link change receiver closed");
self.sock.metrics.socket.actor_tick_other.inc();
continue;
};
let is_major = state.is_major_change(¤t_netmon_state);
event!(
target: "iroh::_events::link_change",
Level::DEBUG,
?state,
is_major
);
current_netmon_state = state;
self.sock.metrics.socket.actor_link_change.inc();
self.handle_network_change(is_major).await;
}
eid = poll_fn(|cx| self.remote_map.poll_cleanup(cx)) => {
trace!(%eid, "cleaned up RemoteStateActor");
}
else => {
trace!("tick: else");
}
}
}
}
async fn handle_network_change(&mut self, is_major: bool) {
debug!(is_major, "link change detected");
if is_major {
if let Err(err) = self.network_change_sender.rebind() {
warn!("failed to rebind transports: {err:?}");
}
#[cfg(not(wasm_browser))]
self.sock.dns_resolver.reset().await;
self.re_stun(UpdateReason::LinkChangeMajor);
} else {
self.re_stun(UpdateReason::LinkChangeMinor);
}
self.remote_map.on_network_change(is_major);
}
fn handle_relay_map_change(&mut self) {
self.re_stun(UpdateReason::RelayMapChange);
}
fn re_stun(&mut self, why: UpdateReason) {
let state = self.netmon_watcher.get();
self.direct_addr_update_state
.schedule_run(why, state.into());
}
async fn handle_actor_message(&mut self, msg: ActorMessage) {
match msg {
ActorMessage::NetworkChange => {
self.network_monitor.network_change().await.ok();
}
ActorMessage::RelayMapChange => {
self.handle_relay_map_change();
}
ActorMessage::ResolveRemote(addr, tx) => {
tx.send(self.remote_map.resolve_remote(addr).await).ok();
}
ActorMessage::RemoteInfo(id, tx) => {
if let Some(info) = self.remote_map.remote_info(id).await {
tx.send(info).ok();
}
}
ActorMessage::AddConnection(remote, conn, tx) => {
if let Some(watcher) = self.remote_map.add_connection(remote, conn).await {
tx.send(watcher).ok();
}
}
#[cfg(test)]
ActorMessage::ForceNetworkChange(is_major) => {
self.handle_network_change(is_major).await;
}
}
}
#[cfg(not(wasm_browser))]
fn update_direct_addresses(&mut self, net_report_report: Option<&net_report::Report>) {
let portmap_watcher = self
.direct_addr_update_state
.port_mapper
.watch_external_address();
let mut addrs: BTreeMap<SocketAddr, DirectAddrType> = BTreeMap::new();
let maybe_port_mapped = *portmap_watcher.borrow();
if let Some(portmap_ext) = maybe_port_mapped.map(SocketAddr::V4) {
addrs
.entry(portmap_ext)
.or_insert(DirectAddrType::Portmapped);
}
if let Some(net_report_report) = net_report_report {
if let Some(global_v4) = net_report_report.global_v4 {
addrs.entry(global_v4.into()).or_insert(DirectAddrType::Qad);
let port = self.sock.ip_bind_addrs().iter().find_map(|addr| {
if addr.port() != 0 {
Some(addr.port())
} else {
None
}
});
if let Some(port) = port
&& net_report_report
.mapping_varies_by_dest()
.unwrap_or_default()
{
let mut addr = global_v4;
addr.set_port(port);
addrs
.entry(addr.into())
.or_insert(DirectAddrType::Qad4LocalPort);
}
}
if let Some(global_v6) = net_report_report.global_v6 {
addrs.entry(global_v6.into()).or_insert(DirectAddrType::Qad);
}
}
self.collect_local_addresses(&mut addrs);
self.sock.store_direct_addresses(
addrs
.iter()
.map(|(addr, typ)| DirectAddr {
addr: *addr,
typ: *typ,
})
.collect(),
);
}
#[cfg(not(wasm_browser))]
fn collect_local_addresses(&mut self, addrs: &mut BTreeMap<SocketAddr, DirectAddrType>) {
let local_addrs: Vec<(SocketAddr, SocketAddr)> = self
.sock
.ip_bind_addrs()
.iter()
.copied()
.zip(self.sock.ip_local_addrs())
.collect();
let has_ipv4_unspecified = local_addrs.iter().find_map(|(_, a)| {
if a.is_ipv4() && a.ip().is_unspecified() {
Some(a.port())
} else {
None
}
});
let has_ipv6_unspecified = local_addrs.iter().find_map(|(_, a)| {
if a.is_ipv6() && a.ip().is_unspecified() {
Some(a.port())
} else {
None
}
});
if local_addrs
.iter()
.any(|(_, local)| local.ip().is_unspecified())
{
let LocalAddresses {
regular: mut ips,
loopback,
} = self.netmon_watcher.get().local_addresses;
if ips.is_empty() && addrs.is_empty() {
ips = loopback;
}
for ip in ips {
let port_if_unspecified = match ip {
IpAddr::V4(_) => has_ipv4_unspecified,
IpAddr::V6(_) => has_ipv6_unspecified,
};
if let Some(port) = port_if_unspecified {
let addr = SocketAddr::new(ip, port);
addrs.entry(addr).or_insert(DirectAddrType::Local);
}
}
}
for (bound, local) in local_addrs {
if !bound.ip().is_unspecified() {
addrs.entry(local).or_insert(DirectAddrType::Local);
}
}
}
fn handle_net_report_report(&mut self, mut report: Option<net_report::Report>) {
if let Some(ref mut r) = report {
self.sock.ipv6_reported.store(r.udp_v6, Ordering::Relaxed);
if r.preferred_relay.is_none()
&& let Some(my_relay) = self.sock.my_relay()
{
r.preferred_relay.replace(my_relay);
}
self.network_change_sender.on_network_change(r);
}
#[cfg(not(wasm_browser))]
self.update_direct_addresses(report.as_ref());
}
}
fn new_re_stun_timer(initial_delay: bool) -> time::Interval {
let mut rng = rand::rng();
let d: Duration = rng.random_range(Duration::from_secs(20)..=Duration::from_secs(26));
if initial_delay {
debug!("scheduling periodic_stun to run in {}s", d.as_secs());
time::interval_at(time::Instant::now() + d, d)
} else {
debug!(
"scheduling periodic_stun to run immediately and in {}s",
d.as_secs()
);
time::interval(d)
}
}
#[derive(derive_more::Debug, Clone, Default)]
struct DiscoveredDirectAddrs {
addrs: Watchable<BTreeSet<DirectAddr>>,
updated_at: Arc<RwLock<Option<Instant>>>,
}
impl DiscoveredDirectAddrs {
fn update(&self, addrs: BTreeSet<DirectAddr>) -> bool {
*self.updated_at.write().expect("poisoned") = Some(Instant::now());
let updated = self.addrs.set(addrs).is_ok();
if updated {
event!(
target: "iroh::_events::direct_addrs",
Level::DEBUG,
addrs = ?self.addrs.get(),
);
}
updated
}
fn sockaddrs(&self) -> impl Iterator<Item = SocketAddr> {
self.addrs.get().into_iter().map(|da| da.addr)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct DirectAddr {
pub addr: SocketAddr,
pub typ: DirectAddrType,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum DirectAddrType {
Unknown,
Local,
Qad,
Portmapped,
Qad4LocalPort,
}
impl Display for DirectAddrType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DirectAddrType::Unknown => write!(f, "?"),
DirectAddrType::Local => write!(f, "local"),
DirectAddrType::Qad => write!(f, "qad"),
DirectAddrType::Portmapped => write!(f, "portmap"),
DirectAddrType::Qad4LocalPort => write!(f, "qad4localport"),
}
}
}
#[cfg(test)]
mod tests {
use std::{net::SocketAddrV4, sync::Arc, time::Duration};
use data_encoding::HEXLOWER;
use iroh_base::{EndpointAddr, EndpointId, TransportAddr};
use n0_error::{Result, StackResultExt, StdResultExt};
use n0_future::{MergeBounded, StreamExt, time};
use n0_tracing_test::traced_test;
use n0_watcher::Watcher;
use rand::{CryptoRng, Rng, RngCore, SeedableRng};
use tokio_util::task::AbortOnDropHandle;
use tracing::{Instrument, error, info, info_span, instrument};
use super::Options;
use crate::{
Endpoint, RelayMode, SecretKey,
address_lookup::memory::MemoryLookup,
dns::DnsResolver,
endpoint::QuicTransportConfig,
socket::{
EndpointInner, Socket, StaticConfig, TransportConfig,
mapped_addrs::{EndpointIdMappedAddr, MappedAddr},
},
tls::{self, DEFAULT_MAX_TLS_TICKETS},
};
const ALPN: &[u8] = b"n0/test/1";
fn default_options<R: CryptoRng + ?Sized>(rng: &mut R) -> Options {
let secret_key = SecretKey::generate(rng);
let static_config = StaticConfig {
tls_config: tls::TlsConfig::new(secret_key.clone(), DEFAULT_MAX_TLS_TICKETS),
transport_config: QuicTransportConfig::default(),
keylog: false,
};
let server_config = static_config.create_server_config(vec![]);
Options {
transports: vec![
TransportConfig::default_ipv4(),
TransportConfig::default_ipv6(),
],
secret_key,
proxy_url: None,
dns_resolver: DnsResolver::new(),
server_config,
#[cfg(any(test, feature = "test-utils"))]
insecure_skip_relay_cert_verify: false,
#[cfg(any(test, feature = "test-utils"))]
address_lookup_user_data: None,
metrics: Default::default(),
hooks: Default::default(),
static_config,
}
}
#[instrument(skip_all, fields(me = %ep.id().fmt_short()))]
async fn echo_receiver(ep: Endpoint, loss: ExpectedLoss) -> Result {
info!("accepting conn");
let conn = ep.accept().await.expect("no conn");
info!("accepting");
let conn = conn.await.context("accepting")?;
info!("accepting bi");
let (mut send_bi, mut recv_bi) = conn.accept_bi().await.std_context("accept bi")?;
info!("reading");
let val = recv_bi
.read_to_end(usize::MAX)
.await
.std_context("read to end")?;
info!("replying");
for chunk in val.chunks(12) {
send_bi.write_all(chunk).await.std_context("write all")?;
}
info!("finishing");
send_bi.finish().std_context("finish")?;
send_bi.stopped().await.std_context("stopped")?;
let stats = conn.stats();
info!("stats: {:#?}", stats);
if matches!(loss, ExpectedLoss::AlmostNone) {
for info in conn.paths().get().iter() {
assert!(
info.stats().unwrap().lost_packets < 10,
"[receiver] path {:?} should not loose many packets",
info.remote_addr()
);
}
}
conn.closed().await;
info!("closed");
ep.inner()?.quinn_endpoint().wait_idle().await;
info!("idle");
Ok(())
}
#[instrument(skip_all, fields(me = %ep.id().fmt_short()))]
async fn echo_sender(
ep: Endpoint,
dest_id: EndpointId,
msg: &[u8],
loss: ExpectedLoss,
) -> Result {
info!("connecting to {}", dest_id.fmt_short());
let dest = EndpointAddr::new(dest_id);
let conn = ep.connect(dest, ALPN).await?;
info!("opening bi");
let (mut send_bi, mut recv_bi) = conn.open_bi().await.std_context("open bi")?;
info!("writing message");
send_bi.write_all(msg).await.std_context("write all")?;
info!("finishing");
send_bi.finish().std_context("finish")?;
send_bi.stopped().await.std_context("stopped")?;
info!("reading_to_end");
let val = recv_bi
.read_to_end(usize::MAX)
.await
.std_context("read to end")?;
assert_eq!(
val,
msg,
"[sender] expected {}, got {}",
HEXLOWER.encode(msg),
HEXLOWER.encode(&val)
);
let stats = conn.stats();
info!("stats: {:#?}", stats);
if matches!(loss, ExpectedLoss::AlmostNone) {
for info in conn.paths().get().iter() {
assert!(
info.stats().unwrap().lost_packets < 10,
"[sender] path {:?} should not loose many packets",
info.remote_addr()
);
}
}
conn.close(0u32.into(), b"done");
info!("closed");
ep.inner()?.quinn_endpoint().wait_idle().await;
info!("idle");
Ok(())
}
#[derive(Debug, Copy, Clone)]
enum ExpectedLoss {
AlmostNone,
YeahSure,
}
async fn run_roundtrip(
sender: Endpoint,
receiver: Endpoint,
payload: &[u8],
loss: ExpectedLoss,
) -> Result<()> {
tokio::time::timeout(Duration::from_secs(4), async move {
let send_endpoint_id = sender.id();
let recv_endpoint_id = receiver.id();
info!("\nroundtrip: {send_endpoint_id:#} -> {recv_endpoint_id:#}");
let receiver_task = AbortOnDropHandle::new(tokio::spawn(echo_receiver(receiver, loss)));
let sender_res = echo_sender(sender, recv_endpoint_id, payload, loss).await;
let sender_is_err = match sender_res {
Ok(()) => false,
Err(err) => {
error!("[sender] Error:\n{err:#?}");
true
}
};
let receiver_is_err = match receiver_task.await {
Ok(Ok(())) => false,
Ok(Err(err)) => {
error!("[receiver] Error:\n{err:#?}");
true
}
Err(joinerr) => {
if joinerr.is_panic() {
std::panic::resume_unwind(joinerr.into_panic());
} else {
error!("[receiver] Error:\n{joinerr:#?}");
}
true
}
};
if sender_is_err || receiver_is_err {
panic!("Sender or receiver errored");
}
})
.await
.std_context("timeout")?;
Ok(())
}
async fn endpoint_pair() -> (AbortOnDropHandle<()>, Endpoint, Endpoint) {
let address_lookup = MemoryLookup::new();
let ep1 = Endpoint::empty_builder(RelayMode::Disabled)
.alpns(vec![ALPN.to_vec()])
.address_lookup(address_lookup.clone())
.bind()
.await
.unwrap();
let ep2 = Endpoint::empty_builder(RelayMode::Disabled)
.alpns(vec![ALPN.to_vec()])
.address_lookup(address_lookup.clone())
.bind()
.await
.unwrap();
address_lookup.add_endpoint_info(ep1.addr());
address_lookup.add_endpoint_info(ep2.addr());
let ep1_addr_stream = ep1.watch_addr().stream();
let ep2_addr_stream = ep2.watch_addr().stream();
let mut addr_stream = MergeBounded::from_iter([ep1_addr_stream, ep2_addr_stream]);
let task = tokio::spawn(async move {
while let Some(addr) = addr_stream.next().await {
address_lookup.add_endpoint_info(addr);
}
});
(AbortOnDropHandle::new(task), ep1, ep2)
}
#[tokio::test(flavor = "multi_thread")]
#[traced_test]
async fn test_two_devices_roundtrip_quinn_small() -> Result {
let (_guard, m1, m2) = endpoint_pair().await;
run_roundtrip(
m1.clone(),
m2.clone(),
b"hello m1",
ExpectedLoss::AlmostNone,
)
.await?;
run_roundtrip(
m2.clone(),
m1.clone(),
b"hello m2",
ExpectedLoss::AlmostNone,
)
.await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
#[traced_test]
async fn test_two_devices_roundtrip_quinn_large() -> Result {
let (_guard, m1, m2) = endpoint_pair().await;
let mut data = vec![0u8; 10 * 1024];
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
rng.fill_bytes(&mut data);
run_roundtrip(m1.clone(), m2.clone(), &data, ExpectedLoss::AlmostNone).await?;
run_roundtrip(m2.clone(), m1.clone(), &data, ExpectedLoss::AlmostNone).await?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_regression_network_change_rebind_wakes_connection_driver() -> Result {
let (_guard, m1, m2) = endpoint_pair().await;
println!("Net change");
m1.inner()?.force_network_change(true).await;
tokio::time::sleep(Duration::from_secs(1)).await;
let _handle = AbortOnDropHandle::new(tokio::spawn({
let endpoint = m2.clone();
async move {
while let Some(incoming) = endpoint.accept().await {
println!("Incoming first conn!");
let conn = incoming.await.anyerr()?;
conn.closed().await;
}
n0_error::Ok(())
}
}));
println!("first conn!");
let conn = m1.connect(m2.addr(), ALPN).await?;
println!("Closing first conn");
conn.close(0u32.into(), b"bye lolz");
conn.closed().await;
println!("Closed first conn");
Ok(())
}
fn offset(rng: &mut rand_chacha::ChaCha8Rng) -> Duration {
let delay = rng.random_range(1..=5);
Duration::from_millis(delay * 50)
}
#[tokio::test(flavor = "multi_thread")]
#[traced_test]
async fn test_two_devices_roundtrip_network_change_only_a() -> Result {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
let (_guard, m1, m2) = endpoint_pair().await;
let _network_change_guard = {
let m1 = m1.clone();
let mut rng = rng.clone();
let task = tokio::spawn(async move {
loop {
info!("[m1] network change");
m1.inner()
.expect("haven't closed the endpoint yet")
.force_network_change(true)
.await;
time::sleep(offset(&mut rng)).await;
}
});
AbortOnDropHandle::new(task)
};
let mut data = vec![0u8; 10 * 1024];
rng.fill_bytes(&mut data);
run_roundtrip(m1.clone(), m2.clone(), &data, ExpectedLoss::YeahSure).await?;
run_roundtrip(m2.clone(), m1.clone(), &data, ExpectedLoss::YeahSure).await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
#[traced_test]
async fn test_two_devices_roundtrip_network_change_a_and_b() -> Result {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
let (_guard, m1, m2) = endpoint_pair().await;
let _network_change_guard = {
let m1 = m1.clone();
let m2 = m2.clone();
let mut rng = rng.clone();
let task = tokio::spawn(async move {
info!("-- [m1] network change");
m1.inner()
.expect("haven't closed the endpoint yet")
.force_network_change(true)
.await;
info!("-- [m2] network change");
m2.inner()
.expect("haven't closed the endpoint yet")
.force_network_change(true)
.await;
time::sleep(offset(&mut rng)).await;
});
AbortOnDropHandle::new(task)
};
let mut data = vec![0u8; 10 * 1024];
rng.fill_bytes(&mut data);
run_roundtrip(m1.clone(), m2.clone(), &data, ExpectedLoss::YeahSure).await?;
run_roundtrip(m2.clone(), m1.clone(), &data, ExpectedLoss::YeahSure).await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
#[traced_test]
async fn test_two_devices_setup_teardown() -> Result {
for i in 0..10 {
info!("-- round {i}");
info!("setting up stack");
let (_guard, m1, m2) = endpoint_pair().await;
info!("closing endpoints");
let sock1 = m1.inner()?;
let sock2 = m2.inner()?;
m1.close().await;
m2.close().await;
assert!(sock1.is_closed());
assert!(sock2.is_closed());
}
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_direct_addresses() {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
let sock = EndpointInner::new(default_options(&mut rng)).await.unwrap();
let eps0 = sock.ip_addrs().get();
info!("{eps0:?}");
assert!(!eps0.is_empty());
let eps1 = sock.ip_addrs().get();
info!("{eps1:?}");
assert_eq!(eps0, eps1);
}
#[instrument(name = "ep", skip_all, fields(me = %secret_key.public().fmt_short()))]
async fn socket_ep(secret_key: SecretKey) -> Result<EndpointInner> {
let static_config = StaticConfig {
tls_config: tls::TlsConfig::new(secret_key.clone(), DEFAULT_MAX_TLS_TICKETS),
transport_config: QuicTransportConfig::default(),
keylog: true,
};
let server_config = static_config.create_server_config(vec![ALPN.to_vec()]);
let dns_resolver = DnsResolver::new();
let opts = Options {
transports: vec![
TransportConfig::default_ipv4(),
TransportConfig::default_ipv6(),
],
secret_key: secret_key.clone(),
address_lookup_user_data: None,
dns_resolver,
proxy_url: None,
server_config,
insecure_skip_relay_cert_verify: false,
metrics: Default::default(),
hooks: Default::default(),
static_config,
};
let sock = Socket::spawn(opts).await?;
Ok(sock)
}
#[instrument(name = "connect", skip_all, fields(me = %ep_secret_key.public().fmt_short()))]
async fn socket_connect(
ep: quinn::Endpoint,
ep_secret_key: SecretKey,
addr: EndpointIdMappedAddr,
endpoint_id: EndpointId,
) -> Result<quinn::Connection> {
let mut transport_config = quinn::TransportConfig::default();
transport_config.keep_alive_interval(Some(Duration::from_secs(1)));
socket_connect_with_transport_config(
ep,
ep_secret_key,
addr,
endpoint_id,
Arc::new(transport_config),
)
.await
}
#[instrument(name = "connect", skip_all, fields(me = %ep_secret_key.public().fmt_short()))]
async fn socket_connect_with_transport_config(
ep: quinn::Endpoint,
ep_secret_key: SecretKey,
mapped_addr: EndpointIdMappedAddr,
endpoint_id: EndpointId,
transport_config: Arc<quinn::TransportConfig>,
) -> Result<quinn::Connection> {
let alpns = vec![ALPN.to_vec()];
let quic_client_config =
tls::TlsConfig::new(ep_secret_key.clone(), DEFAULT_MAX_TLS_TICKETS)
.make_client_config(alpns, true);
let mut client_config = quinn::ClientConfig::new(Arc::new(quic_client_config));
client_config.transport_config(transport_config);
let connect = ep
.connect_with(
client_config,
mapped_addr.private_socket_addr(),
&tls::name::encode(endpoint_id),
)
.std_context("connect")?;
let connection = connect.await.anyerr()?;
Ok(connection)
}
#[tokio::test]
#[traced_test]
async fn test_try_send_no_send_addr() {
let secret_key_1 = SecretKey::from_bytes(&[1u8; 32]);
let secret_key_2 = SecretKey::from_bytes(&[2u8; 32]);
let endpoint_id_2 = secret_key_2.public();
let secret_key_missing_endpoint = SecretKey::from_bytes(&[255u8; 32]);
let endpoint_id_missing_endpoint = secret_key_missing_endpoint.public();
let sock_1 = socket_ep(secret_key_1.clone()).await.unwrap();
let bad_addr = EndpointIdMappedAddr::generate();
let res = tokio::time::timeout(
Duration::from_millis(500),
socket_connect(
sock_1.quinn_endpoint().clone(),
secret_key_1.clone(),
bad_addr,
endpoint_id_missing_endpoint,
),
)
.await;
assert!(res.is_err(), "expecting timeout");
let sock_2 = socket_ep(secret_key_2.clone()).await.unwrap();
let accept_task = tokio::spawn({
async fn accept(ep: quinn::Endpoint) -> Result<()> {
let incoming = ep.accept().await.std_context("no incoming")?;
let _conn = incoming
.accept()
.std_context("accept")?
.await
.std_context("accepting")?;
tokio::time::sleep(Duration::from_secs(10)).await;
info!("accept finished");
Ok(())
}
let ep = sock_2.quinn_endpoint().clone();
async move {
if let Err(err) = accept(ep).await {
error!("{err:#}");
}
}
.instrument(info_span!("ep2.accept, me = endpoint_id_2.fmt_short()"))
});
let _accept_task = AbortOnDropHandle::new(accept_task);
let addrs = sock_2
.ip_addrs()
.get()
.into_iter()
.map(|x| TransportAddr::Ip(x.addr));
let endpoint_addr_2 = EndpointAddr::from_parts(endpoint_id_2, addrs);
let addr = sock_1
.resolve_remote(endpoint_addr_2)
.await
.unwrap()
.unwrap();
let res = tokio::time::timeout(
Duration::from_secs(10),
socket_connect(
sock_1.quinn_endpoint().clone(),
secret_key_1.clone(),
addr,
endpoint_id_2,
),
)
.await
.expect("timeout while connecting");
res.unwrap();
}
#[tokio::test]
#[traced_test]
async fn test_try_send_no_udp_addr_or_relay_url() {
let secret_key_1 = SecretKey::from_bytes(&[1u8; 32]);
let secret_key_2 = SecretKey::from_bytes(&[2u8; 32]);
let endpoint_id_2 = secret_key_2.public();
let sock_1 = socket_ep(secret_key_1.clone()).await.unwrap();
let sock_2 = socket_ep(secret_key_2.clone()).await.unwrap();
let ep_2 = sock_2.quinn_endpoint().clone();
let accept_task = tokio::spawn({
async fn accept(ep: quinn::Endpoint) -> Result<()> {
let incoming = ep.accept().await.std_context("no incoming")?;
let conn = incoming
.accept()
.std_context("accept")?
.await
.std_context("connecting")?;
let mut stream = conn.accept_uni().await.std_context("accept uni")?;
stream
.read_to_end(1 << 16)
.await
.std_context("read to end")?;
info!("accept finished");
Ok(())
}
async move {
if let Err(err) = accept(ep_2).await {
error!("{err:#}");
}
}
.instrument(info_span!("ep2.accept", me = %endpoint_id_2.fmt_short()))
});
let _accept_task = AbortOnDropHandle::new(accept_task);
let empty_addr_2 = EndpointAddr::from_parts(
endpoint_id_2,
[TransportAddr::Ip(
SocketAddrV4::new([192, 0, 2, 1].into(), 12345).into(),
)],
);
let addr_2 = sock_1.resolve_remote(empty_addr_2).await.unwrap().unwrap();
let mut transport_config = quinn::TransportConfig::default();
transport_config.max_idle_timeout(Some(Duration::from_millis(200).try_into().unwrap()));
let res = socket_connect_with_transport_config(
sock_1.quinn_endpoint().clone(),
secret_key_1.clone(),
addr_2,
endpoint_id_2,
Arc::new(transport_config),
)
.await;
assert!(res.is_err(), "expected timeout");
info!("first connect timed out as expected");
let correct_addr_2 = EndpointAddr::from_parts(
endpoint_id_2,
sock_2
.ip_addrs()
.get()
.into_iter()
.map(|x| TransportAddr::Ip(x.addr)),
);
let addr_2a = sock_1
.resolve_remote(correct_addr_2)
.await
.unwrap()
.unwrap();
assert_eq!(addr_2, addr_2a);
tokio::time::timeout(Duration::from_secs(10), async move {
info!("establishing new connection");
let conn = socket_connect(
sock_1.quinn_endpoint().clone(),
secret_key_1.clone(),
addr_2,
endpoint_id_2,
)
.await
.unwrap();
info!("have connection");
let mut stream = conn.open_uni().await.unwrap();
stream.write_all(b"hello").await.unwrap();
stream.finish().unwrap();
stream.stopped().await.unwrap();
info!("finished stream");
})
.await
.expect("connection timed out");
}
}