module std.crypto.lut_sponge

// Sponge hash with lookup-table S-box over the Goldilocks field.
//
// The S-box reads from a RAM-based lookup table โ€” the SAME table used
// for neural network activation (ReLU) and FHE programmable bootstrapping.
// This is the Rosetta Stone: one table, four readers.
//
//   Reader 1: std.nn via lut.apply  โ€” neural activation
//   Reader 2: THIS MODULE via lut.read โ€” crypto S-box
//   Reader 3: std.fhe.pbs via lut.read โ€” FHE test polynomial
//   Reader 4: STARK LogUp โ€” proof authentication (upstream)
//
// Construction: Rescue-style sponge with bounded S-box.
//   State width: 8, Rate: 4, Capacity: 4
//   S-box: lut.read(lut_addr, x mod D) where D = table domain size
//   MDS: circulant(2,1,1,...,1) โ€” same as Poseidon2 external linear
//   Rounds: 14 (conservative for 10-bit S-box)
//   Round constants: 14 * 8 = 112 field elements from RAM
//
// After MDS, state elements exceed [0, D). The prover supplies the
// reduced value via divine(), the circuit constrains it:
//   r = divine(), k = divine(), assert(x == k * D + r), assert(r < D)
//
// For D = 1024 = 2^10: r fits in 10 bits, so split(r) gives hi=0, lo<1024.
use vm.core.convert

use vm.core.assert

use vm.io.io

use vm.io.mem

use std.math.lut

// ---------------------------------------------------------------------------
// State: 8 field elements (same shape as poseidon2.State)
// ---------------------------------------------------------------------------
pub struct State {
    s0: Field,
    s1: Field,
    s2: Field,
    s3: Field,
    s4: Field,
    s5: Field,
    s6: Field,
    s7: Field,
}

pub fn zero_state() -> State {
    State { s0: 0, s1: 0, s2: 0, s3: 0, s4: 0, s5: 0, s6: 0, s7: 0 }
}

// ---------------------------------------------------------------------------
// Modular reduction: reduce field element to [0, D).
// ---------------------------------------------------------------------------
// The prover supplies r = x mod D and k = x / D (integer division).
// The circuit verifies: x == k * D + r, and r < D.
//
// For D = 1024: r is the low 10 bits. split(r) gives (hi, lo) where
// hi must be 0 and lo < 1024. split(k) is unconstrained (any quotient).
fn reduce_mod(x: Field, d: Field) -> Field {
    let r: Field = io.divine()
    let k: Field = io.divine()
    // Constraint: x == k * d + r
    assert.eq(x, k * d + r)
    // Range check: r < d
    // For d = 1024: r must fit in 10 bits.
    // split(r) = (hi, lo) where r = hi * 2^32 + lo.
    // If r < 1024 < 2^32, then hi == 0.
    let (r_hi, r_lo) = convert.split(r)
    let zero: U32 = convert.as_u32(0)
    assert.eq(convert.as_field(r_hi), 0)
    // lo < d (both are U32, d fits in U32)
    let d_u32: U32 = convert.as_u32(d)
    assert.is_true(r_lo < d_u32)
    r
}

// ---------------------------------------------------------------------------
// S-box layer: reduce + lookup for all 8 state elements.
// ---------------------------------------------------------------------------
// Each element is reduced to [0, D) then looked up in the shared table.
// The table read IS the Rosetta Stone crypto reader.
fn sbox_layer(st: State, lut_addr: Field, domain: Field) -> State {
    State {
        s0: lut.read(lut_addr, reduce_mod(st.s0, domain)),
        s1: lut.read(lut_addr, reduce_mod(st.s1, domain)),
        s2: lut.read(lut_addr, reduce_mod(st.s2, domain)),
        s3: lut.read(lut_addr, reduce_mod(st.s3, domain)),
        s4: lut.read(lut_addr, reduce_mod(st.s4, domain)),
        s5: lut.read(lut_addr, reduce_mod(st.s5, domain)),
        s6: lut.read(lut_addr, reduce_mod(st.s6, domain)),
        s7: lut.read(lut_addr, reduce_mod(st.s7, domain))
    }
}

// ---------------------------------------------------------------------------
// MDS linear layer: circulant(2,1,1,...,1)
// ---------------------------------------------------------------------------
// new[i] = state[i] + sum(state)
// Identical to poseidon2.external_linear. Full diffusion in 2 rounds.
fn mds(st: State) -> State {
    let sum: Field = st.s0 + st.s1 + st.s2 + st.s3 + st.s4 + st.s5 + st.s6 + st.s7
    State {
        s0: st.s0 + sum, s1: st.s1 + sum, s2: st.s2 + sum, s3: st.s3 + sum,
        s4: st.s4 + sum, s5: st.s5 + sum, s6: st.s6 + sum, s7: st.s7 + sum
    }
}

// ---------------------------------------------------------------------------
// Add 8 round constants from RAM.
// ---------------------------------------------------------------------------
// rc_addr points to 8 consecutive field elements for this round.
fn add_constants(st: State, rc_addr: Field) -> State {
    State {
        s0: st.s0 + mem.read(rc_addr),
        s1: st.s1 + mem.read(rc_addr + 1),
        s2: st.s2 + mem.read(rc_addr + 2),
        s3: st.s3 + mem.read(rc_addr + 3),
        s4: st.s4 + mem.read(rc_addr + 4),
        s5: st.s5 + mem.read(rc_addr + 5),
        s6: st.s6 + mem.read(rc_addr + 6),
        s7: st.s7 + mem.read(rc_addr + 7)
    }
}

// ---------------------------------------------------------------------------
// One round: add constants + S-box (LUT) + MDS
// ---------------------------------------------------------------------------
// Order: ARK -> S-box -> MDS (standard AES-like structure).
// rc_addr: address of 8 round constants for this round.
fn round(st: State, lut_addr: Field, domain: Field, rc_addr: Field) -> State {
    let after_ark: State = add_constants(st, rc_addr)
    let after_sbox: State = sbox_layer(after_ark, lut_addr, domain)
    mds(after_sbox)
}

// ---------------------------------------------------------------------------
// Full permutation: 14 rounds.
// ---------------------------------------------------------------------------
// Round constants: 14 * 8 = 112 field elements at rc_addr.
// Round i uses constants at rc_addr + i * 8.
pub fn permute(st: State, lut_addr: Field, domain: Field, rc_addr: Field) -> State {
    let mut state: State = st
    for i in 0..14 bounded 14 {
        let idx: Field = convert.as_field(i)
        state = round(state, lut_addr, domain, rc_addr + idx * 8)
    }
    state
}

// ---------------------------------------------------------------------------
// Sponge: absorb 4 elements into rate portion of state.
// ---------------------------------------------------------------------------
pub fn absorb4(st: State, a: Field, b: Field, c: Field, d: Field) -> State {
    State {
        s0: st.s0 + a, s1: st.s1 + b, s2: st.s2 + c, s3: st.s3 + d,
        s4: st.s4, s5: st.s5, s6: st.s6, s7: st.s7
    }
}

// ---------------------------------------------------------------------------
// Squeeze: extract first element from state.
// ---------------------------------------------------------------------------
pub fn squeeze1(st: State) -> Field {
    st.s0
}

// ---------------------------------------------------------------------------
// Hash 4 field elements to 1 digest element.
// ---------------------------------------------------------------------------
// Domain separation: capacity element s4 = 4 (number of inputs).
// The S-box reads from lut_addr โ€” the shared Rosetta Stone table.
pub fn hash4_to_digest(
    a: Field,
    b: Field,
    c: Field,
    d: Field,
    lut_addr: Field,
    domain: Field,
    rc_addr: Field
) -> Field {
    let init: State = State { s0: a, s1: b, s2: c, s3: d, s4: 4, s5: 0, s6: 0, s7: 0 }
    let result: State = permute(init, lut_addr, domain, rc_addr)
    squeeze1(result)
}

Local Graph