use std::{
fmt,
future::Future,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
sync::Arc,
};
use hickory_resolver::{
TokioResolver,
config::{ResolverConfig, ResolverOpts},
name_server::TokioConnectionProvider,
};
use iroh_base::EndpointId;
use n0_error::{StackError, e, stack_error};
use n0_future::{
StreamExt,
boxed::BoxFuture,
time::{self, Duration},
};
use tokio::sync::RwLock;
use tracing::debug;
use url::Url;
use crate::{
defaults::timeouts::DNS_TIMEOUT,
endpoint_info::{self, EndpointInfo, ParseError},
};
pub const N0_DNS_ENDPOINT_ORIGIN_PROD: &str = "dns.iroh.link";
pub const N0_DNS_ENDPOINT_ORIGIN_STAGING: &str = "staging-dns.iroh.link";
const MAX_JITTER_PERCENT: u64 = 20;
pub trait Resolver: fmt::Debug + Send + Sync + 'static {
fn lookup_ipv4(&self, host: String) -> BoxFuture<Result<BoxIter<Ipv4Addr>, DnsError>>;
fn lookup_ipv6(&self, host: String) -> BoxFuture<Result<BoxIter<Ipv6Addr>, DnsError>>;
fn lookup_txt(&self, host: String) -> BoxFuture<Result<BoxIter<TxtRecordData>, DnsError>>;
fn clear_cache(&self);
fn reset(&mut self);
}
pub type BoxIter<T> = Box<dyn Iterator<Item = T> + Send + 'static>;
#[allow(missing_docs)]
#[stack_error(derive, add_meta, from_sources, std_sources)]
#[non_exhaustive]
pub enum DnsError {
#[error(transparent)]
Timeout { source: tokio::time::error::Elapsed },
#[error("No response")]
NoResponse {},
#[error("Resolve failed ipv4: {ipv4}, ipv6 {ipv6}")]
ResolveBoth {
ipv4: Box<DnsError>,
ipv6: Box<DnsError>,
},
#[error("missing host")]
MissingHost {},
#[error(transparent)]
Resolve {
source: hickory_resolver::ResolveError,
},
#[error("invalid DNS response: not a query for _iroh.z32encodedpubkey")]
InvalidResponse {},
}
#[cfg(not(wasm_browser))]
#[allow(missing_docs)]
#[stack_error(derive, add_meta, from_sources)]
#[non_exhaustive]
pub enum LookupError {
#[error("Malformed txt from lookup")]
ParseError { source: ParseError },
#[error("Failed to resolve TXT record")]
LookupFailed { source: DnsError },
}
#[stack_error(derive, add_meta)]
#[error("no calls succeeded: [{}]", errors.iter().map(|e| e.to_string()).collect::<Vec<_>>().join(""))]
pub struct StaggeredError<E: n0_error::StackError + 'static> {
errors: Vec<E>,
}
impl<E: StackError + 'static> StaggeredError<E> {
pub fn iter(&self) -> impl Iterator<Item = &E> {
self.errors.iter()
}
}
#[derive(Debug, Clone, Default)]
pub struct Builder {
use_system_defaults: bool,
nameservers: Vec<(SocketAddr, DnsProtocol)>,
}
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
#[non_exhaustive]
pub enum DnsProtocol {
#[default]
Udp,
Tcp,
Tls,
Https,
}
impl DnsProtocol {
fn to_hickory(self) -> hickory_resolver::proto::xfer::Protocol {
use hickory_resolver::proto::xfer::Protocol;
match self {
DnsProtocol::Udp => Protocol::Udp,
DnsProtocol::Tcp => Protocol::Tcp,
DnsProtocol::Tls => Protocol::Tls,
DnsProtocol::Https => Protocol::Https,
}
}
}
impl Builder {
pub fn with_system_defaults(mut self) -> Self {
self.use_system_defaults = true;
self
}
pub fn with_nameserver(mut self, addr: SocketAddr, protocol: DnsProtocol) -> Self {
self.nameservers.push((addr, protocol));
self
}
pub fn with_nameservers(
mut self,
nameservers: impl IntoIterator<Item = (SocketAddr, DnsProtocol)>,
) -> Self {
self.nameservers.extend(nameservers);
self
}
pub fn build(self) -> DnsResolver {
let resolver = HickoryResolver::new(self);
DnsResolver(DnsResolverInner::Hickory(Arc::new(RwLock::new(resolver))))
}
}
#[derive(Debug, Clone)]
pub struct DnsResolver(DnsResolverInner);
impl DnsResolver {
pub fn new() -> Self {
Builder::default().with_system_defaults().build()
}
pub fn with_nameserver(nameserver: SocketAddr) -> Self {
Builder::default()
.with_nameserver(nameserver, DnsProtocol::Udp)
.build()
}
pub fn builder() -> Builder {
Builder::default()
}
pub fn custom(resolver: impl Resolver) -> Self {
Self(DnsResolverInner::Custom(Arc::new(RwLock::new(resolver))))
}
pub async fn clear_cache(&self) {
self.0.clear_cache().await
}
pub async fn reset(&self) {
self.0.reset().await
}
pub async fn lookup_txt<T: ToString>(
&self,
host: T,
timeout: Duration,
) -> Result<impl Iterator<Item = TxtRecordData>, DnsError> {
let host = host.to_string();
let res = time::timeout(timeout, self.0.lookup_txt(host)).await??;
Ok(res)
}
pub async fn lookup_ipv4<T: ToString>(
&self,
host: T,
timeout: Duration,
) -> Result<impl Iterator<Item = IpAddr> + use<T>, DnsError> {
let host = host.to_string();
let addrs = time::timeout(timeout, self.0.lookup_ipv4(host)).await??;
Ok(addrs.into_iter().map(IpAddr::V4))
}
pub async fn lookup_ipv6<T: ToString>(
&self,
host: T,
timeout: Duration,
) -> Result<impl Iterator<Item = IpAddr> + use<T>, DnsError> {
let host = host.to_string();
let addrs = time::timeout(timeout, self.0.lookup_ipv6(host)).await??;
Ok(addrs.into_iter().map(IpAddr::V6))
}
pub async fn lookup_ipv4_ipv6<T: ToString>(
&self,
host: T,
timeout: Duration,
) -> Result<impl Iterator<Item = IpAddr> + use<T>, DnsError> {
let host = host.to_string();
let res = tokio::join!(
self.lookup_ipv4(host.clone(), timeout),
self.lookup_ipv6(host, timeout)
);
match res {
(Ok(ipv4), Ok(ipv6)) => Ok(LookupIter::Both(ipv4.chain(ipv6))),
(Ok(ipv4), Err(_)) => Ok(LookupIter::Ipv4(ipv4)),
(Err(_), Ok(ipv6)) => Ok(LookupIter::Ipv6(ipv6)),
(Err(ipv4_err), Err(ipv6_err)) => Err(e!(DnsError::ResolveBoth {
ipv4: Box::new(ipv4_err),
ipv6: Box::new(ipv6_err)
})),
}
}
pub async fn resolve_host(
&self,
url: &Url,
prefer_ipv6: bool,
timeout: Duration,
) -> Result<IpAddr, DnsError> {
let host = url.host().ok_or_else(|| e!(DnsError::MissingHost))?;
match host {
url::Host::Domain(domain) => {
let lookup = tokio::join!(
self.lookup_ipv4(domain, timeout),
self.lookup_ipv6(domain, timeout)
);
let (v4, v6) = match lookup {
(Err(ipv4_err), Err(ipv6_err)) => {
return Err(e!(DnsError::ResolveBoth {
ipv4: Box::new(ipv4_err),
ipv6: Box::new(ipv6_err)
}));
}
(Err(_), Ok(mut v6)) => (None, v6.next()),
(Ok(mut v4), Err(_)) => (v4.next(), None),
(Ok(mut v4), Ok(mut v6)) => (v4.next(), v6.next()),
};
if prefer_ipv6 {
v6.or(v4).ok_or_else(|| e!(DnsError::NoResponse))
} else {
v4.or(v6).ok_or_else(|| e!(DnsError::NoResponse))
}
}
url::Host::Ipv4(ip) => Ok(IpAddr::V4(ip)),
url::Host::Ipv6(ip) => Ok(IpAddr::V6(ip)),
}
}
pub async fn lookup_ipv4_staggered(
&self,
host: impl ToString,
timeout: Duration,
delays_ms: &[u64],
) -> Result<impl Iterator<Item = IpAddr>, StaggeredError<DnsError>> {
let host = host.to_string();
let f = || self.lookup_ipv4(host.clone(), timeout);
stagger_call(f, delays_ms).await
}
pub async fn lookup_ipv6_staggered(
&self,
host: impl ToString,
timeout: Duration,
delays_ms: &[u64],
) -> Result<impl Iterator<Item = IpAddr>, StaggeredError<DnsError>> {
let host = host.to_string();
let f = || self.lookup_ipv6(host.clone(), timeout);
stagger_call(f, delays_ms).await
}
pub async fn lookup_ipv4_ipv6_staggered(
&self,
host: impl ToString,
timeout: Duration,
delays_ms: &[u64],
) -> Result<impl Iterator<Item = IpAddr>, StaggeredError<DnsError>> {
let host = host.to_string();
let f = || self.lookup_ipv4_ipv6(host.clone(), timeout);
stagger_call(f, delays_ms).await
}
pub async fn lookup_endpoint_by_id(
&self,
endpoint_id: &EndpointId,
origin: &str,
) -> Result<EndpointInfo, LookupError> {
let name = endpoint_info::endpoint_domain(endpoint_id, origin);
let name = endpoint_info::ensure_iroh_txt_label(name);
let lookup = self.lookup_txt(name.clone(), DNS_TIMEOUT).await?;
let info = EndpointInfo::from_txt_lookup(name, lookup)?;
Ok(info)
}
pub async fn lookup_endpoint_by_domain_name(
&self,
name: &str,
) -> Result<EndpointInfo, LookupError> {
let name = endpoint_info::ensure_iroh_txt_label(name.to_string());
let lookup = self.lookup_txt(name.clone(), DNS_TIMEOUT).await?;
let info = EndpointInfo::from_txt_lookup(name, lookup)?;
Ok(info)
}
pub async fn lookup_endpoint_by_domain_name_staggered(
&self,
name: &str,
delays_ms: &[u64],
) -> Result<EndpointInfo, StaggeredError<LookupError>> {
let f = || self.lookup_endpoint_by_domain_name(name);
stagger_call(f, delays_ms).await
}
pub async fn lookup_endpoint_by_id_staggered(
&self,
endpoint_id: &EndpointId,
origin: &str,
delays_ms: &[u64],
) -> Result<EndpointInfo, StaggeredError<LookupError>> {
let f = || self.lookup_endpoint_by_id(endpoint_id, origin);
stagger_call(f, delays_ms).await
}
}
impl Default for DnsResolver {
fn default() -> Self {
Self::new()
}
}
impl reqwest::dns::Resolve for DnsResolver {
fn resolve(&self, name: reqwest::dns::Name) -> reqwest::dns::Resolving {
let this = self.clone();
let name = name.as_str().to_string();
Box::pin(async move {
let res = this.lookup_ipv4_ipv6(name, DNS_TIMEOUT).await;
match res {
Ok(addrs) => {
let addrs: reqwest::dns::Addrs =
Box::new(addrs.map(|addr| SocketAddr::new(addr, 0)));
Ok(addrs)
}
Err(err) => {
let err: Box<dyn std::error::Error + Send + Sync> = Box::new(err);
Err(err)
}
}
})
}
}
#[derive(Debug, Clone)]
enum DnsResolverInner {
Hickory(Arc<RwLock<HickoryResolver>>),
Custom(Arc<RwLock<dyn Resolver>>),
}
impl DnsResolverInner {
async fn lookup_ipv4(
&self,
host: String,
) -> Result<impl Iterator<Item = Ipv4Addr> + use<>, DnsError> {
Ok(match self {
Self::Hickory(resolver) => Either::Left(resolver.read().await.lookup_ipv4(host).await?),
Self::Custom(resolver) => Either::Right(resolver.read().await.lookup_ipv4(host).await?),
})
}
async fn lookup_ipv6(
&self,
host: String,
) -> Result<impl Iterator<Item = Ipv6Addr> + use<>, DnsError> {
Ok(match self {
Self::Hickory(resolver) => Either::Left(resolver.read().await.lookup_ipv6(host).await?),
Self::Custom(resolver) => Either::Right(resolver.read().await.lookup_ipv6(host).await?),
})
}
async fn lookup_txt(
&self,
host: String,
) -> Result<impl Iterator<Item = TxtRecordData> + use<>, DnsError> {
Ok(match self {
Self::Hickory(resolver) => Either::Left(resolver.read().await.lookup_txt(host).await?),
Self::Custom(resolver) => Either::Right(resolver.read().await.lookup_txt(host).await?),
})
}
async fn clear_cache(&self) {
match self {
Self::Hickory(resolver) => resolver.read().await.clear_cache(),
Self::Custom(resolver) => resolver.read().await.clear_cache(),
}
}
async fn reset(&self) {
match self {
Self::Hickory(resolver) => resolver.write().await.reset(),
Self::Custom(resolver) => resolver.write().await.reset(),
}
}
}
#[derive(Debug)]
struct HickoryResolver {
resolver: TokioResolver,
builder: Builder,
}
impl HickoryResolver {
fn new(builder: Builder) -> Self {
let resolver = Self::build_resolver(&builder);
Self { resolver, builder }
}
fn build_resolver(builder: &Builder) -> TokioResolver {
let (mut config, mut options) = if builder.use_system_defaults {
match Self::system_config() {
Ok((config, options)) => (config, options),
Err(error) => {
debug!(%error, "Failed to read the system's DNS config, using fallback DNS servers.");
(ResolverConfig::google(), ResolverOpts::default())
}
}
} else {
(ResolverConfig::new(), ResolverOpts::default())
};
for (addr, proto) in builder.nameservers.iter() {
let nameserver =
hickory_resolver::config::NameServerConfig::new(*addr, proto.to_hickory());
config.add_name_server(nameserver);
}
options.ip_strategy = hickory_resolver::config::LookupIpStrategy::Ipv4thenIpv6;
let mut hickory_builder =
TokioResolver::builder_with_config(config, TokioConnectionProvider::default());
*hickory_builder.options_mut() = options;
hickory_builder.build()
}
fn system_config() -> Result<(ResolverConfig, ResolverOpts), hickory_resolver::ResolveError> {
let (system_config, options) = hickory_resolver::system_conf::read_system_conf()?;
let mut config = hickory_resolver::config::ResolverConfig::new();
if let Some(name) = system_config.domain() {
config.set_domain(name.clone());
}
for name in system_config.search() {
config.add_search(name.clone());
}
for nameserver_cfg in system_config.name_servers() {
if !WINDOWS_BAD_SITE_LOCAL_DNS_SERVERS.contains(&nameserver_cfg.socket_addr.ip()) {
config.add_name_server(nameserver_cfg.clone());
}
}
Ok((config, options))
}
async fn lookup_ipv4(
&self,
host: String,
) -> Result<impl Iterator<Item = Ipv4Addr> + use<>, DnsError> {
Ok(self
.resolver
.ipv4_lookup(host)
.await?
.into_iter()
.map(Ipv4Addr::from))
}
async fn lookup_ipv6(
&self,
host: String,
) -> Result<impl Iterator<Item = Ipv6Addr> + use<>, DnsError> {
Ok(self
.resolver
.ipv6_lookup(host)
.await?
.into_iter()
.map(Ipv6Addr::from))
}
async fn lookup_txt(
&self,
host: String,
) -> Result<impl Iterator<Item = TxtRecordData> + use<>, DnsError> {
Ok(self
.resolver
.txt_lookup(host)
.await?
.into_iter()
.map(|txt| TxtRecordData::from_iter(txt.iter().cloned())))
}
fn clear_cache(&self) {
self.resolver.clear_cache()
}
fn reset(&mut self) {
self.resolver = Self::build_resolver(&self.builder);
}
}
#[derive(Debug, Clone)]
pub struct TxtRecordData(Box<[Box<[u8]>]>);
impl TxtRecordData {
pub fn iter(&self) -> impl Iterator<Item = &[u8]> {
self.0.iter().map(|x| x.as_ref())
}
}
impl fmt::Display for TxtRecordData {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for s in self.iter() {
write!(f, "{}", &String::from_utf8_lossy(s))?
}
Ok(())
}
}
impl FromIterator<Box<[u8]>> for TxtRecordData {
fn from_iter<T: IntoIterator<Item = Box<[u8]>>>(iter: T) -> Self {
Self(iter.into_iter().collect())
}
}
impl From<Vec<Box<[u8]>>> for TxtRecordData {
fn from(value: Vec<Box<[u8]>>) -> Self {
Self(value.into_boxed_slice())
}
}
enum Either<A, B> {
Left(A),
Right(B),
}
impl<T, A: Iterator<Item = T>, B: Iterator<Item = T>> Iterator for Either<A, B> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
match self {
Either::Left(iter) => iter.next(),
Either::Right(iter) => iter.next(),
}
}
}
const WINDOWS_BAD_SITE_LOCAL_DNS_SERVERS: [IpAddr; 3] = [
IpAddr::V6(Ipv6Addr::new(0xfec0, 0, 0, 0xffff, 0, 0, 0, 1)),
IpAddr::V6(Ipv6Addr::new(0xfec0, 0, 0, 0xffff, 0, 0, 0, 2)),
IpAddr::V6(Ipv6Addr::new(0xfec0, 0, 0, 0xffff, 0, 0, 0, 3)),
];
enum LookupIter<A, B> {
Ipv4(A),
Ipv6(B),
Both(std::iter::Chain<A, B>),
}
impl<A: Iterator<Item = IpAddr>, B: Iterator<Item = IpAddr>> Iterator for LookupIter<A, B> {
type Item = IpAddr;
fn next(&mut self) -> Option<Self::Item> {
match self {
LookupIter::Ipv4(iter) => iter.next(),
LookupIter::Ipv6(iter) => iter.next(),
LookupIter::Both(iter) => iter.next(),
}
}
}
async fn stagger_call<
T,
E: StackError + 'static,
F: Fn() -> Fut,
Fut: Future<Output = Result<T, E>>,
>(
f: F,
delays_ms: &[u64],
) -> Result<T, StaggeredError<E>> {
let mut calls = n0_future::FuturesUnorderedBounded::new(delays_ms.len() + 1);
for delay in std::iter::once(&0u64).chain(delays_ms) {
let delay = add_jitter(delay);
let fut = f();
let staggered_fut = async move {
time::sleep(delay).await;
fut.await
};
calls.push(staggered_fut)
}
let mut errors = vec![];
while let Some(call_result) = calls.next().await {
match call_result {
Ok(t) => return Ok(t),
Err(e) => errors.push(e),
}
}
Err(e!(StaggeredError { errors }))
}
fn add_jitter(delay: &u64) -> Duration {
if *delay == 0 {
return Duration::ZERO;
}
let max_jitter = delay.saturating_mul(MAX_JITTER_PERCENT * 2) / 100;
let jitter = rand::random::<u64>() % max_jitter;
Duration::from_millis(delay.saturating_sub(max_jitter / 2).saturating_add(jitter))
}
#[cfg(test)]
pub(crate) mod tests {
use std::sync::atomic::AtomicUsize;
use n0_tracing_test::traced_test;
use super::*;
#[tokio::test]
#[traced_test]
async fn stagger_basic() {
const CALL_RESULTS: &[Result<u8, u8>] = &[Err(2), Ok(3), Ok(5), Ok(7)];
static DONE_CALL: AtomicUsize = AtomicUsize::new(0);
let f = || {
let r_pos = DONE_CALL.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
async move {
tracing::info!(r_pos, "call");
CALL_RESULTS[r_pos].map_err(|_| e!(DnsError::InvalidResponse))
}
};
let delays = [1000, 15];
let result = stagger_call(f, &delays).await.unwrap();
assert_eq!(result, 5)
}
#[test]
#[traced_test]
fn jitter_test_zero() {
let jittered_delay = add_jitter(&0);
assert_eq!(jittered_delay, Duration::from_secs(0));
}
#[test]
#[traced_test]
fn jitter_test_nonzero_lower_bound() {
let delay: u64 = 300;
for _ in 0..100 {
assert!(add_jitter(&delay) >= Duration::from_millis(delay * 8 / 10));
}
}
#[test]
#[traced_test]
fn jitter_test_nonzero_upper_bound() {
let delay: u64 = 300;
for _ in 0..100 {
assert!(add_jitter(&delay) < Duration::from_millis(delay * 12 / 10));
}
}
#[tokio::test]
#[traced_test]
async fn custom_resolver() {
#[derive(Debug)]
struct MyResolver;
impl Resolver for MyResolver {
fn lookup_ipv4(&self, host: String) -> BoxFuture<Result<BoxIter<Ipv4Addr>, DnsError>> {
Box::pin(async move {
let addr = if host == "foo.example" {
Ipv4Addr::new(1, 1, 1, 1)
} else {
return Err(e!(DnsError::NoResponse));
};
let iter: BoxIter<Ipv4Addr> = Box::new(vec![addr].into_iter());
Ok(iter)
})
}
fn lookup_ipv6(&self, _host: String) -> BoxFuture<Result<BoxIter<Ipv6Addr>, DnsError>> {
todo!()
}
fn lookup_txt(
&self,
_host: String,
) -> BoxFuture<Result<BoxIter<TxtRecordData>, DnsError>> {
todo!()
}
fn clear_cache(&self) {
todo!()
}
fn reset(&mut self) {
todo!()
}
}
let resolver = DnsResolver::custom(MyResolver);
let mut iter = resolver
.lookup_ipv4("foo.example", Duration::from_secs(1))
.await
.expect("not to fail");
let addr = iter.next().expect("one result");
assert_eq!(addr, "1.1.1.1".parse::<IpAddr>().unwrap());
let res = resolver
.lookup_ipv4("bar.example", Duration::from_secs(1))
.await;
assert!(matches!(res, Err(DnsError::NoResponse { .. })))
}
}