use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
use anyhow::Result;
use clap::Parser;
use iroh::{
address_lookup::PkarrResolver,
endpoint::Connection,
protocol::{AcceptError, ProtocolHandler, Router},
Endpoint, EndpointId,
};
use iroh_blobs::{api::Store, store::mem::MemStore, BlobsProtocol, Hash};
mod common;
use common::{get_or_generate_secret_key, setup_logging};
#[derive(Debug, Parser)]
pub struct Cli {
#[clap(subcommand)]
command: Command,
}
#[derive(Debug, Parser)]
pub enum Command {
Listen {
text: Vec<String>,
},
Query {
endpoint_id: EndpointId,
query: String,
},
}
const ALPN: &[u8] = b"iroh-example/text-search/0";
async fn listen(text: Vec<String>) -> Result<()> {
let secret_key = get_or_generate_secret_key()?;
let store = MemStore::new();
let endpoint = Endpoint::builder().secret_key(secret_key).bind().await?;
let proto = BlobSearch::new(&store);
for text in text.into_iter() {
proto.insert_and_index(text).await?;
}
let blobs = BlobsProtocol::new(&store, None);
let node = Router::builder(endpoint)
.accept(ALPN, proto.clone())
.accept(iroh_blobs::ALPN, blobs.clone())
.spawn();
let node_id = node.endpoint().id();
println!("our endpoint id: {node_id}");
tokio::signal::ctrl_c().await?;
node.shutdown().await?;
Ok(())
}
async fn query(endpoint_id: EndpointId, query: String) -> Result<()> {
let store = MemStore::new();
let endpoint = Endpoint::empty_builder(iroh::RelayMode::Default)
.address_lookup(PkarrResolver::n0_dns())
.bind()
.await?;
let hashes = query_remote(&endpoint, &store, endpoint_id, &query).await?;
for hash in hashes {
read_and_print(&store, hash).await?;
}
endpoint.close().await;
store.shutdown().await?;
Ok(())
}
#[tokio::main]
async fn main() -> Result<()> {
setup_logging();
let args = Cli::parse();
match args.command {
Command::Listen { text } => {
listen(text).await?;
}
Command::Query {
endpoint_id,
query: query_text,
} => {
query(endpoint_id, query_text).await?;
}
}
Ok(())
}
#[derive(Debug, Clone)]
struct BlobSearch {
blobs: Store,
index: Arc<Mutex<HashMap<String, Hash>>>,
}
impl ProtocolHandler for BlobSearch {
async fn accept(&self, connection: Connection) -> std::result::Result<(), AcceptError> {
let this = self.clone();
let node_id = connection.remote_id();
println!("accepted connection from {node_id}");
let (mut send, mut recv) = connection.accept_bi().await?;
let query_bytes = recv.read_to_end(64).await.map_err(AcceptError::from_err)?;
let query = String::from_utf8(query_bytes).map_err(AcceptError::from_err)?;
let hashes = this.query_local(&query);
println!("query: {query}, found {} results", hashes.len());
for hash in hashes {
send.write_all(hash.as_bytes())
.await
.map_err(AcceptError::from_err)?;
}
send.finish()?;
connection.closed().await;
Ok(())
}
}
impl BlobSearch {
pub fn new(blobs: &Store) -> Arc<Self> {
Arc::new(Self {
blobs: blobs.clone(),
index: Default::default(),
})
}
pub fn query_local(&self, query: &str) -> Vec<Hash> {
let db = self.index.lock().unwrap();
db.iter()
.filter_map(|(text, hash)| text.contains(query).then_some(*hash))
.collect::<Vec<_>>()
}
pub async fn insert_and_index(&self, text: String) -> Result<Hash> {
let hash = self.blobs.add_bytes(text.into_bytes()).await?.hash;
self.add_to_index(hash).await?;
Ok(hash)
}
async fn add_to_index(&self, hash: Hash) -> Result<bool> {
let bitfield = self.blobs.observe(hash).await?;
if !bitfield.is_complete() || bitfield.size() > 1024 * 1024 {
return Ok(false);
}
let data = self.blobs.get_bytes(hash).await?;
match String::from_utf8(data.to_vec()) {
Ok(text) => {
let mut db = self.index.lock().unwrap();
db.insert(text, hash);
Ok(true)
}
Err(_err) => Ok(false),
}
}
}
pub async fn query_remote(
endpoint: &Endpoint,
store: &Store,
endpoint_id: EndpointId,
query: &str,
) -> Result<Vec<Hash>> {
let conn = endpoint.connect(endpoint_id, ALPN).await?;
let blobs_conn = endpoint.connect(endpoint_id, iroh_blobs::ALPN).await?;
let (mut send, mut recv) = conn.open_bi().await?;
send.write_all(query.as_bytes()).await?;
send.finish()?;
let mut out = vec![];
let mut hash_bytes = [0u8; 64];
loop {
match recv.read_exact(&mut hash_bytes).await {
Err(iroh::endpoint::ReadExactError::FinishedEarly(_)) => break,
Err(err) => return Err(err.into()),
Ok(_) => {}
};
let hash = Hash::from_bytes(hash_bytes);
store.remote().fetch(blobs_conn.clone(), hash).await?;
out.push(hash);
}
conn.close(0u32.into(), b"done");
blobs_conn.close(0u32.into(), b"done");
Ok(out)
}
async fn read_and_print(store: &Store, hash: Hash) -> Result<()> {
let content = store.get_bytes(hash).await?;
let message = String::from_utf8(content.to_vec())?;
println!("{}: {message}", hash.fmt_short());
Ok(())
}