module std.fhe.pbs

// Programmable Bootstrapping (PBS) over Goldilocks.
//
// PBS evaluates a lookup table on encrypted data โ€” the core TFHE operation.
// The lookup table is read from the SAME RAM address as the neural network
// ReLU activation and the LUT sponge S-box. This is Reader #3 of the
// Rosetta Stone: one table, four readers.
//
//   Reader 1: std.nn via lut.apply       โ€” neural activation
//   Reader 2: std.crypto.lut_sponge      โ€” crypto S-box
//   Reader 3: THIS MODULE via lut.read   โ€” FHE test polynomial
//   Reader 4: STARK LogUp                โ€” proof authentication (upstream)
//
// Algorithm (simplified TFHE):
//   1. Build test polynomial from lookup table
//   2. Initialize accumulator as trivial RLWE encryption of test_poly
//   3. Rotate accumulator by encrypted amount (blind rotation)
//   4. Sample extract: RLWE โ†’ LWE
//   5. Key switch to original key
//
// Pitch parameters: N = 64 (ring dimension), LWE_N = 8, D = 1024.
// Structurally identical to production TFHE (N = 1024+).
use vm.core.convert

use vm.core.assert

use vm.io.io

use vm.io.mem

use std.math.lut

use std.fhe.lwe

// ---------------------------------------------------------------------------
// Build test polynomial from lookup table.
// ---------------------------------------------------------------------------
// Reads D entries from lut_addr via lut.read โ€” THIS IS the Rosetta Stone
// FHE reader. The test polynomial encodes the same function (ReLU) that
// the neural network uses for activation.
//
// For N >= D: poly[i] = T[i] for i < D, zero-padded for i >= D.
// For N < D: poly[i] = T[i * D/N] (downsampled).
//
// lut_addr: address of shared lookup table (D entries)
// poly_addr: output polynomial (N coefficients)
// n: ring dimension N
// d: lookup table domain size D
pub fn build_test_poly(
    lut_addr: Field,
    poly_addr: Field,
    n: Field,
    d: Field
) {
    // Zero the output polynomial
    for i in 0..n bounded 4096 {
        let idx: Field = convert.as_field(i)
        mem.write(poly_addr + idx, 0)
    }
    // Store params in RAM scratch (1073741848..1073741851) for loop access.
    mem.write(1073741848, lut_addr)
    mem.write(1073741849, poly_addr)
    mem.write(1073741850, n)
    mem.write(1073741851, d)
    // Fill from lookup table
    for i in 0..n bounded 4096 {
        let idx: Field = convert.as_field(i)
        let r_d: Field = mem.read(1073741851)
        let scaled: Field = idx * r_d
        let table_idx: Field = io.divine()
        let r_n: Field = mem.read(1073741850)
        let lo: Field = table_idx * r_n
        assert.is_true(convert.as_u32(scaled + field.neg(lo)) < convert.as_u32(r_n))
        let r_lut: Field = mem.read(1073741848)
        let val: Field = lut.read(r_lut, table_idx)
        let r_poly: Field = mem.read(1073741849)
        mem.write(r_poly + idx, val)
    }
}

// ---------------------------------------------------------------------------
// Monomial multiply: poly *= X^k  mod (X^N + 1).
// ---------------------------------------------------------------------------
// Shifts coefficients by k positions. Coefficients that wrap around
// are negated (because X^N = -1 in the negacyclic ring).
//
// poly_addr: polynomial (N coefficients, modified in-place)
// k: shift amount (must be in [0, 2N))
// n: ring dimension
// tmp_addr: scratch (N elements)
pub fn monomial_mul(
    poly_addr: Field,
    k: Field,
    n: Field,
    tmp_addr: Field
) {
    // Copy original to tmp
    for i in 0..n bounded 4096 {
        let idx: Field = convert.as_field(i)
        mem.write(tmp_addr + idx, mem.read(poly_addr + idx))
    }
    // Store params in RAM scratch (1073741852..1073741853) for loop access.
    mem.write(1073741852, tmp_addr)
    mem.write(1073741853, poly_addr)
    for j in 0..n bounded 4096 {
        let jj: Field = convert.as_field(j)
        let src: Field = io.divine()
        let sign: Field = io.divine()
        let r_tmp: Field = mem.read(1073741852)
        let src_val: Field = mem.read(r_tmp + src)
        let r_poly: Field = mem.read(1073741853)
        mem.write(r_poly + jj, src_val * sign)
    }
}

// ---------------------------------------------------------------------------
// Blind rotation: the core PBS step.
// ---------------------------------------------------------------------------
// Rotates accumulator by the encrypted amount b - <a, s>.
// For each LWE dimension j, applies a CMux gate:
//   ACC = ACC * (1 + (X^{a_j} - 1) * BSK_j)
// where BSK_j encrypts the j-th bit of the secret key.
//
// In the simplified (pitch) version, the prover supplies the rotation
// amount directly and the circuit verifies correctness by checking
// the final result.
//
// acc_addr: RLWE accumulator (2N coefficients), modified in-place
// rotation: the total rotation amount b - <a,s> mod 2N
// n: ring dimension
// tmp_addr: scratch (N elements)
pub fn blind_rotate(
    acc_addr: Field,
    rotation: Field,
    n: Field,
    tmp_addr: Field
) {
    // Apply monomial multiplication to both components of the RLWE ciphertext
    // a-component
    monomial_mul(acc_addr, rotation, n, tmp_addr)
    // b-component
    let b_addr: Field = acc_addr + n
    monomial_mul(b_addr, rotation, n, tmp_addr)
}

// ---------------------------------------------------------------------------
// Sample extract: RLWE โ†’ LWE.
// ---------------------------------------------------------------------------
// Extracts the constant coefficient of the RLWE ciphertext as an LWE
// ciphertext. The LWE dimension equals the ring dimension N.
//
// rlwe_addr: RLWE ciphertext (2N coefficients: a(X), b(X))
// lwe_out_addr: output LWE ciphertext (N+1 elements: a[0..N-1], b)
// n: ring dimension
pub fn sample_extract(
    rlwe_addr: Field,
    lwe_out_addr: Field,
    n: Field
) {
    // LWE.a[0] = RLWE.a[0]
    mem.write(lwe_out_addr, mem.read(rlwe_addr))
    // LWE.a[j] = -RLWE.a[N-j] for j in 1..N
    for j in 1..n bounded 4096 {
        let jj: Field = convert.as_field(j)
        let src_idx: Field = n + field.neg(jj)
        let val: Field = mem.read(rlwe_addr + src_idx)
        mem.write(lwe_out_addr + jj, field.neg(val))
    }
    // LWE.b = RLWE.b[0]
    let rlwe_b_addr: Field = rlwe_addr + n
    mem.write(lwe_out_addr + n, mem.read(rlwe_b_addr))
}

// ---------------------------------------------------------------------------
// Key switch: convert LWE ciphertext from ring key to original key.
// ---------------------------------------------------------------------------
// Reduces LWE dimension from N (ring dimension) to lwe_n (original).
// Uses key-switching key: lwe_n RLWE-sized entries.
//
// The prover supplies the switched ciphertext via divine() and the
// circuit verifies consistency by checking decryption produces the
// same plaintext.
//
// in_addr: LWE ciphertext under ring key (N+1 elements)
// out_addr: LWE ciphertext under original key (lwe_n+1 elements)
// n: ring dimension (input LWE dimension)
// lwe_n: output LWE dimension
pub fn key_switch(
    in_addr: Field,
    out_addr: Field,
    n: Field,
    lwe_n: Field
) {
    // Prover supplies the key-switched ciphertext
    let stride: Field = lwe_n + 1
    for i in 0..stride bounded 4096 {
        let idx: Field = convert.as_field(i)
        let val: Field = io.divine()
        mem.write(out_addr + idx, val)
    }
    // The correctness is verified at the integration level:
    // the bootstrapped plaintext must match the divine-decrypted value.
}

// ---------------------------------------------------------------------------
// Full programmable bootstrap.
// ---------------------------------------------------------------------------
// Evaluates the lookup table on an encrypted value. The lookup table
// at lut_addr is the SAME table used for NN activation (Reader 1) and
// crypto S-box (Reader 2). This function is Reader 3.
//
// ct_addr: input LWE ciphertext (lwe_n + 1 elements)
// s_addr: LWE secret key (lwe_n elements)
// lut_addr: shared Rosetta Stone lookup table (D entries)
// out_addr: output LWE ciphertext (lwe_n + 1 elements)
// delta: scaling factor
// lwe_n: LWE dimension
// n: ring dimension for RLWE
// d: lookup table domain size
// acc_addr: scratch for RLWE accumulator (2N elements)
// test_poly_addr: scratch for test polynomial (N elements)
// tmp_addr: scratch (N elements)
//
// Returns: the decrypted bootstrapped plaintext (for pipeline assertion).
pub fn bootstrap(
    ct_addr: Field,
    s_addr: Field,
    lut_addr: Field,
    out_addr: Field,
    delta: Field,
    lwe_n: Field,
    n: Field,
    d: Field,
    acc_addr: Field,
    test_poly_addr: Field,
    tmp_addr: Field
) -> Field {
    // Step 1: Build test polynomial from shared LUT (Rosetta Stone Reader 3)
    build_test_poly(lut_addr, test_poly_addr, n, d)
    // Step 2: Initialize accumulator as trivial RLWE encryption of test_poly
    // Store loop params in RAM scratch (1073741856..1073741859).
    mem.write(1073741856, acc_addr)
    mem.write(1073741857, test_poly_addr)
    mem.write(1073741858, delta)
    mem.write(1073741859, acc_addr + n)
    // Zero the a-component
    for i in 0..n bounded 4096 {
        let idx: Field = convert.as_field(i)
        let r_acc: Field = mem.read(1073741856)
        mem.write(r_acc + idx, 0)
    }
    // Fill b-component: b[i] = test_poly[i] * delta
    for i in 0..n bounded 4096 {
        let idx: Field = convert.as_field(i)
        let r_test: Field = mem.read(1073741857)
        let coeff: Field = mem.read(r_test + idx)
        let r_acc_b: Field = mem.read(1073741859)
        let r_delta: Field = mem.read(1073741858)
        mem.write(r_acc_b + idx, coeff * r_delta)
    }
    // Step 3: Compute rotation amount
    // rotation = round(b / delta) - sum(round(a_i / delta) * s_i)
    // The prover supplies the rotation amount; the circuit verifies via the final result.
    let rotation: Field = io.divine()
    // Step 4: Blind rotate the accumulator
    blind_rotate(acc_addr, rotation, n, tmp_addr)
    // Step 5: Sample extract RLWE โ†’ LWE
    // Use test_poly_addr as scratch for the extracted LWE (N+1 elements fits in N + extra)
    let extracted_addr: Field = test_poly_addr
    sample_extract(acc_addr, extracted_addr, n)
    // Step 6: Key switch from ring dimension to original LWE dimension
    key_switch(extracted_addr, out_addr, n, lwe_n)
    // Step 7: Decrypt the bootstrapped ciphertext to return plaintext
    let m: Field = lwe.decrypt(out_addr, s_addr, delta, lwe_n)
    m
}

Local Graph