module std.fhe.lwe
// LWE encryption over the Goldilocks field (p = 2^64 - 2^32 + 1).
//
// Ciphertext (a, b) where b = <a, s> + m * Delta + e.
// Ciphertext modulus q = p โ no impedance mismatch.
// Every LWE operation is native field arithmetic, natively provable
// inside a STARK trace.
//
// Memory layout for one ciphertext of LWE dimension n:
// ct_addr + 0 .. ct_addr + n-1 : vector a (n field elements)
// ct_addr + n : scalar b (1 field element)
// Stride per ciphertext: n + 1
//
// The prover pre-generates (a, s, e) outside the trace and stores
// them in RAM before execution. The STARK proof covers correct
// computation given these values. Randomness quality is the prover's
// responsibility, not the circuit's.
use vm.core.field
use vm.core.convert
use vm.core.assert
use vm.io.io
use vm.io.mem
// ---------------------------------------------------------------------------
// Inner product: <a, s> = sum(a[i] * s[i]) for i in 0..n.
// Both vectors stored in RAM.
// ---------------------------------------------------------------------------
pub fn inner_product(a_addr: Field, s_addr: Field, n: Field) -> Field {
let mut sum: Field = 0
for i in 0..n bounded 4096 {
let idx: Field = convert.as_field(i)
let a_val: Field = mem.read(a_addr + idx)
let s_val: Field = mem.read(s_addr + idx)
sum = sum + a_val * s_val
}
sum
}
// ---------------------------------------------------------------------------
// Encrypt: produce LWE ciphertext at ct_addr.
//
// a_addr: prover-supplied random vector (n elements in RAM)
// s_addr: secret key (n elements in RAM)
// ct_addr: output ciphertext (n+1 elements)
// m: plaintext message
// delta: scaling factor (p / t for t-bit plaintext space)
// e: prover-supplied noise (small field element)
// n: LWE dimension
//
// Writes: ct_addr[0..n-1] = a, ct_addr[n] = <a,s> + m*delta + e
// ---------------------------------------------------------------------------
pub fn encrypt(
a_addr: Field,
s_addr: Field,
ct_addr: Field,
m: Field,
delta: Field,
e: Field,
n: Field
) {
// Copy random vector a into ciphertext
for i in 0..n bounded 4096 {
let idx: Field = convert.as_field(i)
mem.write(ct_addr + idx, mem.read(a_addr + idx))
}
// b = <a, s> + m * delta + e
let dot: Field = inner_product(a_addr, s_addr, n)
let b: Field = dot + m * delta + e
mem.write(ct_addr + n, b)
}
// ---------------------------------------------------------------------------
// Decrypt: recover plaintext from ciphertext.
//
// phase = b - <a, s> = m * delta + e
//
// Rounding in a finite field: the prover supplies candidate m via
// io.divine(). The circuit verifies |phase - m * delta| < delta/2.
// Standard TFHE-in-ZK pattern.
//
// ct_addr: ciphertext (n+1 elements)
// s_addr: secret key (n elements)
// delta: scaling factor
// n: LWE dimension
// Returns: decrypted plaintext m
// ---------------------------------------------------------------------------
// Check if hi-word of a < hi-word of b.
// Separate function to keep stack shallow during split.
fn hi_lt(a: Field, b: Field) -> Bool {
let (a_hi, a_lo) = convert.split(a)
let (b_hi, b_lo) = convert.split(b)
a_hi < b_hi
}
pub fn decrypt(ct_addr: Field, s_addr: Field, delta: Field, n: Field) -> Field {
// phase = b - <a, s>
let b: Field = mem.read(ct_addr + n)
let dot: Field = inner_product(ct_addr, s_addr, n)
let phase: Field = b + field.neg(dot)
// Prover supplies candidate plaintext
let m: Field = io.divine()
// Verify: noise = phase - m * delta
let noise: Field = phase + field.neg(m * delta)
// Check |noise| < delta/2.
let half_delta: Field = delta * field.inv(2)
let neg_noise: Field = field.neg(noise)
// Store values in RAM to keep stack clean for hi_lt calls.
mem.write(1073741864, noise)
mem.write(1073741865, half_delta)
mem.write(1073741866, neg_noise)
mem.write(1073741867, m)
// Compare noise hi < half_delta hi (small positive noise).
let small_pos: Bool = hi_lt(mem.read(1073741864), mem.read(1073741865))
if small_pos {
mem.read(1073741867)
} else {
// Compare neg_noise hi < half_delta hi (small negative noise).
let small_neg: Bool = hi_lt(mem.read(1073741866), mem.read(1073741865))
assert.is_true(small_neg)
mem.read(1073741867)
}
}
// ---------------------------------------------------------------------------
// Homomorphic addition: ct_out = ct1 + ct2.
// Adds all n+1 elements pointwise.
// ---------------------------------------------------------------------------
pub fn ct_add(ct1_addr: Field, ct2_addr: Field, out_addr: Field, n: Field) {
let size: Field = n + 1
for i in 0..size bounded 4096 {
let idx: Field = convert.as_field(i)
let v1: Field = mem.read(ct1_addr + idx)
let v2: Field = mem.read(ct2_addr + idx)
mem.write(out_addr + idx, v1 + v2)
}
}
// ---------------------------------------------------------------------------
// Homomorphic scalar multiply: ct_out = k * ct.
// Multiplies all n+1 elements by scalar k.
// ---------------------------------------------------------------------------
pub fn scale(ct_addr: Field, k: Field, out_addr: Field, n: Field) {
let size: Field = n + 1
for i in 0..size bounded 4096 {
let idx: Field = convert.as_field(i)
let val: Field = mem.read(ct_addr + idx)
mem.write(out_addr + idx, val * k)
}
}
// ---------------------------------------------------------------------------
// Encrypted dot product: out = sum_i(w[i] * ct[i]).
//
// Computes weighted sum of ciphertexts (plaintext weights, encrypted data).
// Each ct[i] is an LWE ciphertext of dimension lwe_n at stride lwe_n+1.
//
// cts_addr: base of n_ct contiguous ciphertexts
// w_addr: n_ct plaintext weights in RAM
// out_addr: output ciphertext (lwe_n+1 elements)
// tmp_addr: scratch (lwe_n+1 elements)
// lwe_n: LWE dimension
// n_ct: number of input ciphertexts
// ---------------------------------------------------------------------------
pub fn private_dot(
cts_addr: Field,
w_addr: Field,
out_addr: Field,
tmp_addr: Field,
lwe_n: Field,
n_ct: Field
) {
let stride: Field = lwe_n + 1
// Zero output ciphertext
for j in 0..stride bounded 4096 {
let jj: Field = convert.as_field(j)
mem.write(out_addr + jj, 0)
}
// Store params in RAM scratch area (addresses 1073741832..1073741837).
// Use literal addresses so compiler emits push instead of dup.
mem.write(1073741832, cts_addr)
mem.write(1073741833, w_addr)
mem.write(1073741834, out_addr)
mem.write(1073741835, tmp_addr)
mem.write(1073741836, lwe_n)
mem.write(1073741837, stride)
// Accumulate: out += w[i] * ct[i]
for i in 0..n_ct bounded 4096 {
let ii: Field = convert.as_field(i)
let r_cts: Field = mem.read(1073741832)
let r_w: Field = mem.read(1073741833)
let r_out: Field = mem.read(1073741834)
let r_tmp: Field = mem.read(1073741835)
let r_lwe: Field = mem.read(1073741836)
let r_stride: Field = mem.read(1073741837)
let ct_base: Field = r_cts + ii * r_stride
let w: Field = mem.read(r_w + ii)
scale(ct_base, w, r_tmp, r_lwe)
ct_add(r_out, r_tmp, r_out, r_lwe)
}
}
// ---------------------------------------------------------------------------
// Private linear layer: W * x where x is encrypted, W is plaintext.
//
// Each row of W dot-producted with the encrypted input vector gives
// one output ciphertext. This is the core of private neural inference.
//
// cts_addr: input encrypted vector (input_dim ciphertexts)
// w_addr: weight matrix (neurons x input_dim, plaintext, row-major)
// out_addr: output encrypted vector (neurons ciphertexts)
// tmp_addr: scratch (lwe_n+1 elements)
// lwe_n: LWE dimension
// input_dim: number of input ciphertexts
// neurons: number of output neurons
// ---------------------------------------------------------------------------
pub fn private_linear(
cts_addr: Field,
w_addr: Field,
out_addr: Field,
tmp_addr: Field,
lwe_n: Field,
input_dim: Field,
neurons: Field
) {
// Store params in RAM scratch area (addresses 1073741824..1073741830).
// Use literal addresses everywhere so the compiler emits push instead of dup.
mem.write(1073741824, cts_addr)
mem.write(1073741825, w_addr)
mem.write(1073741826, out_addr)
mem.write(1073741827, tmp_addr)
mem.write(1073741828, lwe_n)
mem.write(1073741829, input_dim)
mem.write(1073741830, lwe_n + 1)
for row in 0..neurons bounded 4096 {
let row_idx: Field = convert.as_field(row)
let r_cts: Field = mem.read(1073741824)
let r_w: Field = mem.read(1073741825)
let r_out: Field = mem.read(1073741826)
let r_tmp: Field = mem.read(1073741827)
let r_lwe: Field = mem.read(1073741828)
let r_dim: Field = mem.read(1073741829)
let r_stride: Field = mem.read(1073741830)
let row_w_addr: Field = r_w + row_idx * r_dim
let row_out_addr: Field = r_out + row_idx * r_stride
private_dot(r_cts, row_w_addr, row_out_addr, r_tmp, r_lwe, r_dim)
}
}
trident/std/fhe/lwe.tri
ฯ 0.0%