module std.fhe.rlwe
// Ring-LWE encryption over R_q = F_p[X]/(X^N + 1).
//
// Ciphertext (a(X), b(X)) where b(X) = a(X)*s(X) + m(X)*delta + e(X)
// modulo (X^N + 1). Ciphertext modulus q = p (Goldilocks), no impedance
// mismatch between FHE ring and STARK field.
//
// All ring operations use NTT from std.private.poly. Goldilocks has
// 2^32 roots of unity, making NTT native and efficient.
//
// Memory layout for one RLWE ciphertext of ring dimension N:
// ct_addr + 0 .. ct_addr + N-1 : polynomial a(X), N coefficients
// ct_addr + N .. ct_addr + 2N-1 : polynomial b(X), N coefficients
// Stride per ciphertext: 2N
//
// Pitch parameters: N = 64 (structurally identical to production N = 1024+).
use vm.core.convert
use vm.core.assert
use vm.io.io
use vm.io.mem
use std.private.poly
// ---------------------------------------------------------------------------
// Encrypt: produce RLWE ciphertext ct = (a, b).
// ---------------------------------------------------------------------------
// b(X) = a(X)*s(X) + m(X)*delta + e(X) mod (X^N + 1)
//
// a_addr: prover-supplied random polynomial (N coefficients)
// s_addr: secret key polynomial (N coefficients)
// m_addr: message polynomial (N coefficients)
// ct_addr: output ciphertext (2N coefficients: a then b)
// delta: scaling factor (p / plaintext_space)
// e_addr: prover-supplied noise polynomial (N coefficients)
// n: ring dimension N
// omega, omega_inv, n_inv, log_n: NTT parameters
// tmp1_addr, tmp2_addr: scratch (N elements each)
pub fn encrypt(
a_addr: Field,
s_addr: Field,
m_addr: Field,
ct_addr: Field,
delta: Field,
e_addr: Field,
n: Field,
omega: Field,
omega_inv: Field,
n_inv: Field,
log_n: Field,
tmp1_addr: Field,
tmp2_addr: Field
) {
let b_addr: Field = ct_addr + n
// Copy a to ct_addr (a part of ciphertext)
for i in 0..n bounded 4096 {
let idx: Field = convert.as_field(i)
mem.write(ct_addr + idx, mem.read(a_addr + idx))
}
// b = a * s (polynomial multiply via NTT)
poly.poly_mul(a_addr, s_addr, b_addr, tmp1_addr, omega, omega_inv, n_inv, log_n)
// b += m * delta
for i in 0..n bounded 4096 {
let idx: Field = convert.as_field(i)
let b_val: Field = mem.read(b_addr + idx)
let m_val: Field = mem.read(m_addr + idx)
mem.write(b_addr + idx, b_val + m_val * delta)
}
// b += e
poly.add(b_addr, e_addr, b_addr, n)
}
// ---------------------------------------------------------------------------
// Decrypt: recover message polynomial from RLWE ciphertext.
// ---------------------------------------------------------------------------
// phase(X) = b(X) - a(X)*s(X) = m(X)*delta + e(X)
// For each coefficient: m_i = round(phase_i / delta) via divine().
//
// ct_addr: ciphertext (2N coefficients)
// s_addr: secret key polynomial (N coefficients)
// out_addr: decrypted message polynomial (N coefficients)
// delta: scaling factor
// n, omega, omega_inv, n_inv, log_n: NTT parameters
// tmp1_addr, tmp2_addr: scratch (N elements each)
pub fn decrypt(
ct_addr: Field,
s_addr: Field,
out_addr: Field,
delta: Field,
n: Field,
omega: Field,
omega_inv: Field,
n_inv: Field,
log_n: Field,
tmp1_addr: Field,
tmp2_addr: Field
) {
let b_addr: Field = ct_addr + n
// phase = b - a * s
// First compute a * s -> tmp1
poly.poly_mul(ct_addr, s_addr, tmp1_addr, tmp2_addr, omega, omega_inv, n_inv, log_n)
// phase = b - a*s -> tmp1
poly.sub(b_addr, tmp1_addr, tmp1_addr, n)
// For each coefficient: divine the plaintext, verify noise bound
let half_delta: Field = delta * field.inv(2)
let (half_hi, half_lo) = convert.split(half_delta)
for i in 0..n bounded 4096 {
let idx: Field = convert.as_field(i)
let phase_i: Field = mem.read(tmp1_addr + idx)
// Prover supplies candidate plaintext coefficient
let m_i: Field = io.divine()
// Verify: noise = phase - m * delta, |noise| < delta/2
let noise: Field = phase_i + field.neg(m_i * delta)
let neg_noise: Field = field.neg(noise)
let (noise_hi, noise_lo) = convert.split(noise)
let (neg_hi, neg_lo) = convert.split(neg_noise)
let small_pos: Bool = noise_hi < half_hi
let small_neg: Bool = neg_hi < half_hi
if small_pos {
mem.write(out_addr + idx, m_i)
} else {
assert.is_true(small_neg)
mem.write(out_addr + idx, m_i)
}
}
}
// ---------------------------------------------------------------------------
// Homomorphic addition: ct_out = ct1 + ct2.
// ---------------------------------------------------------------------------
// Adds both polynomial components (a and b) coefficient-wise.
pub fn add(
ct1_addr: Field,
ct2_addr: Field,
out_addr: Field,
n: Field
) {
// Add a-components
poly.add(ct1_addr, ct2_addr, out_addr, n)
// Add b-components
let b1: Field = ct1_addr + n
let b2: Field = ct2_addr + n
let b_out: Field = out_addr + n
poly.add(b1, b2, b_out, n)
}
// ---------------------------------------------------------------------------
// Homomorphic scalar multiply: ct_out = k * ct.
// ---------------------------------------------------------------------------
// Scales both polynomial components by a scalar.
pub fn scale_ct(
ct_addr: Field,
k: Field,
out_addr: Field,
n: Field
) {
poly.scale(ct_addr, k, out_addr, n)
let b_in: Field = ct_addr + n
let b_out: Field = out_addr + n
poly.scale(b_in, k, b_out, n)
}
// ---------------------------------------------------------------------------
// External product: multiply RLWE ciphertext by a plaintext polynomial.
// ---------------------------------------------------------------------------
// ct_out = poly * ct (multiply both components of ct by poly via NTT).
// This is the core building block for blind rotation (CMux gate).
pub fn external_product(
ct_addr: Field,
poly_addr: Field,
out_addr: Field,
n: Field,
omega: Field,
omega_inv: Field,
n_inv: Field,
log_n: Field,
tmp1_addr: Field,
tmp2_addr: Field
) {
// a_out = poly * a
poly.poly_mul(ct_addr, poly_addr, out_addr, tmp1_addr, omega, omega_inv, n_inv, log_n)
// b_out = poly * b
let b_in: Field = ct_addr + n
let b_out: Field = out_addr + n
poly.poly_mul(b_in, poly_addr, b_out, tmp2_addr, omega, omega_inv, n_inv, log_n)
}
trident/std/fhe/rlwe.tri
ฯ 0.0%