use std::fmt;
use std::ops::Range;
use hemera::OUTPUT_BYTES;
pub const CHUNK_SIZE: usize = 4096;
const PAIR_SIZE: usize = OUTPUT_BYTES * 2;
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct TreeNode(pub(crate) u64);
impl TreeNode {
pub fn new(index: u64) -> Self {
Self(index)
}
pub fn level(self) -> u32 {
self.0.trailing_ones()
}
pub fn is_leaf(self) -> bool {
(self.0 & 1) == 0
}
pub fn mid(self) -> ChunkNum {
ChunkNum(self.0 + 1)
}
pub fn left_child(self) -> Option<Self> {
let level = self.level();
if level == 0 {
return None;
}
Some(Self(self.0 - (1 << (level - 1))))
}
pub fn right_child(self) -> Option<Self> {
let level = self.level();
if level == 0 {
return None;
}
Some(Self(self.0 + (1 << (level - 1))))
}
pub fn parent(self) -> Option<Self> {
let level = self.level();
let span = 1u64 << level;
let offset = self.0;
if offset & (span << 1) == 0 {
Some(Self(offset + span))
} else {
offset.checked_sub(span).map(Self)
}
}
pub fn chunk_range(self) -> Range<ChunkNum> {
let level = self.level();
let span = 1u64 << level;
let mid = self.0 + 1;
ChunkNum(mid - span)..ChunkNum(mid + span)
}
pub fn byte_range(self) -> Range<u64> {
let range = self.chunk_range();
range.start.to_bytes()..range.end.to_bytes()
}
pub fn right_descendant(self, len: Self) -> Option<Self> {
let mut node = self.right_child()?;
while node >= len {
node = node.left_child()?;
}
Some(node)
}
pub fn count_below(self) -> u64 {
let level = self.level();
if level == 0 {
return 0;
}
(1u64 << (level + 1)) - 2
}
pub fn post_order_offset(self) -> u64 {
let level = self.level();
let span = 1u64 << level;
let mid = self.0 + 1;
let start = mid - span;
start + self.count_below()
}
pub const fn add_block_size(self, n: u8) -> Option<Self> {
let mask = (1u64 << n) - 1;
if self.0 & mask == mask {
Some(Self(self.0 >> n))
} else {
None
}
}
pub const fn subtract_block_size(self, n: u8) -> Self {
let shifted = !(!self.0 << n);
Self(shifted)
}
}
impl fmt::Debug for TreeNode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "TreeNode({}, level={})", self.0, self.level())
}
}
impl fmt::Display for TreeNode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ChunkNum(pub u64);
impl ChunkNum {
pub fn chunks(size: u64) -> Self {
Self(size.div_ceil(CHUNK_SIZE as u64))
}
pub fn full_chunks(size: u64) -> Self {
Self(size / CHUNK_SIZE as u64)
}
pub fn to_bytes(self) -> u64 {
self.0 * CHUNK_SIZE as u64
}
}
impl fmt::Debug for ChunkNum {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ChunkNum({})", self.0)
}
}
impl fmt::Display for ChunkNum {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::ops::Add for ChunkNum {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self(self.0 + rhs.0)
}
}
impl std::ops::Add<u64> for ChunkNum {
type Output = Self;
fn add(self, rhs: u64) -> Self {
Self(self.0 + rhs)
}
}
impl std::ops::Sub for ChunkNum {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
Self(self.0 - rhs.0)
}
}
impl std::ops::Sub<u64> for ChunkNum {
type Output = Self;
fn sub(self, rhs: u64) -> Self {
Self(self.0 - rhs)
}
}
impl range_collections::range_set::RangeSetEntry for ChunkNum {
fn min_value() -> Self {
ChunkNum(0)
}
fn is_min_value(&self) -> bool {
self.0 == 0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct BlockSize(pub u8);
impl BlockSize {
pub const ZERO: Self = Self(0);
pub const DEFAULT: Self = Self(2);
pub const fn from_chunk_log(log: u8) -> Self {
Self(log)
}
pub fn from_bytes(bytes: u64) -> Option<Self> {
if bytes < CHUNK_SIZE as u64 || !bytes.is_power_of_two() {
return None;
}
let log = (bytes / CHUNK_SIZE as u64).trailing_zeros() as u8;
Some(Self(log))
}
pub fn bytes(self) -> usize {
CHUNK_SIZE << self.0
}
pub fn chunk_log(self) -> u8 {
self.0
}
pub fn to_u32(self) -> u32 {
self.0 as u32
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct BaoTree {
size: u64,
block_size: BlockSize,
}
impl BaoTree {
pub fn new(size: u64, block_size: BlockSize) -> Self {
Self { size, block_size }
}
pub fn size(&self) -> u64 {
self.size
}
pub fn block_size(&self) -> BlockSize {
self.block_size
}
pub fn chunks(&self) -> ChunkNum {
ChunkNum::chunks(self.size)
}
pub fn blocks(&self) -> u64 {
let chunk_count = self.chunks().0;
let chunks_per_block = 1u64 << self.block_size.0;
chunk_count.div_ceil(chunks_per_block).max(1)
}
pub fn root(&self) -> TreeNode {
let blocks = self.blocks();
if blocks <= 1 {
return TreeNode(0);
}
TreeNode(blocks.next_power_of_two() - 1)
}
pub fn outboard_size(&self) -> u64 {
let blocks = self.blocks();
if blocks <= 1 {
return 0;
}
(blocks - 1) * PAIR_SIZE as u64
}
pub fn chunk_group_bytes(&self) -> usize {
self.block_size.bytes()
}
pub fn shifted(&self) -> (TreeNode, TreeNode) {
let level = self.block_size.0;
let size = self.size;
let shift = CHUNK_SIZE.trailing_zeros() + level as u32;
let mask = (1u64 << shift) - 1;
let full_blocks = size >> shift;
let open_block = u64::from((size & mask) != 0);
let blocks = (full_blocks + open_block).max(1);
let n = blocks.div_ceil(2);
let root = n.next_power_of_two() - 1;
let filled_size = n + n.saturating_sub(1);
(TreeNode(root), TreeNode(filled_size))
}
pub fn pre_order_offset(&self, node: TreeNode) -> Option<u64> {
let pre_order = self.pre_order_chunks();
let mut parent_idx = 0u64;
for chunk in &pre_order {
if let BaoChunk::Parent { node: n, .. } = chunk {
if *n == node {
return Some(parent_idx);
}
parent_idx += 1;
}
}
None
}
pub fn post_order_offset(&self, node: TreeNode) -> Option<u64> {
let post_order = self.post_order_chunks();
let mut parent_idx = 0u64;
for chunk in &post_order {
if let BaoChunk::Parent { node: n, .. } = chunk {
if *n == node {
return Some(parent_idx);
}
parent_idx += 1;
}
}
None
}
pub fn post_order_chunks_iter(&self) -> PostOrderChunkIter {
PostOrderChunkIter {
chunks: self.post_order_chunks(),
pos: 0,
}
}
pub fn leaf_byte_ranges3(&self, node: TreeNode) -> (u64, u64, u64) {
let Range { start, end } = node.byte_range();
let mid = node.mid().to_bytes();
(start, mid.min(self.size), end.min(self.size))
}
pub fn node_actual_chunk_range(&self, node: TreeNode) -> Range<ChunkNum> {
let raw = node.chunk_range();
let cpb = 1u64 << self.block_size.0;
ChunkNum((raw.start.0 / 2) * cpb)..ChunkNum((raw.end.0 / 2) * cpb)
}
pub fn is_relevant_for_outboard(&self, node: TreeNode) -> bool {
let level = node.level();
let bs = self.block_size.to_u32();
if level < bs {
false
} else if level > bs {
true
} else {
node.mid().to_bytes() < self.size
}
}
}
#[derive(Debug)]
pub struct PostOrderChunkIter {
chunks: Vec<BaoChunk>,
pos: usize,
}
impl Iterator for PostOrderChunkIter {
type Item = BaoChunk;
fn next(&mut self) -> Option<BaoChunk> {
if self.pos < self.chunks.len() {
let item = self.chunks[self.pos].clone();
self.pos += 1;
Some(item)
} else {
None
}
}
}
#[derive(Debug, Clone)]
pub enum BaoChunk {
Parent {
node: TreeNode,
is_root: bool,
left: bool,
right: bool,
},
Leaf {
start_chunk: u64,
size: usize,
is_root: bool,
},
}
impl BaoTree {
pub fn pre_order_chunks(&self) -> Vec<BaoChunk> {
let blocks = self.blocks();
if blocks == 0 {
return vec![];
}
if blocks == 1 {
let chunk_bytes = self.size.min(self.block_size.bytes() as u64) as usize;
return vec![BaoChunk::Leaf {
start_chunk: 0,
size: chunk_bytes,
is_root: true,
}];
}
let root = self.root();
let mut items = Vec::new();
self.pre_order_recurse(root, true, blocks, &mut items);
items
}
fn pre_order_recurse(
&self,
node: TreeNode,
is_root: bool,
total_blocks: u64,
out: &mut Vec<BaoChunk>,
) {
let level = node.level();
if level == 0 {
let block_idx = node.0 / 2;
let block_bytes = self.block_size.bytes() as u64;
let start_byte = block_idx * block_bytes;
let end_byte = ((block_idx + 1) * block_bytes).min(self.size);
let size = if start_byte >= self.size {
0
} else {
(end_byte - start_byte) as usize
};
out.push(BaoChunk::Leaf {
start_chunk: block_idx * (1u64 << self.block_size.0),
size,
is_root,
});
return;
}
let right_exists = node.right_child().is_some_and(|rc| {
let right_block_start = rc.chunk_range().start.0 / 2;
right_block_start < total_blocks
});
if !right_exists {
if let Some(left) = node.left_child() {
self.pre_order_recurse(left, is_root, total_blocks, out);
}
return;
}
out.push(BaoChunk::Parent { node, is_root, left: true, right: true });
if let Some(left) = node.left_child() {
self.pre_order_recurse(left, false, total_blocks, out);
}
if let Some(right) = node.right_child() {
self.pre_order_recurse(right, false, total_blocks, out);
}
}
pub fn post_order_chunks(&self) -> Vec<BaoChunk> {
let blocks = self.blocks();
if blocks == 0 {
return vec![];
}
if blocks == 1 {
let chunk_bytes = self.size.min(self.block_size.bytes() as u64) as usize;
return vec![BaoChunk::Leaf {
start_chunk: 0,
size: chunk_bytes,
is_root: true,
}];
}
let root = self.root();
let mut items = Vec::new();
self.post_order_recurse(root, true, blocks, &mut items);
items
}
fn post_order_recurse(
&self,
node: TreeNode,
is_root: bool,
total_blocks: u64,
out: &mut Vec<BaoChunk>,
) {
let level = node.level();
if level == 0 {
let block_idx = node.0 / 2;
let block_bytes = self.block_size.bytes() as u64;
let start_byte = block_idx * block_bytes;
let end_byte = ((block_idx + 1) * block_bytes).min(self.size);
let size = if start_byte >= self.size {
0
} else {
(end_byte - start_byte) as usize
};
out.push(BaoChunk::Leaf {
start_chunk: block_idx * (1u64 << self.block_size.0),
size,
is_root,
});
return;
}
let right_exists = node.right_child().is_some_and(|rc| {
let right_block_start = rc.chunk_range().start.0 / 2;
right_block_start < total_blocks
});
if !right_exists {
if let Some(left) = node.left_child() {
self.post_order_recurse(left, is_root, total_blocks, out);
}
return;
}
if let Some(left) = node.left_child() {
self.post_order_recurse(left, false, total_blocks, out);
}
if let Some(right) = node.right_child() {
self.post_order_recurse(right, false, total_blocks, out);
}
out.push(BaoChunk::Parent { node, is_root, left: true, right: true });
}
}
impl BaoTree {
pub fn pre_order_chunks_filtered(
&self,
ranges: &range_collections::RangeSetRef<ChunkNum>,
) -> Vec<BaoChunk> {
let blocks = self.blocks();
if blocks == 0 || ranges.is_empty() {
return vec![];
}
if blocks == 1 {
let chunk_bytes = self.size.min(self.block_size.bytes() as u64) as usize;
let leaf_range = range_collections::RangeSet2::from(ChunkNum(0)..ChunkNum(1u64 << self.block_size.0));
if leaf_range.is_disjoint(ranges) {
return vec![];
}
return vec![BaoChunk::Leaf {
start_chunk: 0,
size: chunk_bytes,
is_root: true,
}];
}
let root = self.root();
let mut items = Vec::new();
self.pre_order_filtered_recurse(root, true, blocks, ranges, &mut items);
items
}
fn pre_order_filtered_recurse(
&self,
node: TreeNode,
is_root: bool,
total_blocks: u64,
ranges: &range_collections::RangeSetRef<ChunkNum>,
out: &mut Vec<BaoChunk>,
) {
let actual_range = self.node_actual_chunk_range(node);
let node_chunks =
range_collections::RangeSet2::from(actual_range.start..actual_range.end);
if node_chunks.is_disjoint(ranges) {
return;
}
let level = node.level();
if level == 0 {
let block_idx = node.0 / 2;
let block_bytes = self.block_size.bytes() as u64;
let start_byte = block_idx * block_bytes;
let end_byte = ((block_idx + 1) * block_bytes).min(self.size);
let size = if start_byte >= self.size {
0
} else {
(end_byte - start_byte) as usize
};
out.push(BaoChunk::Leaf {
start_chunk: block_idx * (1u64 << self.block_size.0),
size,
is_root,
});
return;
}
let right_exists = node.right_child().is_some_and(|rc| {
let right_block_start = rc.chunk_range().start.0 / 2;
right_block_start < total_blocks
});
if !right_exists {
if let Some(left) = node.left_child() {
self.pre_order_filtered_recurse(left, is_root, total_blocks, ranges, out);
}
return;
}
let left_child = node.left_child();
let right_child = node.right_child();
let left_overlaps = left_child.is_some_and(|lc| {
let lr = self.node_actual_chunk_range(lc);
let lcs = range_collections::RangeSet2::from(lr.start..lr.end);
!lcs.is_disjoint(ranges)
});
let right_overlaps = right_child.is_some_and(|rc| {
let rr = self.node_actual_chunk_range(rc);
let rcs = range_collections::RangeSet2::from(rr.start..rr.end);
!rcs.is_disjoint(ranges)
});
if !left_overlaps && !right_overlaps {
return;
}
out.push(BaoChunk::Parent { node, is_root, left: left_overlaps, right: right_overlaps });
if left_overlaps {
if let Some(left) = left_child {
self.pre_order_filtered_recurse(left, false, total_blocks, ranges, out);
}
}
if right_overlaps {
if let Some(right) = right_child {
self.pre_order_filtered_recurse(right, false, total_blocks, ranges, out);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tree_node_level() {
assert_eq!(TreeNode(0).level(), 0); assert_eq!(TreeNode(1).level(), 1); assert_eq!(TreeNode(2).level(), 0); assert_eq!(TreeNode(3).level(), 2); assert_eq!(TreeNode(7).level(), 3); }
#[test]
fn tree_node_is_leaf() {
assert!(TreeNode(0).is_leaf());
assert!(!TreeNode(1).is_leaf());
assert!(TreeNode(2).is_leaf());
assert!(!TreeNode(3).is_leaf());
}
#[test]
fn tree_node_children() {
let parent = TreeNode(1);
assert_eq!(parent.left_child(), Some(TreeNode(0)));
assert_eq!(parent.right_child(), Some(TreeNode(2)));
let leaf = TreeNode(0);
assert_eq!(leaf.left_child(), None);
assert_eq!(leaf.right_child(), None);
let l2_parent = TreeNode(3);
assert_eq!(l2_parent.left_child(), Some(TreeNode(1)));
assert_eq!(l2_parent.right_child(), Some(TreeNode(5)));
}
#[test]
fn tree_node_chunk_range() {
assert_eq!(TreeNode(0).chunk_range(), ChunkNum(0)..ChunkNum(2));
assert_eq!(TreeNode(1).chunk_range(), ChunkNum(0)..ChunkNum(4));
assert_eq!(TreeNode(3).chunk_range(), ChunkNum(0)..ChunkNum(8));
}
#[test]
fn block_size_bytes() {
assert_eq!(BlockSize::ZERO.bytes(), CHUNK_SIZE);
assert_eq!(BlockSize::DEFAULT.bytes(), CHUNK_SIZE * 4);
assert_eq!(BlockSize::from_chunk_log(1).bytes(), CHUNK_SIZE * 2);
}
#[test]
fn block_size_from_bytes() {
assert_eq!(BlockSize::from_bytes(CHUNK_SIZE as u64), Some(BlockSize::ZERO));
assert_eq!(BlockSize::from_bytes(CHUNK_SIZE as u64 * 4), Some(BlockSize::DEFAULT));
assert_eq!(BlockSize::from_bytes(512), None);
assert_eq!(BlockSize::from_bytes(3000), None);
}
#[test]
fn bao_tree_basic() {
let tree = BaoTree::new(CHUNK_SIZE as u64, BlockSize::ZERO);
assert_eq!(tree.chunks(), ChunkNum(1));
assert_eq!(tree.blocks(), 1);
assert_eq!(tree.outboard_size(), 0);
}
#[test]
fn bao_tree_two_blocks() {
let tree = BaoTree::new(CHUNK_SIZE as u64 * 2, BlockSize::ZERO);
assert_eq!(tree.chunks(), ChunkNum(2));
assert_eq!(tree.blocks(), 2);
assert_eq!(tree.outboard_size(), PAIR_SIZE as u64); }
#[test]
fn bao_tree_empty() {
let tree = BaoTree::new(0, BlockSize::ZERO);
assert_eq!(tree.blocks(), 1); assert_eq!(tree.outboard_size(), 0);
}
#[test]
fn bao_tree_large() {
let tree = BaoTree::new(1_000_000, BlockSize::DEFAULT);
assert!(tree.blocks() > 0);
assert!(tree.outboard_size() > 0);
}
#[test]
fn pre_order_single_block() {
let tree = BaoTree::new(CHUNK_SIZE as u64, BlockSize::ZERO);
let chunks = tree.pre_order_chunks();
assert_eq!(chunks.len(), 1);
assert!(matches!(
chunks[0],
BaoChunk::Leaf {
start_chunk: 0,
size,
is_root: true,
} if size == CHUNK_SIZE
));
}
#[test]
fn pre_order_two_blocks() {
let tree = BaoTree::new(CHUNK_SIZE as u64 * 2, BlockSize::ZERO);
let chunks = tree.pre_order_chunks();
assert_eq!(chunks.len(), 3);
assert!(matches!(chunks[0], BaoChunk::Parent { is_root: true, .. }));
assert!(matches!(
chunks[1],
BaoChunk::Leaf {
start_chunk: 0,
is_root: false,
..
}
));
assert!(matches!(
chunks[2],
BaoChunk::Leaf {
start_chunk: 1,
is_root: false,
..
}
));
}
#[test]
fn post_order_two_blocks() {
let tree = BaoTree::new(CHUNK_SIZE as u64 * 2, BlockSize::ZERO);
let chunks = tree.post_order_chunks();
assert_eq!(chunks.len(), 3);
assert!(matches!(
chunks[0],
BaoChunk::Leaf {
start_chunk: 0,
is_root: false,
..
}
));
assert!(matches!(
chunks[1],
BaoChunk::Leaf {
start_chunk: 1,
is_root: false,
..
}
));
assert!(matches!(chunks[2], BaoChunk::Parent { is_root: true, .. }));
}
#[test]
fn pre_order_four_blocks() {
let tree = BaoTree::new(CHUNK_SIZE as u64 * 4, BlockSize::ZERO);
let chunks = tree.pre_order_chunks();
assert_eq!(chunks.len(), 7);
assert!(matches!(chunks[0], BaoChunk::Parent { is_root: true, .. }));
let leaf_count = chunks
.iter()
.filter(|c| matches!(c, BaoChunk::Leaf { .. }))
.count();
assert_eq!(leaf_count, 4);
}
}