use std::io::{self, Write};
use smallvec::SmallVec;
use hemera::OUTPUT_BYTES;
use crate::hash::{HashBackend, Poseidon2Backend};
use crate::io::error::EncodeError;
use crate::io::traits::{Outboard, ReadAt};
use crate::tree::{ChunkNum, CHUNK_SIZE};
use crate::ChunkRangesRef;
const PAIR_SIZE: usize = OUTPUT_BYTES * 2;
pub fn encode_ranges_validated<D: ReadAt, O: Outboard<Hash = hemera::Hash>, W: Write>(
data: D,
outboard: O,
ranges: &ChunkRangesRef,
mut encoded: W,
) -> Result<(), EncodeError> {
if ranges.is_empty() {
return Ok(());
}
let backend = Poseidon2Backend;
let tree = outboard.tree();
let block_size = tree.block_size();
let pre_order = tree.pre_order_chunks_filtered(ranges);
let mut stack = SmallVec::<[hemera::Hash; 10]>::new();
stack.push(outboard.root());
for chunk in &pre_order {
match chunk {
crate::tree::BaoChunk::Parent { node, is_root, left, right } => {
let (l_hash, r_hash) = outboard.load(*node)?.unwrap();
let actual = backend.parent_hash(&l_hash, &r_hash, *is_root);
let expected = stack.pop().unwrap();
if actual != expected {
return Err(EncodeError::ParentHashMismatch(*node));
}
if *right {
stack.push(r_hash.clone());
}
if *left {
stack.push(l_hash.clone());
}
let pair = combine_hash_pair(&l_hash, &r_hash);
encoded.write_all(&pair)?;
}
crate::tree::BaoChunk::Leaf {
start_chunk,
size,
is_root,
} => {
let byte_start = *start_chunk * CHUNK_SIZE as u64;
let mut buf = vec![0u8; *size];
data.read_exact_at(byte_start, &mut buf)?;
let computed =
super::hash_block(&backend, &buf, *start_chunk, *is_root, block_size.bytes());
let expected = stack.pop().unwrap();
if computed != expected {
return Err(EncodeError::LeafHashMismatch(ChunkNum(*start_chunk)));
}
encoded.write_all(&buf)?;
}
}
}
Ok(())
}
pub fn valid_ranges<'a, O, D>(
outboard: O,
data: D,
ranges: &'a ChunkRangesRef,
) -> impl IntoIterator<Item = io::Result<std::ops::Range<ChunkNum>>> + 'a
where
O: Outboard<Hash = hemera::Hash> + 'a,
D: ReadAt + 'a,
{
genawaiter::sync::Gen::new(move |co| async move {
if let Err(cause) =
validate_ranges_impl(outboard, data, ranges, &co).await
{
co.yield_(Err(cause)).await;
}
})
}
pub fn truncate_ranges(ranges: &ChunkRangesRef, size: u64) -> &ChunkRangesRef {
let bs = ranges.boundaries();
ChunkRangesRef::new_unchecked(&bs[..truncated_len(ranges, size)])
}
fn truncated_len(ranges: &ChunkRangesRef, size: u64) -> usize {
let end = ChunkNum::chunks(size);
let lc = ChunkNum(end.0.saturating_sub(1));
let bs = ranges.boundaries();
match bs.binary_search(&lc) {
Ok(i) if (i & 1) == 0 => i + 1,
Ok(i) => {
if bs.len() == i + 1 {
i + 1
} else {
i
}
}
Err(ip) if (ip & 1) == 0 => {
if bs.len() == ip {
ip
} else {
ip + 1
}
}
Err(ip) => ip,
}
}
fn combine_hash_pair(l: &hemera::Hash, r: &hemera::Hash) -> [u8; PAIR_SIZE] {
let mut res = [0u8; PAIR_SIZE];
res[..OUTPUT_BYTES].copy_from_slice(l.as_bytes());
res[OUTPUT_BYTES..].copy_from_slice(r.as_bytes());
res
}
fn hash_subtree(
backend: &Poseidon2Backend,
start_chunk: u64,
data: &[u8],
is_root: bool,
) -> hemera::Hash {
const CHUNK_LEN: usize = CHUNK_SIZE;
if data.len() <= CHUNK_LEN {
return backend.chunk_hash(data, start_chunk, is_root);
}
let mut chunk_hashes: Vec<hemera::Hash> = Vec::new();
let mut offset = 0usize;
let mut counter = start_chunk;
while offset < data.len() {
let end = (offset + CHUNK_LEN).min(data.len());
let chunk_data = &data[offset..end];
chunk_hashes.push(backend.chunk_hash(chunk_data, counter, false));
offset += CHUNK_LEN;
counter += 1;
}
let mut level = chunk_hashes;
while level.len() > 1 {
let mut next = Vec::with_capacity(level.len().div_ceil(2));
let mut i = 0;
while i < level.len() {
if i + 1 < level.len() {
next.push(backend.parent_hash(&level[i], &level[i + 1], false));
} else {
next.push(level[i].clone());
}
i += 2;
}
level = next;
}
level.into_iter().next().unwrap()
}
async fn validate_ranges_impl<O, D>(
outboard: O,
data: D,
ranges: &ChunkRangesRef,
co: &genawaiter::sync::Co<io::Result<std::ops::Range<ChunkNum>>>,
) -> io::Result<()>
where
O: Outboard<Hash = hemera::Hash>,
D: ReadAt,
{
let backend = Poseidon2Backend;
let tree = outboard.tree();
if tree.blocks() == 0 {
return Ok(());
}
if tree.blocks() == 1 {
let sz: usize = tree.size().try_into().unwrap();
let mut tmp = vec![0u8; sz];
if data.read_exact_at(0, &mut tmp).is_err() {
return Ok(());
}
let actual = hash_subtree(&backend, 0, &tmp, true);
if actual == outboard.root() {
co.yield_(Ok(ChunkNum(0)..tree.chunks())).await;
}
return Ok(());
}
let ranges = truncate_ranges(ranges, tree.size());
validate_node_rec(
&outboard, &data, &backend, &tree, tree.root(), outboard.root(),
true, ranges, co,
).await;
Ok(())
}
async fn validate_node_rec<O, D>(
outboard: &O,
data: &D,
backend: &Poseidon2Backend,
tree: &crate::tree::BaoTree,
node: crate::tree::TreeNode,
expected_hash: hemera::Hash,
is_root: bool,
ranges: &ChunkRangesRef,
co: &genawaiter::sync::Co<io::Result<std::ops::Range<ChunkNum>>>,
)
where
O: Outboard<Hash = hemera::Hash>,
D: ReadAt,
{
use range_collections::RangeSet2;
let actual_range = tree.node_actual_chunk_range(node);
let node_chunks = RangeSet2::from(actual_range.start..actual_range.end);
if node_chunks.is_disjoint(ranges) {
return;
}
let level = node.level();
if level == 0 {
let block_idx = node.0 / 2;
let block_bytes = tree.block_size().bytes() as u64;
let start_byte = block_idx * block_bytes;
let end_byte = ((block_idx + 1) * block_bytes).min(tree.size());
let size = if start_byte >= tree.size() {
0
} else {
(end_byte - start_byte) as usize
};
let start_chunk = block_idx * (1u64 << tree.block_size().0);
let mut buf = vec![0u8; size];
if data.read_exact_at(start_byte, &mut buf).is_err() {
return; }
let actual = hash_subtree(backend, start_chunk, &buf, is_root);
if actual == expected_hash {
let chunks_per_block = 1u64 << tree.block_size().chunk_log();
let leaf_end_chunk = start_chunk + chunks_per_block;
co.yield_(Ok(ChunkNum(start_chunk)..ChunkNum(leaf_end_chunk))).await;
}
return;
}
let right_exists = node.right_child().is_some_and(|rc| {
let right_block_start = rc.chunk_range().start.0 / 2;
right_block_start < tree.blocks()
});
if !right_exists {
if let Some(left) = node.left_child() {
Box::pin(validate_node_rec(
outboard, data, backend, tree, left, expected_hash,
is_root, ranges, co,
)).await;
}
return;
}
let pair = match outboard.load(node) {
Ok(Some(pair)) => pair,
_ => return, };
let (l_hash, r_hash) = pair;
let actual = backend.parent_hash(&l_hash, &r_hash, is_root);
if actual != expected_hash {
return; }
if let Some(left) = node.left_child() {
Box::pin(validate_node_rec(
outboard, data, backend, tree, left, l_hash,
false, ranges, co,
)).await;
}
if let Some(right) = node.right_child() {
Box::pin(validate_node_rec(
outboard, data, backend, tree, right, r_hash,
false, ranges, co,
)).await;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::io::pre_order::PreOrderMemOutboard;
use crate::tree::{BlockSize, ChunkNum};
use crate::ChunkRanges;
#[test]
fn encode_ranges_validated_full() {
let data = vec![0x42u8; CHUNK_SIZE * 2];
let outboard = PreOrderMemOutboard::create(&data, BlockSize::ZERO);
let mut encoded = Vec::new();
let size = data.len() as u64;
encoded.extend_from_slice(&size.to_le_bytes());
encode_ranges_validated(&data[..], &outboard, &ChunkRanges::all(), &mut encoded)
.expect("encode should succeed");
assert_eq!(encoded.len(), 8 + PAIR_SIZE + CHUNK_SIZE * 2);
}
#[test]
fn valid_ranges_all_valid() {
let data = vec![0x42u8; CHUNK_SIZE * 2];
let outboard = PreOrderMemOutboard::create(&data, BlockSize::ZERO);
let mut ranges = ChunkRanges::empty();
for range in valid_ranges(&outboard, &data[..], &ChunkRanges::all())
.into_iter()
.flatten()
{
ranges |= ChunkRanges::from(range);
}
assert_eq!(ranges, ChunkRanges::from(ChunkNum(0)..ChunkNum(2)));
}
#[test]
fn valid_ranges_empty_data() {
let data: Vec<u8> = vec![];
let outboard = PreOrderMemOutboard::create(&data, BlockSize::ZERO);
let mut ranges = ChunkRanges::empty();
for range in valid_ranges(&outboard, &data[..], &ChunkRanges::all())
.into_iter()
.flatten()
{
ranges |= ChunkRanges::from(range);
}
assert_eq!(ranges, ChunkRanges::empty());
}
#[test]
fn encode_then_valid_ranges_roundtrip() {
let data = vec![0xABu8; CHUNK_SIZE * 4];
let outboard = PreOrderMemOutboard::create(&data, BlockSize::ZERO);
let mut ranges = ChunkRanges::empty();
for range in valid_ranges(&outboard, &data[..], &ChunkRanges::all())
.into_iter()
.flatten()
{
ranges |= ChunkRanges::from(range);
}
assert_eq!(ranges, ChunkRanges::from(ChunkNum(0)..ChunkNum(4)));
}
#[test]
fn encode_ranges_block_size_nonzero() {
let bs = BlockSize::from_chunk_log(1);
let data = vec![0x42u8; CHUNK_SIZE * 8]; let outboard = PreOrderMemOutboard::create(&data, bs);
let mut encoded = Vec::new();
encode_ranges_validated(&data[..], &outboard, &ChunkRanges::all(), &mut encoded)
.expect("encode should succeed");
assert_eq!(encoded.len(), 3 * PAIR_SIZE + CHUNK_SIZE * 8);
}
#[test]
fn encode_ranges_partial_large_block() {
let bs = BlockSize::from_chunk_log(4);
let data: Vec<u8> = (0..100000u64).map(|i| (i % 251) as u8).collect();
let outboard = PreOrderMemOutboard::create(&data, bs);
let ranges = ChunkRanges::from(ChunkNum(16)..ChunkNum(32));
let mut encoded = Vec::new();
encode_ranges_validated(&data[..], &outboard, &ranges, &mut encoded)
.expect("encode should succeed for partial ranges");
assert!(!encoded.is_empty());
}
#[test]
fn valid_ranges_block_size_nonzero() {
let bs = BlockSize::from_chunk_log(1);
let data = vec![0x42u8; CHUNK_SIZE * 8];
let outboard = PreOrderMemOutboard::create(&data, bs);
let mut ranges = ChunkRanges::empty();
for range in valid_ranges(&outboard, &data[..], &ChunkRanges::all())
.into_iter()
.flatten()
{
ranges |= ChunkRanges::from(range);
}
assert_eq!(ranges, ChunkRanges::from(ChunkNum(0)..ChunkNum(8)));
}
}