use std::{
borrow::Cow,
io,
path::{Path, PathBuf},
sync::{Arc, OnceLock},
};
use axum_server::{
accept::Accept,
tls_rustls::{RustlsAcceptor, RustlsConfig},
};
use n0_error::{Result, StackResultExt, StdResultExt, bail_any};
use n0_future::{FutureExt, future::Boxed as BoxFuture};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls_acme::{AcmeConfig, axum::AxumAcceptor, caches::DirCache};
use tokio_stream::StreamExt;
use tracing::{Instrument, debug, error, info_span};
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, strum::Display)]
#[serde(rename_all = "snake_case")]
pub enum CertMode {
Manual,
LetsEncrypt,
SelfSigned,
}
impl CertMode {
pub(crate) async fn build(
&self,
domains: Vec<String>,
cert_cache: PathBuf,
letsencrypt_contact: Option<String>,
letsencrypt_prod: bool,
) -> Result<TlsAcceptor> {
Ok(match self {
CertMode::Manual => TlsAcceptor::manual(domains, cert_cache).await?,
CertMode::SelfSigned => TlsAcceptor::self_signed(domains).await?,
CertMode::LetsEncrypt => {
let contact =
letsencrypt_contact.context("contact is required for letsencrypt cert mode")?;
TlsAcceptor::letsencrypt(domains, &contact, letsencrypt_prod, cert_cache)?
}
})
}
}
#[derive(Clone)]
pub enum TlsAcceptor {
LetsEncrypt(AxumAcceptor),
Manual(RustlsAcceptor),
}
impl<I: AsyncRead + AsyncWrite + Unpin + Send + 'static, S: Send + 'static> Accept<I, S>
for TlsAcceptor
{
type Stream = tokio_rustls::server::TlsStream<I>;
type Service = S;
type Future = BoxFuture<io::Result<(Self::Stream, Self::Service)>>;
fn accept(&self, stream: I, service: S) -> Self::Future {
match self {
Self::LetsEncrypt(a) => a.accept(stream, service).boxed(),
Self::Manual(a) => a.accept(stream, service).boxed(),
}
}
}
impl TlsAcceptor {
async fn self_signed(domains: Vec<String>) -> Result<Self> {
let rcgen::CertifiedKey { cert, signing_key } =
rcgen::generate_simple_self_signed(domains).anyerr()?;
let config = RustlsConfig::from_der(vec![cert.der().to_vec()], signing_key.serialize_der())
.await
.anyerr()?;
let acceptor = RustlsAcceptor::new(config);
Ok(Self::Manual(acceptor))
}
async fn manual(domains: Vec<String>, dir: PathBuf) -> Result<Self> {
let config = rustls::ServerConfig::builder().with_no_client_auth();
if domains.len() != 1 {
bail_any!("Multiple domains in manual mode are not supported");
}
let keyname = escape_hostname(&domains[0]);
let cert_path = dir.join(format!("{keyname}.crt"));
let key_path = dir.join(format!("{keyname}.key"));
let certs = load_certs(cert_path).await?;
let secret_key = load_secret_key(key_path).await?;
let config = config.with_single_cert(certs, secret_key).anyerr()?;
let config = RustlsConfig::from_config(Arc::new(config));
let acceptor = RustlsAcceptor::new(config);
Ok(Self::Manual(acceptor))
}
fn letsencrypt(
domains: Vec<String>,
contact: &str,
is_production: bool,
dir: PathBuf,
) -> Result<Self> {
let config = rustls::ServerConfig::builder().with_no_client_auth();
let mut state = AcmeConfig::new(domains)
.contact([format!("mailto:{contact}")])
.cache_option(Some(DirCache::new(dir)))
.directory_lets_encrypt(is_production)
.state();
let config = config.with_cert_resolver(state.resolver());
let acceptor = state.acceptor();
tokio::spawn(
async move {
loop {
match state.next().await.unwrap() {
Ok(ok) => debug!("acme event: {:?}", ok),
Err(err) => error!("error: {:?}", err),
}
}
}
.instrument(info_span!("acme")),
);
let config = Arc::new(config);
let acceptor = AxumAcceptor::new(acceptor, config);
Ok(Self::LetsEncrypt(acceptor))
}
}
async fn load_certs(filename: impl AsRef<Path>) -> Result<Vec<CertificateDer<'static>>> {
let filename = filename.as_ref();
let certfile = tokio::fs::read(filename)
.await
.with_std_context(|_| format!("cannot open certificate file at {}", filename.display()))?;
CertificateDer::pem_slice_iter(&certfile)
.collect::<Result<Vec<_>, _>>()
.with_std_context(|_| format!("cannot parse certificates from {}", filename.display()))
}
async fn load_secret_key(filename: impl AsRef<Path>) -> Result<PrivateKeyDer<'static>> {
let filename = filename.as_ref();
let keyfile = tokio::fs::read(filename)
.await
.with_std_context(|_| format!("cannot open secret key file at {}", filename.display()))?;
PrivateKeyDer::from_pem_slice(&keyfile)
.with_std_context(|_| format!("cannot parse secret key from {}", filename.display()))
}
static UNSAFE_HOSTNAME_CHARACTERS: OnceLock<regex::Regex> = OnceLock::new();
fn escape_hostname(hostname: &str) -> Cow<'_, str> {
let regex = UNSAFE_HOSTNAME_CHARACTERS
.get_or_init(|| regex::Regex::new(r"[^a-zA-Z0-9-\.]").expect("valid regex"));
regex.replace_all(hostname, "")
}