//! GPU backend for Hemera via wgpu compute shaders.
//!
//! Provides GPU-accelerated batch operations:
//! - `batch_permute` โ€” raw Poseidon2 permutations
//! - `batch_hash` โ€” plain sponge hash (no tree domain)
//! - `batch_keyed_hash` โ€” keyed sponge hash
//! - `batch_derive_key` โ€” two-phase key derivation
//! - `batch_hash_leaves` โ€” full sponge + tree leaf hashing
//! - `batch_hash_nodes` โ€” parent node hashing
//! - `batch_hash_nodes_nmt` โ€” namespace-aware parent node hashing
//! - `root_hash` โ€” full Merkle tree on GPU (leaves + level-by-level node merge)
//! - `outboard` โ€” GPU leaf hashing + CPU tree serialization
//! - `batch_verify_proofs` โ€” batch inclusion proof verification
//! - `batch_squeeze` โ€” batch XOF squeeze via GPU permutations
//!
//! The GPU path is optional โ€” the CPU backend (`cyber-hemera`) is always
//! available as fallback.

// GPU crate interfaces with external hardware (wgpu). All data is
// marshalled through GPU buffers โ€” heap allocation is unavoidable.
#![allow(unknown_lints, rs_no_vec, rs_unbounded_async)]

use std::num::NonZeroU64;

use cyber_hemera::field::Goldilocks;
use cyber_hemera::sparse::CompressedSparseProof;
use cyber_hemera::tree::{InclusionProof, Sibling};
use cyber_hemera::{Hash, CHUNK_SIZE, OUTPUT_BYTES, WIDTH};
use wgpu::util::DeviceExt;

/// Pre-compiled GPU compute pipelines and device handles.
#[derive(Debug)]
pub struct GpuContext {
    device: wgpu::Device,
    queue: wgpu::Queue,
    permute_pipeline: wgpu::ComputePipeline,
    hash_leaf_pipeline: wgpu::ComputePipeline,
    hash_node_pipeline: wgpu::ComputePipeline,
    hash_chunk_pipeline: wgpu::ComputePipeline,
    keyed_hash_pipeline: wgpu::ComputePipeline,
    derive_key_material_pipeline: wgpu::ComputePipeline,
    hash_node_nmt_pipeline: wgpu::ComputePipeline,
    bind_group_layout: wgpu::BindGroupLayout,
    rc_buffer: wgpu::Buffer,
    diag_buffer: wgpu::Buffer,
    dummy_buffer: wgpu::Buffer,
}

const FLAG_ROOT: u32 = 1;
const DOMAIN_HASH: u32 = 0;
const PARAMS_SIZE: u64 = 32; // DispatchParams struct: 8 ร— u32

impl GpuContext {
    /// Initialize GPU backend. Returns `None` if no suitable GPU is available.
    pub async fn new() -> Option<Self> {
        let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor::default());
        let adapter = instance
            .request_adapter(&wgpu::RequestAdapterOptions {
                power_preference: wgpu::PowerPreference::HighPerformance,
                compatible_surface: None,
                force_fallback_adapter: false,
            })
            .await
            .ok()?;

        if !adapter
            .get_downlevel_capabilities()
            .flags
            .contains(wgpu::DownlevelFlags::COMPUTE_SHADERS)
        {
            return None;
        }

        let (device, queue) = adapter
            .request_device(&wgpu::DeviceDescriptor {
                label: Some("hemera GPU"),
                required_features: wgpu::Features::empty(),
                required_limits: wgpu::Limits::downlevel_defaults(),
                ..Default::default()
            })
            .await
            .ok()?;

        let shader_source = concat!(
            include_str!("shaders/params.wgsl"),
            include_str!("shaders/field.wgsl"),
            include_str!("shaders/encoding.wgsl"),
            include_str!("shaders/permutation.wgsl"),
            include_str!("shaders/sponge.wgsl"),
            include_str!("shaders/tree.wgsl"),
            include_str!("shaders/entry_points.wgsl"),
        );
        let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
            label: Some("hemera"),
            source: wgpu::ShaderSource::Wgsl(shader_source.into()),
        });

        let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
            label: Some("hemera bgl"),
            entries: &[
                buffer_entry(0, wgpu::BufferBindingType::Storage { read_only: false }, 4),
                buffer_entry(1, wgpu::BufferBindingType::Storage { read_only: true }, 4),
                buffer_entry(2, wgpu::BufferBindingType::Uniform, PARAMS_SIZE),
                buffer_entry(3, wgpu::BufferBindingType::Storage { read_only: true }, 4),
                buffer_entry(4, wgpu::BufferBindingType::Storage { read_only: true }, 4),
            ],
        });

        let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
            label: Some("hemera layout"),
            bind_group_layouts: &[&bgl],
            immediate_size: 0,
        });

        let pipe = |ep: &str| {
            device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
                label: Some(ep),
                layout: Some(&pipeline_layout),
                module: &module,
                entry_point: Some(ep),
                compilation_options: wgpu::PipelineCompilationOptions::default(),
                cache: None,
            })
        };

        let rc_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
            label: Some("round constants"),
            contents: bytemuck::cast_slice(&generate_round_constants_u32()),
            usage: wgpu::BufferUsages::STORAGE,
        });
        let diag_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
            label: Some("matrix diag"),
            contents: bytemuck::cast_slice(&generate_matrix_diag_u32()),
            usage: wgpu::BufferUsages::STORAGE,
        });
        let dummy_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
            label: Some("dummy"),
            contents: &[0u8; 4],
            usage: wgpu::BufferUsages::STORAGE,
        });

        Some(Self {
            permute_pipeline: pipe("hemera_permute"),
            hash_leaf_pipeline: pipe("hemera_hash_leaf"),
            hash_node_pipeline: pipe("hemera_hash_node"),
            hash_chunk_pipeline: pipe("hemera_hash_chunk"),
            keyed_hash_pipeline: pipe("hemera_keyed_hash"),
            derive_key_material_pipeline: pipe("hemera_derive_key_material"),
            hash_node_nmt_pipeline: pipe("hemera_hash_node_nmt"),
            bind_group_layout: bgl,
            rc_buffer,
            diag_buffer,
            dummy_buffer,
            device,
            queue,
        })
    }

    fn bind(&self, io: &wgpu::Buffer, params: &wgpu::Buffer, aux: &wgpu::Buffer) -> wgpu::BindGroup {
        self.device.create_bind_group(&wgpu::BindGroupDescriptor {
            label: None,
            layout: &self.bind_group_layout,
            entries: &[
                wgpu::BindGroupEntry { binding: 0, resource: io.as_entire_binding() },
                wgpu::BindGroupEntry { binding: 1, resource: self.rc_buffer.as_entire_binding() },
                wgpu::BindGroupEntry { binding: 2, resource: params.as_entire_binding() },
                wgpu::BindGroupEntry { binding: 3, resource: self.diag_buffer.as_entire_binding() },
                wgpu::BindGroupEntry { binding: 4, resource: aux.as_entire_binding() },
            ],
        })
    }

    fn dispatch_readback(
        &self,
        pipeline: &wgpu::ComputePipeline,
        bg: &wgpu::BindGroup,
        io: &wgpu::Buffer,
        count: u32,
    ) -> Vec<u32> {
        let dl = self.device.create_buffer(&wgpu::BufferDescriptor {
            label: None,
            size: io.size(),
            usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
            mapped_at_creation: false,
        });
        let mut enc = self.device.create_command_encoder(&Default::default());
        {
            let mut pass = enc.begin_compute_pass(&Default::default());
            pass.set_pipeline(pipeline);
            pass.set_bind_group(0, bg, &[]);
            pass.dispatch_workgroups(count.div_ceil(64), 1, 1);
        }
        enc.copy_buffer_to_buffer(io, 0, &dl, 0, io.size());
        self.queue.submit([enc.finish()]);
        let slice = dl.slice(..);
        slice.map_async(wgpu::MapMode::Read, |r| {
            r.expect("GPU buffer map failed โ€” compute may have timed out");
        });
        self.device
            .poll(wgpu::PollType::wait_indefinitely())
            .unwrap();
        let mapped = slice.get_mapped_range();
        let out: Vec<u32> = bytemuck::cast_slice(&mapped).to_vec();
        drop(mapped);
        dl.unmap();
        out
    }

    fn params_buf(&self, p: [u32; 8]) -> wgpu::Buffer {
        self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
            label: None,
            contents: bytemuck::cast_slice(&p),
            usage: wgpu::BufferUsages::UNIFORM,
        })
    }

    fn dispatch_hash(&self, pipeline: &wgpu::ComputePipeline, aux: &[u8], p: [u32; 8], n: u32) -> Vec<Hash> {
        let io = self.device.create_buffer(&wgpu::BufferDescriptor {
            label: None,
            size: (n as u64) * 16 * 4,
            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
            mapped_at_creation: false,
        });
        let aux_buf = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
            label: None,
            contents: aux,
            usage: wgpu::BufferUsages::STORAGE,
        });
        let pb = self.params_buf(p);
        let bg = self.bind(&io, &pb, &aux_buf);
        u32s_to_hashes(&self.dispatch_readback(pipeline, &bg, &io, n), n as usize)
    }

    // โ”€โ”€ Batch primitives โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

    /// Run batch Poseidon2 permutations on GPU.
    pub async fn batch_permute(&self, states: &Goldilocks; WIDTH) -> Vec<[Goldilocks; WIDTH]> {
        if states.is_empty() { return vec![]; }
        let n = states.len() as u32;
        let data = flatten_states(states);
        let io = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
            label: None,
            contents: bytemuck::cast_slice(&data),
            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
        });
        let pb = self.params_buf([n, 0, 0, 0, 0, 0, 0, 0]);
        let bg = self.bind(&io, &pb, &self.dummy_buffer);
        unflatten_states(&self.dispatch_readback(&self.permute_pipeline, &bg, &io, n), states.len())
    }

    /// Batch sponge hash of data chunks (plain hash, no tree domain).
    pub async fn batch_hash(&self, data: &[u8], chunk_size: usize) -> Vec<Hash> {
        if data.is_empty() { return vec![]; }
        let n = data.len().div_ceil(chunk_size) as u32;
        self.dispatch_hash(
            &self.hash_chunk_pipeline,
            &pad4(data),
            [n, DOMAIN_HASH, chunk_size as u32, data.len() as u32, 0, 0, 0, 0],
            n,
        )
    }

    /// Batch keyed hash โ€” key read from aux[0..64) by GPU, no per-chunk duplication.
    pub async fn batch_keyed_hash(&self, key: &[u8; OUTPUT_BYTES], data: &[u8], chunk_size: usize) -> Vec<Hash> {
        if data.is_empty() { return vec![]; }
        let n = data.len().div_ceil(chunk_size) as u32;
        let mut buf = Vec::with_capacity(OUTPUT_BYTES + data.len());
        buf.extend_from_slice(key);
        buf.extend_from_slice(data);
        self.dispatch_hash(
            &self.keyed_hash_pipeline,
            &pad4(&buf),
            [n, 0, chunk_size as u32, data.len() as u32, 0, 0, 0, 0],
            n,
        )
    }

    /// Derive key: context hash on CPU, material hash on GPU.
    pub async fn batch_derive_key(&self, context: &str, data: &[u8], chunk_size: usize) -> Vec<Hash> {
        if data.is_empty() { return vec![]; }
        let ctx_hash = cyber_hemera::Hasher::new_derive_key_context(context).finalize();
        self.batch_derive_key_material(&ctx_hash, data, chunk_size).await
    }

    /// Derive key material phase, seeded by a pre-computed context hash.
    pub async fn batch_derive_key_material(&self, ctx: &Hash, data: &[u8], chunk_size: usize) -> Vec<Hash> {
        if data.is_empty() { return vec![]; }
        let n = data.len().div_ceil(chunk_size) as u32;
        let mut aux_u32s: Vec<u32> = Vec::new();
        push_hash_u32s(&mut aux_u32s, ctx);
        aux_u32s.extend(bytemuck::cast_slice::<u8, u32>(&pad4(data)).iter());
        self.dispatch_hash(
            &self.derive_key_material_pipeline,
            bytemuck::cast_slice(&aux_u32s),
            [n, 0, chunk_size as u32, data.len() as u32, 0, 0, 0, 0],
            n,
        )
    }

    /// Maximum leaves per GPU dispatch to avoid command buffer timeouts.
    const LEAF_BATCH: usize = 4096;

    /// Hash leaf chunks on GPU (sponge + tree domain).
    ///
    /// For large inputs, dispatches in batches to avoid GPU timeouts.
    /// The counter offset (param slot ns_min_lo) ensures each leaf gets
    /// its correct global chunk index regardless of batch boundaries.
    pub async fn batch_hash_leaves(&self, data: &[u8], chunk_size: usize, is_root: bool) -> Vec<Hash> {
        if data.is_empty() { return vec![]; }
        let total_chunks = data.len().div_ceil(chunk_size);

        if total_chunks <= Self::LEAF_BATCH {
            let n = total_chunks as u32;
            let flags = if is_root { FLAG_ROOT } else { 0 };
            return self.dispatch_hash(
                &self.hash_leaf_pipeline,
                &pad4(data),
                [n, flags, chunk_size as u32, data.len() as u32, 0, 0, 0, 0],
                n,
            );
        }

        // Batch: split data into sub-slices and dispatch each independently.
        let mut all_hashes = Vec::with_capacity(total_chunks);
        let mut offset = 0;
        while offset < total_chunks {
            let batch_end = (offset + Self::LEAF_BATCH).min(total_chunks);
            let byte_start = offset * chunk_size;
            let byte_end = (batch_end * chunk_size).min(data.len());
            let slice = &data[byte_start..byte_end];
            let n = (batch_end - offset) as u32;
            let counter_offset = offset as u32;
            let hashes = self.dispatch_hash(
                &self.hash_leaf_pipeline,
                &pad4(slice),
                [n, 0, chunk_size as u32, slice.len() as u32, counter_offset, 0, 0, 0],
                n,
            );
            all_hashes.extend(hashes);
            offset = batch_end;
        }
        all_hashes
    }

    /// Hash leaf chunks with progress reporting (used by root_hash_with_progress).
    async fn batch_hash_leaves_progress(
        &self,
        data: &[u8],
        chunk_size: usize,
        is_root: bool,
        total: usize,
        progress: &impl Fn(usize, usize),
    ) -> Vec<Hash> {
        if data.is_empty() { return vec![]; }
        let total_chunks = data.len().div_ceil(chunk_size);

        if total_chunks <= Self::LEAF_BATCH {
            let n = total_chunks as u32;
            let flags = if is_root { FLAG_ROOT } else { 0 };
            let result = self.dispatch_hash(
                &self.hash_leaf_pipeline,
                &pad4(data),
                [n, flags, chunk_size as u32, data.len() as u32, 0, 0, 0, 0],
                n,
            );
            progress(total_chunks, total);
            return result;
        }

        let mut all_hashes = Vec::with_capacity(total_chunks);
        let mut offset = 0;
        while offset < total_chunks {
            let batch_end = (offset + Self::LEAF_BATCH).min(total_chunks);
            let byte_start = offset * chunk_size;
            let byte_end = (batch_end * chunk_size).min(data.len());
            let slice = &data[byte_start..byte_end];
            let n = (batch_end - offset) as u32;
            let counter_offset = offset as u32;
            let hashes = self.dispatch_hash(
                &self.hash_leaf_pipeline,
                &pad4(slice),
                [n, 0, chunk_size as u32, slice.len() as u32, counter_offset, 0, 0, 0],
                n,
            );
            all_hashes.extend(hashes);
            offset = batch_end;
            progress(offset, total);
        }
        all_hashes
    }

    /// Combine pairs of child hashes into parent hashes.
    pub async fn batch_hash_nodes(&self, pairs: &[(Hash, Hash)], is_root: bool) -> Vec<Hash> {
        if pairs.is_empty() { return vec![]; }
        let n = pairs.len() as u32;
        let flags = if is_root { FLAG_ROOT } else { 0 };
        self.dispatch_hash(
            &self.hash_node_pipeline,
            &flatten_pairs(pairs),
            [n, flags, 0, 0, 0, 0, 0, 0],
            n,
        )
    }

    /// Combine pairs with namespace bounds (NMT). Full u64 ns support.
    pub async fn batch_hash_nodes_nmt(
        &self, pairs: &[(Hash, Hash)], ns_min: u64, ns_max: u64, is_root: bool,
    ) -> Vec<Hash> {
        if pairs.is_empty() { return vec![]; }
        let n = pairs.len() as u32;
        let flags = if is_root { FLAG_ROOT } else { 0 };
        self.dispatch_hash(
            &self.hash_node_nmt_pipeline,
            &flatten_pairs(pairs),
            [n, flags, 0, 0,
             ns_min as u32, (ns_min >> 32) as u32,
             ns_max as u32, (ns_max >> 32) as u32],
            n,
        )
    }

    // โ”€โ”€ High-level tree operations โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

    /// Compute the Merkle root hash on GPU.
    ///
    /// Hashes all leaves in one GPU dispatch, then merges level-by-level
    /// until a single root remains. Matches `cyber_hemera::tree::root_hash`.
    pub async fn root_hash(&self, data: &[u8]) -> Hash {
        self.root_hash_with_progress(data, |_, _| {}).await
    }

    /// Compute the Merkle root hash on GPU with progress callback.
    ///
    /// `progress` receives `(completed, total)` operation counts.
    pub async fn root_hash_with_progress(
        &self,
        data: &[u8],
        progress: impl Fn(usize, usize),
    ) -> Hash {
        if data.is_empty() {
            return cyber_hemera::tree::hash_leaf(data, 0, true);
        }
        let n = data.len().div_ceil(CHUNK_SIZE);
        if n == 1 {
            return self.batch_hash_leaves(data, CHUNK_SIZE, true).await.remove(0);
        }
        let total = 2 * n - 1;
        progress(0, total);
        let leaves = self.batch_hash_leaves_progress(data, CHUNK_SIZE, false, total, &progress).await;
        self.merge_tree_with_progress(leaves, n, total, &progress).await
    }

    /// Merge tree with progress reporting.
    async fn merge_tree_with_progress(
        &self,
        leaves: Vec<Hash>,
        done_start: usize,
        total: usize,
        progress: &impl Fn(usize, usize),
    ) -> Hash {
        let mut done = done_start;
        let n = leaves.len();
        let segments = left_balanced_decompose(n);
        let num_segments = segments.len();

        // Each segment reduces to one root hash. Track per-segment.
        let mut seg_roots: Vec<Option<Hash>> = vec![None; num_segments];
        let mut current_level: Vec<Hash> = leaves;
        let max_height = segments[0].1.trailing_zeros();

        // Initialize size-1 segments as already complete.
        for (si, &(_, size)) in segments.iter().enumerate() {
            if size == 1 {
                seg_roots[si] = Some(current_level[segments[si].0]);
            }
        }

        for round in 0..max_height {
            let mut next_level = Vec::new();
            let mut pairs: Vec<(Hash, Hash)> = Vec::new();
            let mut pair_segment: Vec<usize> = Vec::new();

            let mut offset = 0;
            for (si, &(_, size)) in segments.iter().enumerate() {
                let height = size.trailing_zeros();
                if height <= round {
                    // Already complete or size-1.
                    if size > 1 { offset += size >> round; }
                    continue;
                }
                let cur_count = size >> round;
                for i in (0..cur_count).step_by(2) {
                    pairs.push((current_level[offset + i], current_level[offset + i + 1]));
                    pair_segment.push(si);
                }
                offset += cur_count;
            }

            if pairs.is_empty() { break; }

            // Determine is_root: only if single segment and final round.
            let is_root = num_segments == 1 && round + 1 == max_height;
            let num_pairs = pairs.len();
            let results = self.batch_hash_nodes(&pairs, is_root).await;
            done += num_pairs;
            if total > 0 { progress(done, total); }

            let mut ri = 0;
            for (si, &(_, size)) in segments.iter().enumerate() {
                let height = size.trailing_zeros();
                if height <= round || size == 1 {
                    continue;
                }
                let result_count = size >> (round + 1);
                if height == round + 1 {
                    // Segment completing this round.
                    seg_roots[si] = Some(results[ri]);
                    ri += 1;
                } else {
                    next_level.extend_from_slice(&results[ri..ri + result_count]);
                    ri += result_count;
                }
            }

            current_level = next_level;
        }

        // All segments now have roots. Fold right-to-left along the spine.
        // The tree structure: seg[0] is left child of root spine,
        // seg[1] is left child of next spine node, etc.
        // Fold: start from rightmost two, merge, then merge with next-left.
        let mut roots: Vec<Hash> = seg_roots.into_iter().map(|r| r.unwrap()).collect();

        while roots.len() > 1 {
            let right = roots.pop().unwrap();
            let left = roots.pop().unwrap();
            let is_root = roots.is_empty();
            let result = self.batch_hash_nodes(&[(left, right)], is_root).await;
            roots.push(result[0]);
            done += 1;
            if total > 0 { progress(done, total); }
        }

        roots.pop().unwrap()
    }

    /// Compute the outboard (hash tree without data) on GPU.
    ///
    /// GPU hashes all leaves in parallel. CPU builds the tree structure
    /// and serializes parent pairs in pre-order.
    /// Returns `(root_hash, outboard_bytes)` matching `cyber_hemera::stream::outboard`.
    pub async fn outboard(&self, data: &[u8]) -> (Hash, Vec<u8>) {
        let n = if data.is_empty() { 1 } else { data.len().div_ceil(CHUNK_SIZE) };

        if n <= 1 {
            let root = self.batch_hash_leaves(data, CHUNK_SIZE, true).await.remove(0);
            let mut out = Vec::with_capacity(8);
            out.extend_from_slice(&(data.len() as u64).to_le_bytes());
            return (root, out);
        }

        // GPU: batch hash all leaves.
        let leaves = self.batch_hash_leaves(data, CHUNK_SIZE, false).await;

        // CPU: recursive tree build + pre-order serialization.
        let num_parents = n - 1;
        let mut out = Vec::with_capacity(8 + num_parents * OUTPUT_BYTES * 2);
        out.extend_from_slice(&(data.len() as u64).to_le_bytes());
        let root = outboard_subtree_from_leaves(&leaves, 0, n, true, &mut out);
        (root, out)
    }

    /// Verify multiple inclusion proofs in batch on GPU.
    ///
    /// Each entry is `(chunk_data, proof, expected_root)`.
    /// Returns a Vec of bools indicating which proofs verified successfully.
    pub async fn batch_verify_proofs(
        &self,
        proofs: &[(&[u8], &InclusionProof, &Hash)],
    ) -> Vec<bool> {
        if proofs.is_empty() { return vec![]; }

        let max_depth = proofs.iter().map(|(_, p, _)| p.depth()).max().unwrap_or(0);

        // Compute initial hashes (leaf or subtree root).
        let mut current: Vec<Hash> = proofs.iter().map(|(chunk, proof, _)| {
            let start = proof.start_chunk;
            let end = proof.end_chunk;
            if end - start == 1 {
                cyber_hemera::tree::hash_leaf(chunk, start, proof.num_chunks == 1)
            } else {
                cyber_hemera::tree::root_hash(chunk)
            }
        }).collect();

        if max_depth == 0 {
            return proofs.iter().enumerate()
                .map(|(i, (_, _, root))| current[i] == **root)
                .collect();
        }

        // Walk proof levels leaf-to-root, batching hash_node calls on GPU.
        // Siblings are stored root-to-leaf, so index from the end.
        for level in 0..max_depth {
            let mut root_pairs: Vec<(Hash, Hash)> = Vec::new();
            let mut root_indices: Vec<usize> = Vec::new();
            let mut inner_pairs: Vec<(Hash, Hash)> = Vec::new();
            let mut inner_indices: Vec<usize> = Vec::new();

            for (i, (_, proof, _)) in proofs.iter().enumerate() {
                let siblings = proof.siblings();
                // level 0 = leaf-most sibling (last in array).
                let sib_idx = siblings.len().checked_sub(1 + level);
                let Some(sib_idx) = sib_idx else { continue; };
                let is_root = sib_idx == 0;
                let pair = match siblings[sib_idx] {
                    Sibling::Left(sib) => (sib, current[i]),
                    Sibling::Right(sib) => (current[i], sib),
                };
                if is_root {
                    root_pairs.push(pair);
                    root_indices.push(i);
                } else {
                    inner_pairs.push(pair);
                    inner_indices.push(i);
                }
            }

            if !inner_pairs.is_empty() {
                let results = self.batch_hash_nodes(&inner_pairs, false).await;
                for (j, &idx) in inner_indices.iter().enumerate() {
                    current[idx] = results[j];
                }
            }
            if !root_pairs.is_empty() {
                let results = self.batch_hash_nodes(&root_pairs, true).await;
                for (j, &idx) in root_indices.iter().enumerate() {
                    current[idx] = results[j];
                }
            }
        }

        proofs.iter().enumerate().map(|(i, (_, _, root))| current[i] == **root).collect()
    }

    /// Verify multiple sparse Merkle proofs in batch on GPU.
    ///
    /// Each entry is `(proof, value_if_inclusion, expected_root)`.
    /// `value` is `Some(data)` for inclusion proofs, `None` for non-inclusion.
    /// Returns a Vec of bools indicating which proofs verified successfully.
    pub async fn batch_verify_sparse_proofs(
        &self,
        proofs: &[(&CompressedSparseProof, Option<&[u8]>, &Hash)],
        depth: u32,
    ) -> Vec<bool> {
        if proofs.is_empty() { return vec![]; }

        // Compute sentinel table on CPU (done once).
        let sentinels = cyber_hemera::sparse::sentinel_table(depth);

        // Compute initial leaf hashes on CPU.
        let n = proofs.len();
        let mut current_hashes: Vec<Hash> = proofs.iter().map(|(proof, value, _)| {
            match value {
                Some(v) => cyber_hemera::tree::hash_leaf(
                    &[proof.key.as_slice(), *v].concat(), 0, false,
                ),
                None => sentinels[0],
            }
        }).collect();
        let mut cursors = vec![0usize; n];

        for level in 0..depth {
            let is_root_level = level + 1 == depth;
            let mut root_pairs: Vec<(Hash, Hash)> = Vec::new();
            let mut root_indices: Vec<usize> = Vec::new();
            let mut inner_pairs: Vec<(Hash, Hash)> = Vec::new();
            let mut inner_indices: Vec<usize> = Vec::new();

            for (i, (proof, _, _)) in proofs.iter().enumerate() {
                let byte_idx = (level / 8) as usize;
                let bit_in_byte = level % 8;
                let has_real = (proof.bitmask[byte_idx] >> bit_in_byte) & 1 == 1;

                let sibling = if has_real {
                    if cursors[i] >= proof.siblings.len() {
                        // Invalid proof โ€” mark with sentinel to fail later.
                        current_hashes[i] = Hash::from_bytes([0xFF; 64]);
                        continue;
                    }
                    let s = proof.siblings[cursors[i]];
                    cursors[i] += 1;
                    s
                } else {
                    sentinels[level as usize]
                };

                let bit_pos = depth - 1 - level;
                let is_right = key_bit_static(&proof.key, bit_pos);

                let pair = if is_right {
                    (sibling, current_hashes[i])
                } else {
                    (current_hashes[i], sibling)
                };

                if is_root_level {
                    root_pairs.push(pair);
                    root_indices.push(i);
                } else {
                    inner_pairs.push(pair);
                    inner_indices.push(i);
                }
            }

            if !inner_pairs.is_empty() {
                let results = self.batch_hash_nodes(&inner_pairs, false).await;
                for (j, &idx) in inner_indices.iter().enumerate() {
                    current_hashes[idx] = results[j];
                }
            }
            if !root_pairs.is_empty() {
                let results = self.batch_hash_nodes(&root_pairs, true).await;
                for (j, &idx) in root_indices.iter().enumerate() {
                    current_hashes[idx] = results[j];
                }
            }
        }

        proofs.iter().enumerate().map(|(i, (proof, _, root))| {
            current_hashes[i] == **root && cursors[i] == proof.siblings.len()
        }).collect()
    }

    /// Batch XOF squeeze: given finalized sponge states, produce `count`
    /// output blocks (64 bytes each) per state using GPU permutations.
    pub async fn batch_squeeze(
        &self,
        states: &Goldilocks; WIDTH,
        count: usize,
    ) -> Vec<Vec<[u8; OUTPUT_BYTES]>> {
        if states.is_empty() || count == 0 { return vec![vec![]; states.len()]; }

        let n = states.len();
        let mut result = vec![Vec::with_capacity(count); n];

        // Block 0: extract from initial state (before any permutation).
        for (i, state) in states.iter().enumerate() {
            result[i].push(extract_output(state));
        }

        // Blocks 1..count: permute, then extract.
        let mut current_states: Vec<[Goldilocks; WIDTH]> = states.to_vec();
        for _ in 1..count {
            current_states = self.batch_permute(&current_states).await;
            for (i, state) in current_states.iter().enumerate() {
                result[i].push(extract_output(state));
            }
        }

        result
    }
}

// โ”€โ”€ Helpers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

fn buffer_entry(binding: u32, ty: wgpu::BufferBindingType, min: u64) -> wgpu::BindGroupLayoutEntry {
    wgpu::BindGroupLayoutEntry {
        binding,
        visibility: wgpu::ShaderStages::COMPUTE,
        ty: wgpu::BindingType::Buffer {
            ty,
            has_dynamic_offset: false,
            min_binding_size: Some(NonZeroU64::new(min).unwrap()),
        },
        count: None,
    }
}

fn pad4(data: &[u8]) -> Vec<u8> {
    let mut v = data.to_vec();
    v.resize(v.len().next_multiple_of(4), 0);
    v
}

fn flatten_states(states: &Goldilocks; WIDTH) -> Vec<u32> {
    let mut out = Vec::with_capacity(states.len() * WIDTH * 2);
    for state in states {
        for e in state {
            let v = e.as_canonical_u64();
            out.push(v as u32);
            out.push((v >> 32) as u32);
        }
    }
    out
}

fn unflatten_states(u32s: &[u32], count: usize) -> Vec<[Goldilocks; WIDTH]> {
    (0..count)
        .map(|idx| {
            core::array::from_fn(|i| {
                let off = idx * WIDTH * 2 + i * 2;
                Goldilocks::new(u32s[off] as u64 | ((u32s[off + 1] as u64) << 32))
            })
        })
        .collect()
}

fn push_hash_u32s(out: &mut Vec<u32>, hash: &Hash) {
    for i in 0..8 {
        let v = u64::from_le_bytes(hash.as_bytes()[i * 8..(i + 1) * 8].try_into().unwrap());
        out.push(v as u32);
        out.push((v >> 32) as u32);
    }
}

fn flatten_pairs(pairs: &[(Hash, Hash)]) -> Vec<u8> {
    let mut u32s = Vec::with_capacity(pairs.len() * 32);
    for (l, r) in pairs {
        push_hash_u32s(&mut u32s, l);
        push_hash_u32s(&mut u32s, r);
    }
    bytemuck::cast_slice(&u32s).to_vec()
}

fn u32s_to_hashes(u32s: &[u32], count: usize) -> Vec<Hash> {
    (0..count)
        .map(|idx| {
            let mut bytes = [0u8; OUTPUT_BYTES];
            for i in 0..8 {
                let off = idx * 16 + i * 2;
                let v = u32s[off] as u64 | ((u32s[off + 1] as u64) << 32);
                bytes[i * 8..(i + 1) * 8].copy_from_slice(&v.to_le_bytes());
            }
            Hash::from_bytes(bytes)
        })
        .collect()
}

fn generate_round_constants_u32() -> Vec<u32> {
    cyber_hemera::constants::ROUND_CONSTANTS_U64
        .iter()
        .flat_map(|&v| [v as u32, (v >> 32) as u32])
        .collect()
}

fn generate_matrix_diag_u32() -> Vec<u32> {
    cyber_hemera::field::MATRIX_DIAG_16
        .iter()
        .flat_map(|e| {
            let v = e.as_canonical_u64();
            [v as u32, (v >> 32) as u32]
        })
        .collect()
}

fn extract_output(state: &[Goldilocks; WIDTH]) -> [u8; OUTPUT_BYTES] {
    let mut out = [0u8; OUTPUT_BYTES];
    for i in 0..8 {
        out[i * 8..(i + 1) * 8].copy_from_slice(&state[i].as_canonical_u64().to_le_bytes());
    }
    out
}

/// Extract the i-th path bit from a key (MSB-first), matching `sparse::key_bit`.
fn key_bit_static(key: &[u8; 32], bit_index: u32) -> bool {
    let byte_idx = (bit_index / 8) as usize;
    let bit_in_byte = 7 - (bit_index % 8);
    (key[byte_idx] >> bit_in_byte) & 1 == 1
}

// โ”€โ”€ Tree helpers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

/// Decompose n into complete binary subtrees (left-balanced).
/// Returns [(start_leaf, size)] where each size is a power of 2.
fn left_balanced_decompose(n: usize) -> Vec<(usize, usize)> {
    let mut result = Vec::new();
    let mut offset = 0;
    let mut remaining = n;
    while remaining > 0 {
        if remaining.is_power_of_two() {
            result.push((offset, remaining));
            break;
        }
        let split = left_subtree_chunks(remaining);
        result.push((offset, split));
        offset += split;
        remaining -= split;
    }
    result
}

fn left_subtree_chunks(count: usize) -> usize {
    debug_assert!(count > 1);
    1 << (usize::BITS - (count - 1).leading_zeros() - 1)
}

/// Recursive outboard subtree using pre-computed leaf hashes (CPU hash_node).
fn outboard_subtree_from_leaves(
    leaves: &[Hash],
    offset: usize,
    count: usize,
    is_root: bool,
    out: &mut Vec<u8>,
) -> Hash {
    if count == 1 {
        return leaves[offset];
    }

    let split = left_subtree_chunks(count);

    // Reserve slot for this parent's hash pair (pre-order).
    let pair_start = out.len();
    out.extend_from_slice(&[0u8; OUTPUT_BYTES * 2]);

    let left = outboard_subtree_from_leaves(leaves, offset, split, false, out);
    let right = outboard_subtree_from_leaves(leaves, offset + split, count - split, false, out);

    // Fill in the reserved slot.
    out[pair_start..pair_start + OUTPUT_BYTES].copy_from_slice(left.as_ref());
    out[pair_start + OUTPUT_BYTES..pair_start + OUTPUT_BYTES * 2].copy_from_slice(right.as_ref());

    cyber_hemera::tree::hash_node(&left, &right, is_root)
}

Dimensions

zheng/src/lib.rs
trident/src/lib.rs
optica/src/lib.rs
nebu/rs/lib.rs
bbg/src/lib.rs
nox/src/lib.rs
rs/macros/src/lib.rs
rs/core/src/lib.rs
hemera/rs/src/lib.rs
nebu/wgsl/src/lib.rs
cw-cyber/contracts/litium-mine/src/lib.rs
cw-cyber/packages/hub-base/src/lib.rs
cw-cyber/packages/cyber-std-test/src/lib.rs
cw-cyber/contracts/litium-wrap/src/lib.rs
cw-cyber/contracts/hub-skills/src/lib.rs
cw-cyber/contracts/cw-cyber-gift/src/lib.rs
cw-cyber/contracts/litium-stake/src/lib.rs
cw-cyber/packages/cyber-std/src/lib.rs
cw-cyber/contracts/cybernet/src/lib.rs
trident/editor/zed/src/lib.rs
cw-cyber/contracts/hub-protocols/src/lib.rs
cw-cyber/contracts/std-test/src/lib.rs
cw-cyber/contracts/litium-core/src/lib.rs
cw-cyber/contracts/hub-networks/src/lib.rs
cw-cyber/contracts/graph-filter/src/lib.rs
cw-cyber/contracts/cw-cyber-subgraph/src/lib.rs
cw-cyber/contracts/cw-cyber-passport/src/lib.rs
cw-cyber/contracts/hub-channels/src/lib.rs
cw-cyber/contracts/litium-refer/src/lib.rs
cw-cyber/contracts/hub-libs/src/lib.rs
cw-cyber/contracts/hub-tokens/src/lib.rs
rs/tests/macro-integration/src/lib.rs

Local Graph