use std::{collections::BTreeSet, sync::Arc};
use clap::Parser;
use iroh::{
Endpoint, EndpointId,
endpoint::Connection,
protocol::{AcceptError, ProtocolHandler, Router},
};
use n0_error::{Result, StdResultExt};
use tokio::sync::Mutex;
use tracing_subscriber::{EnvFilter, prelude::*};
#[derive(Debug, Parser)]
pub struct Cli {
#[clap(subcommand)]
command: Command,
}
#[derive(Debug, Parser)]
pub enum Command {
Listen {
text: Vec<String>,
},
Query {
endpoint_id: EndpointId,
query: String,
},
}
const ALPN: &[u8] = b"iroh-example/text-search/0";
#[tokio::main]
async fn main() -> Result<()> {
setup_logging();
let args = Cli::parse();
let endpoint = Endpoint::bind().await?;
let proto = BlobSearch::new(endpoint.clone());
let builder = Router::builder(endpoint);
let router = builder.accept(ALPN, proto.clone()).spawn();
match args.command {
Command::Listen { text } => {
let endpoint_id = router.endpoint().id();
println!("our endpoint id: {endpoint_id}");
for text in text.into_iter() {
proto.insert(text).await?;
}
tokio::signal::ctrl_c().await.anyerr()?;
}
Command::Query { endpoint_id, query } => {
let num_matches = proto.query_remote(endpoint_id, &query).await?;
println!("Found {num_matches} matches");
}
}
router.shutdown().await.anyerr()?;
Ok(())
}
#[derive(Debug, Clone)]
struct BlobSearch {
endpoint: Endpoint,
blobs: Arc<Mutex<BTreeSet<String>>>,
}
impl ProtocolHandler for BlobSearch {
async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
let endpoint_id = connection.remote_id();
println!("accepted connection from {endpoint_id}");
let (mut send, mut recv) = connection.accept_bi().await?;
let query_bytes = recv.read_to_end(64).await.map_err(AcceptError::from_err)?;
let query = String::from_utf8(query_bytes).map_err(AcceptError::from_err)?;
let num_matches = self.query_local(&query).await;
send.write_all(&num_matches.to_le_bytes())
.await
.map_err(AcceptError::from_err)?;
send.finish()?;
connection.closed().await;
Ok(())
}
}
impl BlobSearch {
pub fn new(endpoint: Endpoint) -> Self {
Self {
endpoint,
blobs: Default::default(),
}
}
pub async fn query_remote(&self, endpoint_id: EndpointId, query: &str) -> Result<u64> {
let conn = self.endpoint.connect(endpoint_id, ALPN).await?;
let (mut send, mut recv) = conn.open_bi().await.anyerr()?;
send.write_all(query.as_bytes()).await.anyerr()?;
send.finish().anyerr()?;
let mut num_matches = [0u8; 8];
recv.read_exact(&mut num_matches).await.anyerr()?;
let num_matches = u64::from_le_bytes(num_matches);
Ok(num_matches)
}
pub async fn query_local(&self, query: &str) -> u64 {
let guard = self.blobs.lock().await;
let count: usize = guard.iter().filter(|text| text.contains(query)).count();
count as u64
}
pub async fn insert(&self, text: String) -> Result<()> {
let mut guard = self.blobs.lock().await;
guard.insert(text);
Ok(())
}
}
fn setup_logging() {
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer().with_writer(std::io::stderr))
.with(EnvFilter::from_default_env())
.try_init()
.ok();
}