module std.fhe.pbs
// Programmable Bootstrapping (PBS) over Goldilocks.
//
// PBS evaluates a lookup table on encrypted data โ the core TFHE operation.
// The lookup table is read from the SAME RAM address as the neural network
// ReLU activation and the LUT sponge S-box. This is Reader #3 of the
// Rosetta Stone: one table, four readers.
//
// Reader 1: std.nn via lut.apply โ neural activation
// Reader 2: std.crypto.lut_sponge โ crypto S-box
// Reader 3: THIS MODULE via lut.read โ FHE test polynomial
// Reader 4: STARK LogUp โ proof authentication (upstream)
//
// Algorithm (simplified TFHE):
// 1. Build test polynomial from lookup table
// 2. Initialize accumulator as trivial RLWE encryption of test_poly
// 3. Rotate accumulator by encrypted amount (blind rotation)
// 4. Sample extract: RLWE โ LWE
// 5. Key switch to original key
//
// Pitch parameters: N = 64 (ring dimension), LWE_N = 8, D = 1024.
// Structurally identical to production TFHE (N = 1024+).
use vm.core.convert
use vm.core.assert
use vm.io.io
use vm.io.mem
use std.math.lut
use std.fhe.lwe
// ---------------------------------------------------------------------------
// Build test polynomial from lookup table.
// ---------------------------------------------------------------------------
// Reads D entries from lut_addr via lut.read โ THIS IS the Rosetta Stone
// FHE reader. The test polynomial encodes the same function (ReLU) that
// the neural network uses for activation.
//
// For N >= D: poly[i] = T[i] for i < D, zero-padded for i >= D.
// For N < D: poly[i] = T[i * D/N] (downsampled).
//
// lut_addr: address of shared lookup table (D entries)
// poly_addr: output polynomial (N coefficients)
// n: ring dimension N
// d: lookup table domain size D
pub fn build_test_poly(
lut_addr: Field,
poly_addr: Field,
n: Field,
d: Field
) {
// Zero the output polynomial
for i in 0..n bounded 4096 {
let idx: Field = convert.as_field(i)
mem.write(poly_addr + idx, 0)
}
// Store params in RAM scratch (1073741848..1073741851) for loop access.
mem.write(1073741848, lut_addr)
mem.write(1073741849, poly_addr)
mem.write(1073741850, n)
mem.write(1073741851, d)
// Fill from lookup table
for i in 0..n bounded 4096 {
let idx: Field = convert.as_field(i)
let r_d: Field = mem.read(1073741851)
let scaled: Field = idx * r_d
let table_idx: Field = io.divine()
let r_n: Field = mem.read(1073741850)
let lo: Field = table_idx * r_n
assert.is_true(convert.as_u32(scaled + field.neg(lo)) < convert.as_u32(r_n))
let r_lut: Field = mem.read(1073741848)
let val: Field = lut.read(r_lut, table_idx)
let r_poly: Field = mem.read(1073741849)
mem.write(r_poly + idx, val)
}
}
// ---------------------------------------------------------------------------
// Monomial multiply: poly *= X^k mod (X^N + 1).
// ---------------------------------------------------------------------------
// Shifts coefficients by k positions. Coefficients that wrap around
// are negated (because X^N = -1 in the negacyclic ring).
//
// poly_addr: polynomial (N coefficients, modified in-place)
// k: shift amount (must be in [0, 2N))
// n: ring dimension
// tmp_addr: scratch (N elements)
pub fn monomial_mul(
poly_addr: Field,
k: Field,
n: Field,
tmp_addr: Field
) {
// Copy original to tmp
for i in 0..n bounded 4096 {
let idx: Field = convert.as_field(i)
mem.write(tmp_addr + idx, mem.read(poly_addr + idx))
}
// Store params in RAM scratch (1073741852..1073741853) for loop access.
mem.write(1073741852, tmp_addr)
mem.write(1073741853, poly_addr)
for j in 0..n bounded 4096 {
let jj: Field = convert.as_field(j)
let src: Field = io.divine()
let sign: Field = io.divine()
let r_tmp: Field = mem.read(1073741852)
let src_val: Field = mem.read(r_tmp + src)
let r_poly: Field = mem.read(1073741853)
mem.write(r_poly + jj, src_val * sign)
}
}
// ---------------------------------------------------------------------------
// Blind rotation: the core PBS step.
// ---------------------------------------------------------------------------
// Rotates accumulator by the encrypted amount b - <a, s>.
// For each LWE dimension j, applies a CMux gate:
// ACC = ACC * (1 + (X^{a_j} - 1) * BSK_j)
// where BSK_j encrypts the j-th bit of the secret key.
//
// In the simplified (pitch) version, the prover supplies the rotation
// amount directly and the circuit verifies correctness by checking
// the final result.
//
// acc_addr: RLWE accumulator (2N coefficients), modified in-place
// rotation: the total rotation amount b - <a,s> mod 2N
// n: ring dimension
// tmp_addr: scratch (N elements)
pub fn blind_rotate(
acc_addr: Field,
rotation: Field,
n: Field,
tmp_addr: Field
) {
// Apply monomial multiplication to both components of the RLWE ciphertext
// a-component
monomial_mul(acc_addr, rotation, n, tmp_addr)
// b-component
let b_addr: Field = acc_addr + n
monomial_mul(b_addr, rotation, n, tmp_addr)
}
// ---------------------------------------------------------------------------
// Sample extract: RLWE โ LWE.
// ---------------------------------------------------------------------------
// Extracts the constant coefficient of the RLWE ciphertext as an LWE
// ciphertext. The LWE dimension equals the ring dimension N.
//
// rlwe_addr: RLWE ciphertext (2N coefficients: a(X), b(X))
// lwe_out_addr: output LWE ciphertext (N+1 elements: a[0..N-1], b)
// n: ring dimension
pub fn sample_extract(
rlwe_addr: Field,
lwe_out_addr: Field,
n: Field
) {
// LWE.a[0] = RLWE.a[0]
mem.write(lwe_out_addr, mem.read(rlwe_addr))
// LWE.a[j] = -RLWE.a[N-j] for j in 1..N
for j in 1..n bounded 4096 {
let jj: Field = convert.as_field(j)
let src_idx: Field = n + field.neg(jj)
let val: Field = mem.read(rlwe_addr + src_idx)
mem.write(lwe_out_addr + jj, field.neg(val))
}
// LWE.b = RLWE.b[0]
let rlwe_b_addr: Field = rlwe_addr + n
mem.write(lwe_out_addr + n, mem.read(rlwe_b_addr))
}
// ---------------------------------------------------------------------------
// Key switch: convert LWE ciphertext from ring key to original key.
// ---------------------------------------------------------------------------
// Reduces LWE dimension from N (ring dimension) to lwe_n (original).
// Uses key-switching key: lwe_n RLWE-sized entries.
//
// The prover supplies the switched ciphertext via divine() and the
// circuit verifies consistency by checking decryption produces the
// same plaintext.
//
// in_addr: LWE ciphertext under ring key (N+1 elements)
// out_addr: LWE ciphertext under original key (lwe_n+1 elements)
// n: ring dimension (input LWE dimension)
// lwe_n: output LWE dimension
pub fn key_switch(
in_addr: Field,
out_addr: Field,
n: Field,
lwe_n: Field
) {
// Prover supplies the key-switched ciphertext
let stride: Field = lwe_n + 1
for i in 0..stride bounded 4096 {
let idx: Field = convert.as_field(i)
let val: Field = io.divine()
mem.write(out_addr + idx, val)
}
// The correctness is verified at the integration level:
// the bootstrapped plaintext must match the divine-decrypted value.
}
// ---------------------------------------------------------------------------
// Full programmable bootstrap.
// ---------------------------------------------------------------------------
// Evaluates the lookup table on an encrypted value. The lookup table
// at lut_addr is the SAME table used for NN activation (Reader 1) and
// crypto S-box (Reader 2). This function is Reader 3.
//
// ct_addr: input LWE ciphertext (lwe_n + 1 elements)
// s_addr: LWE secret key (lwe_n elements)
// lut_addr: shared Rosetta Stone lookup table (D entries)
// out_addr: output LWE ciphertext (lwe_n + 1 elements)
// delta: scaling factor
// lwe_n: LWE dimension
// n: ring dimension for RLWE
// d: lookup table domain size
// acc_addr: scratch for RLWE accumulator (2N elements)
// test_poly_addr: scratch for test polynomial (N elements)
// tmp_addr: scratch (N elements)
//
// Returns: the decrypted bootstrapped plaintext (for pipeline assertion).
pub fn bootstrap(
ct_addr: Field,
s_addr: Field,
lut_addr: Field,
out_addr: Field,
delta: Field,
lwe_n: Field,
n: Field,
d: Field,
acc_addr: Field,
test_poly_addr: Field,
tmp_addr: Field
) -> Field {
// Step 1: Build test polynomial from shared LUT (Rosetta Stone Reader 3)
build_test_poly(lut_addr, test_poly_addr, n, d)
// Step 2: Initialize accumulator as trivial RLWE encryption of test_poly
// Store loop params in RAM scratch (1073741856..1073741859).
mem.write(1073741856, acc_addr)
mem.write(1073741857, test_poly_addr)
mem.write(1073741858, delta)
mem.write(1073741859, acc_addr + n)
// Zero the a-component
for i in 0..n bounded 4096 {
let idx: Field = convert.as_field(i)
let r_acc: Field = mem.read(1073741856)
mem.write(r_acc + idx, 0)
}
// Fill b-component: b[i] = test_poly[i] * delta
for i in 0..n bounded 4096 {
let idx: Field = convert.as_field(i)
let r_test: Field = mem.read(1073741857)
let coeff: Field = mem.read(r_test + idx)
let r_acc_b: Field = mem.read(1073741859)
let r_delta: Field = mem.read(1073741858)
mem.write(r_acc_b + idx, coeff * r_delta)
}
// Step 3: Compute rotation amount
// rotation = round(b / delta) - sum(round(a_i / delta) * s_i)
// The prover supplies the rotation amount; the circuit verifies via the final result.
let rotation: Field = io.divine()
// Step 4: Blind rotate the accumulator
blind_rotate(acc_addr, rotation, n, tmp_addr)
// Step 5: Sample extract RLWE โ LWE
// Use test_poly_addr as scratch for the extracted LWE (N+1 elements fits in N + extra)
let extracted_addr: Field = test_poly_addr
sample_extract(acc_addr, extracted_addr, n)
// Step 6: Key switch from ring dimension to original LWE dimension
key_switch(extracted_addr, out_addr, n, lwe_n)
// Step 7: Decrypt the bootstrapped ciphertext to return plaintext
let m: Field = lwe.decrypt(out_addr, s_addr, delta, lwe_n)
m
}
trident/std/fhe/pbs.tri
ฯ 0.0%