use std::{
io,
ops::{Bound, RangeBounds},
};
use cyber_bao::{io::round_up_to_chunks, ChunkNum};
use builder::GetRequestBuilder;
use derive_more::From;
use iroh::endpoint::VarInt;
use postcard::experimental::max_size::MaxSize;
use range_collections::{range_set::RangeSetEntry, RangeSet2};
use serde::{Deserialize, Serialize};
mod range_spec;
pub use cyber_bao::ChunkRanges;
use n0_error::stack_error;
pub use range_spec::{ChunkRangesSeq, NonEmptyRequestRangeSpecIter, RangeSpec};
use crate::{api::blobs::Bitfield, util::RecvStreamExt, BlobFormat, Hash, HashAndFormat};
pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024;
pub const ERR_PERMISSION: VarInt = VarInt::from_u32(1u32);
pub const ERR_LIMIT: VarInt = VarInt::from_u32(2u32);
pub const ERR_INTERNAL: VarInt = VarInt::from_u32(3u32);
pub const ALPN: &[u8] = b"/iroh-bytes/4";
#[derive(Deserialize, Serialize, Debug, PartialEq, Eq, Clone, From)]
pub enum Request {
Get(GetRequest),
Observe(ObserveRequest),
Slot2,
Slot3,
Slot4,
Slot5,
Slot6,
Slot7,
Push(PushRequest),
GetMany(GetManyRequest),
}
#[derive(Deserialize, Serialize, Debug, PartialEq, Eq, Clone, Copy, MaxSize)]
pub enum RequestType {
Get,
Observe,
Slot2,
Slot3,
Slot4,
Slot5,
Slot6,
Slot7,
Push,
GetMany,
}
impl Request {
pub async fn read_async<R: crate::util::RecvStream>(
reader: &mut R,
) -> io::Result<(Self, usize)> {
let request_type = reader.read_u8().await?;
let request_type: RequestType = postcard::from_bytes(std::slice::from_ref(&request_type))
.map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"failed to deserialize request type",
)
})?;
Ok(match request_type {
RequestType::Get => {
let (r, size) = reader
.read_to_end_as::<GetRequest>(MAX_MESSAGE_SIZE)
.await?;
(r.into(), size)
}
RequestType::GetMany => {
let (r, size) = reader
.read_to_end_as::<GetManyRequest>(MAX_MESSAGE_SIZE)
.await?;
(r.into(), size)
}
RequestType::Observe => {
let (r, size) = reader
.read_to_end_as::<ObserveRequest>(MAX_MESSAGE_SIZE)
.await?;
(r.into(), size)
}
RequestType::Push => {
let r = reader
.read_length_prefixed::<PushRequest>(MAX_MESSAGE_SIZE)
.await?;
let size = postcard::experimental::serialized_size(&r).unwrap();
(r.into(), size)
}
_ => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"failed to deserialize request type",
));
}
})
}
}
#[derive(Deserialize, Serialize, Debug, PartialEq, Eq, Clone, Hash)]
pub struct GetRequest {
pub hash: Hash,
pub ranges: ChunkRangesSeq,
}
impl From<HashAndFormat> for GetRequest {
fn from(value: HashAndFormat) -> Self {
match value.format {
BlobFormat::Raw => Self::blob(value.hash),
BlobFormat::HashSeq => Self::all(value.hash),
}
}
}
impl GetRequest {
pub fn builder() -> GetRequestBuilder {
GetRequestBuilder::default()
}
pub fn content(&self) -> HashAndFormat {
HashAndFormat {
hash: self.hash,
format: if self.ranges.is_blob() {
BlobFormat::Raw
} else {
BlobFormat::HashSeq
},
}
}
pub fn new(hash: Hash, ranges: ChunkRangesSeq) -> Self {
Self { hash, ranges }
}
pub fn all(hash: impl Into<Hash>) -> Self {
Self {
hash: hash.into(),
ranges: ChunkRangesSeq::all(),
}
}
pub fn blob(hash: impl Into<Hash>) -> Self {
Self {
hash: hash.into(),
ranges: ChunkRangesSeq::from_ranges([ChunkRanges::all()]),
}
}
pub fn blob_ranges(hash: Hash, ranges: ChunkRanges) -> Self {
Self {
hash,
ranges: ChunkRangesSeq::from_ranges([ranges]),
}
}
}
#[derive(
Deserialize, Serialize, Debug, PartialEq, Eq, Clone, derive_more::From, derive_more::Deref,
)]
pub struct PushRequest(GetRequest);
impl PushRequest {
pub fn new(hash: Hash, ranges: ChunkRangesSeq) -> Self {
Self(GetRequest::new(hash, ranges))
}
}
#[derive(Deserialize, Serialize, Debug, PartialEq, Eq, Clone)]
pub struct GetManyRequest {
pub hashes: Vec<Hash>,
pub ranges: ChunkRangesSeq,
}
impl<I: Into<Hash>> FromIterator<I> for GetManyRequest {
fn from_iter<T: IntoIterator<Item = I>>(iter: T) -> Self {
let mut res = iter.into_iter().map(Into::into).collect::<Vec<Hash>>();
res.sort();
res.dedup();
let n = res.len() as u64;
Self {
hashes: res,
ranges: ChunkRangesSeq(smallvec::smallvec![
(0, ChunkRanges::all()),
(n, ChunkRanges::empty())
]),
}
}
}
impl GetManyRequest {
pub fn new(hashes: Vec<Hash>, ranges: ChunkRangesSeq) -> Self {
Self { hashes, ranges }
}
pub fn builder() -> builder::GetManyRequestBuilder {
builder::GetManyRequestBuilder::default()
}
}
#[derive(Deserialize, Serialize, Debug, PartialEq, Eq, Clone, Hash)]
pub struct ObserveRequest {
pub hash: Hash,
pub ranges: RangeSpec,
}
impl ObserveRequest {
pub fn new(hash: Hash) -> Self {
Self {
hash,
ranges: RangeSpec::all(),
}
}
}
#[derive(Deserialize, Serialize, Debug, PartialEq, Eq)]
pub struct ObserveItem {
pub size: u64,
pub ranges: ChunkRanges,
}
impl From<&Bitfield> for ObserveItem {
fn from(value: &Bitfield) -> Self {
Self {
size: value.size,
ranges: value.ranges.clone(),
}
}
}
impl From<&ObserveItem> for Bitfield {
fn from(value: &ObserveItem) -> Self {
Self {
size: value.size,
ranges: value.ranges.clone(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u16)]
pub enum Closed {
StreamDropped = 0,
ProviderTerminating = 1,
RequestReceived = 2,
}
impl Closed {
pub fn reason(&self) -> &'static [u8] {
match self {
Closed::StreamDropped => b"stream dropped",
Closed::ProviderTerminating => b"provider terminating",
Closed::RequestReceived => b"request received",
}
}
}
impl From<Closed> for VarInt {
fn from(source: Closed) -> Self {
VarInt::from(source as u16)
}
}
#[stack_error(derive, add_meta)]
#[error("Unknown error_code: {code}")]
pub struct UnknownErrorCode {
code: u64,
}
impl TryFrom<VarInt> for Closed {
type Error = UnknownErrorCode;
fn try_from(value: VarInt) -> std::result::Result<Self, Self::Error> {
match value.into_inner() {
0 => Ok(Self::StreamDropped),
1 => Ok(Self::ProviderTerminating),
2 => Ok(Self::RequestReceived),
val => Err(n0_error::e!(UnknownErrorCode { code: val })),
}
}
}
pub trait ChunkRangesExt {
fn last_chunk() -> Self;
fn chunk(offset: u64) -> Self;
fn bytes(ranges: impl RangeBounds<u64>) -> Self;
fn chunks(ranges: impl RangeBounds<u64>) -> Self;
fn offset(offset: u64) -> Self;
}
impl ChunkRangesExt for ChunkRanges {
fn last_chunk() -> Self {
ChunkRanges::from(ChunkNum(u64::MAX)..)
}
fn chunk(offset: u64) -> Self {
ChunkRanges::from(ChunkNum(offset)..ChunkNum(offset + 1))
}
fn bytes(ranges: impl RangeBounds<u64>) -> Self {
round_up_to_chunks(&bounds_from_range(ranges, |v| v))
}
fn chunks(ranges: impl RangeBounds<u64>) -> Self {
bounds_from_range(ranges, ChunkNum)
}
fn offset(offset: u64) -> Self {
Self::bytes(offset..offset + 1)
}
}
pub(crate) fn bounds_from_range<R, T, F>(range: R, f: F) -> RangeSet2<T>
where
R: RangeBounds<u64>,
T: RangeSetEntry,
F: Fn(u64) -> T,
{
let from = match range.start_bound() {
Bound::Included(start) => Some(*start),
Bound::Excluded(start) => {
let Some(start) = start.checked_add(1) else {
return RangeSet2::empty();
};
Some(start)
}
Bound::Unbounded => None,
};
let to = match range.end_bound() {
Bound::Included(end) => end.checked_add(1),
Bound::Excluded(end) => Some(*end),
Bound::Unbounded => None,
};
match (from, to) {
(Some(from), Some(to)) => RangeSet2::from(f(from)..f(to)),
(Some(from), None) => RangeSet2::from(f(from)..),
(None, Some(to)) => RangeSet2::from(..f(to)),
(None, None) => RangeSet2::all(),
}
}
pub mod builder {
use std::collections::BTreeMap;
use cyber_bao::ChunkRanges;
use super::ChunkRangesSeq;
use crate::{
protocol::{GetManyRequest, GetRequest},
Hash,
};
#[derive(Debug, Clone, Default)]
pub struct ChunkRangesSeqBuilder {
ranges: BTreeMap<u64, ChunkRanges>,
}
#[derive(Debug, Clone, Default)]
pub struct GetRequestBuilder {
builder: ChunkRangesSeqBuilder,
}
impl GetRequestBuilder {
pub fn offset(mut self, offset: u64, ranges: impl Into<ChunkRanges>) -> Self {
self.builder = self.builder.offset(offset, ranges);
self
}
pub fn child(mut self, child: u64, ranges: impl Into<ChunkRanges>) -> Self {
self.builder = self.builder.offset(child + 1, ranges);
self
}
pub fn root(mut self, ranges: impl Into<ChunkRanges>) -> Self {
self.builder = self.builder.offset(0, ranges);
self
}
pub fn next(mut self, ranges: impl Into<ChunkRanges>) -> Self {
self.builder = self.builder.next(ranges);
self
}
pub fn build(self, hash: impl Into<Hash>) -> GetRequest {
let ranges = self.builder.build();
GetRequest::new(hash.into(), ranges)
}
pub fn build_open(self, hash: impl Into<Hash>) -> GetRequest {
let ranges = self.builder.build_open();
GetRequest::new(hash.into(), ranges)
}
}
impl ChunkRangesSeqBuilder {
pub fn offset(self, offset: u64, ranges: impl Into<ChunkRanges>) -> Self {
self.at_offset(offset, ranges.into())
}
pub fn next(self, ranges: impl Into<ChunkRanges>) -> Self {
let offset = self.next_offset_value();
self.at_offset(offset, ranges.into())
}
pub fn build(self) -> ChunkRangesSeq {
ChunkRangesSeq::from_ranges(self.build0())
}
pub fn build_open(self) -> ChunkRangesSeq {
ChunkRangesSeq::from_ranges_infinite(self.build0())
}
fn at_offset(mut self, offset: u64, ranges: ChunkRanges) -> Self {
self.ranges
.entry(offset)
.and_modify(|v| *v |= ranges.clone())
.or_insert(ranges);
self
}
fn build0(mut self) -> impl Iterator<Item = ChunkRanges> {
let mut ranges = Vec::new();
self.ranges.retain(|_, v| !v.is_empty());
let until_key = self.next_offset_value();
for offset in 0..until_key {
ranges.push(self.ranges.remove(&offset).unwrap_or_default());
}
ranges.into_iter()
}
fn next_offset_value(&self) -> u64 {
self.ranges
.last_key_value()
.map(|(k, _)| *k + 1)
.unwrap_or_default()
}
}
#[derive(Debug, Clone, Default)]
pub struct GetManyRequestBuilder {
ranges: BTreeMap<Hash, ChunkRanges>,
}
impl GetManyRequestBuilder {
pub fn hash(mut self, hash: impl Into<Hash>, ranges: impl Into<ChunkRanges>) -> Self {
let ranges = ranges.into();
let hash = hash.into();
self.ranges
.entry(hash)
.and_modify(|v| *v |= ranges.clone())
.or_insert(ranges);
self
}
pub fn build(self) -> GetManyRequest {
let (hashes, ranges): (Vec<Hash>, Vec<ChunkRanges>) = self
.ranges
.into_iter()
.filter(|(_, v)| !v.is_empty())
.unzip();
let ranges = ChunkRangesSeq::from_ranges(ranges);
GetManyRequest { hashes, ranges }
}
}
#[cfg(test)]
mod tests {
use cyber_bao::ChunkNum;
use super::*;
use crate::protocol::{ChunkRangesExt, GetManyRequest};
#[test]
fn chunk_ranges_ext() {
let ranges = ChunkRanges::bytes(1..2)
| ChunkRanges::chunks(100..=200)
| ChunkRanges::offset(1024 * 10)
| ChunkRanges::chunk(1024)
| ChunkRanges::last_chunk();
assert_eq!(
ranges,
ChunkRanges::from(ChunkNum(0)..ChunkNum(1)) | ChunkRanges::from(ChunkNum(10)..ChunkNum(11)) | ChunkRanges::from(ChunkNum(100)..ChunkNum(201)) | ChunkRanges::from(ChunkNum(1024)..ChunkNum(1025)) | ChunkRanges::last_chunk() );
}
#[test]
fn get_request_builder() {
let hash = [0; 32];
let request = GetRequest::builder()
.root(ChunkRanges::all())
.next(ChunkRanges::all())
.next(ChunkRanges::bytes(0..100))
.build(hash);
assert_eq!(request.hash.as_bytes(), &hash);
assert_eq!(
request.ranges,
ChunkRangesSeq::from_ranges([
ChunkRanges::all(),
ChunkRanges::all(),
ChunkRanges::from(..ChunkNum(1)),
])
);
let request = GetRequest::builder()
.root(ChunkRanges::all())
.child(2, ChunkRanges::bytes(0..100))
.build(hash);
assert_eq!(request.hash.as_bytes(), &hash);
assert_eq!(
request.ranges,
ChunkRangesSeq::from_ranges([
ChunkRanges::all(), ChunkRanges::empty(), ChunkRanges::empty(), ChunkRanges::from(..ChunkNum(1)) ])
);
let request = GetRequest::builder()
.root(ChunkRanges::all())
.next(ChunkRanges::bytes(0..1024) | ChunkRanges::last_chunk())
.build_open(hash);
assert_eq!(request.hash.as_bytes(), &[0; 32]);
assert_eq!(
request.ranges,
ChunkRangesSeq::from_ranges_infinite([
ChunkRanges::all(),
ChunkRanges::from(..ChunkNum(1)) | ChunkRanges::last_chunk(),
])
);
}
#[test]
fn get_many_request_builder() {
let hash1 = [0; 32];
let hash2 = [1; 32];
let hash3 = [2; 32];
let request = GetManyRequest::builder()
.hash(hash1, ChunkRanges::all())
.hash(hash2, ChunkRanges::empty()) .hash(hash3, ChunkRanges::bytes(0..100))
.build();
assert_eq!(
request.hashes,
vec![Hash::from([0; 32]), Hash::from([2; 32])]
);
assert_eq!(
request.ranges,
ChunkRangesSeq::from_ranges([
ChunkRanges::all(), ChunkRanges::from(..ChunkNum(1)), ])
);
}
}
}
#[cfg(test)]
mod tests {
use iroh_test::{assert_eq_hex, hexdump::parse_hexdump};
use postcard::experimental::max_size::MaxSize;
use super::{GetRequest, Request, RequestType};
use crate::Hash;
#[test]
fn request_wire_format() {
let hash: Hash = [0xda; 32].into();
let cases = [
(
Request::from(GetRequest::blob(hash)),
r"
00 # enum variant for GetRequest
dadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadada # the hash
020001000100 # the ChunkRangesSeq
",
),
(
Request::from(GetRequest::all(hash)),
r"
00 # enum variant for GetRequest
dadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadadada # the hash
01000100 # the ChunkRangesSeq
",
),
];
for (case, expected_hex) in cases {
let expected = parse_hexdump(expected_hex).unwrap();
let bytes = postcard::to_stdvec(&case).unwrap();
assert_eq_hex!(bytes, expected);
}
}
#[test]
fn request_type_size() {
assert_eq!(RequestType::POSTCARD_MAX_SIZE, 1);
}
}