module std.fhe.lwe

// LWE encryption over the Goldilocks field (p = 2^64 - 2^32 + 1).
//
// Ciphertext (a, b) where b = <a, s> + m * Delta + e.
// Ciphertext modulus q = p โ€” no impedance mismatch.
// Every LWE operation is native field arithmetic, natively provable
// inside a STARK trace.
//
// Memory layout for one ciphertext of LWE dimension n:
//   ct_addr + 0 .. ct_addr + n-1  : vector a  (n field elements)
//   ct_addr + n                   : scalar b  (1 field element)
//   Stride per ciphertext: n + 1
//
// The prover pre-generates (a, s, e) outside the trace and stores
// them in RAM before execution. The STARK proof covers correct
// computation given these values. Randomness quality is the prover's
// responsibility, not the circuit's.
use vm.core.field

use vm.core.convert

use vm.core.assert

use vm.io.io

use vm.io.mem

// ---------------------------------------------------------------------------
// Inner product: <a, s> = sum(a[i] * s[i]) for i in 0..n.
// Both vectors stored in RAM.
// ---------------------------------------------------------------------------
pub fn inner_product(a_addr: Field, s_addr: Field, n: Field) -> Field {
    let mut sum: Field = 0
    for i in 0..n bounded 4096 {
        let idx: Field = convert.as_field(i)
        let a_val: Field = mem.read(a_addr + idx)
        let s_val: Field = mem.read(s_addr + idx)
        sum = sum + a_val * s_val
    }
    sum
}

// ---------------------------------------------------------------------------
// Encrypt: produce LWE ciphertext at ct_addr.
//
// a_addr: prover-supplied random vector (n elements in RAM)
// s_addr: secret key (n elements in RAM)
// ct_addr: output ciphertext (n+1 elements)
// m: plaintext message
// delta: scaling factor (p / t for t-bit plaintext space)
// e: prover-supplied noise (small field element)
// n: LWE dimension
//
// Writes: ct_addr[0..n-1] = a, ct_addr[n] = <a,s> + m*delta + e
// ---------------------------------------------------------------------------
pub fn encrypt(
    a_addr: Field,
    s_addr: Field,
    ct_addr: Field,
    m: Field,
    delta: Field,
    e: Field,
    n: Field
) {
    // Copy random vector a into ciphertext
    for i in 0..n bounded 4096 {
        let idx: Field = convert.as_field(i)
        mem.write(ct_addr + idx, mem.read(a_addr + idx))
    }
    // b = <a, s> + m * delta + e
    let dot: Field = inner_product(a_addr, s_addr, n)
    let b: Field = dot + m * delta + e
    mem.write(ct_addr + n, b)
}

// ---------------------------------------------------------------------------
// Decrypt: recover plaintext from ciphertext.
//
// phase = b - <a, s> = m * delta + e
//
// Rounding in a finite field: the prover supplies candidate m via
// io.divine(). The circuit verifies |phase - m * delta| < delta/2.
// Standard TFHE-in-ZK pattern.
//
// ct_addr: ciphertext (n+1 elements)
// s_addr: secret key (n elements)
// delta: scaling factor
// n: LWE dimension
// Returns: decrypted plaintext m
// ---------------------------------------------------------------------------
// Check if hi-word of a < hi-word of b.
// Separate function to keep stack shallow during split.
fn hi_lt(a: Field, b: Field) -> Bool {
    let (a_hi, a_lo) = convert.split(a)
    let (b_hi, b_lo) = convert.split(b)
    a_hi < b_hi
}

pub fn decrypt(ct_addr: Field, s_addr: Field, delta: Field, n: Field) -> Field {
    // phase = b - <a, s>
    let b: Field = mem.read(ct_addr + n)
    let dot: Field = inner_product(ct_addr, s_addr, n)
    let phase: Field = b + field.neg(dot)
    // Prover supplies candidate plaintext
    let m: Field = io.divine()
    // Verify: noise = phase - m * delta
    let noise: Field = phase + field.neg(m * delta)
    // Check |noise| < delta/2.
    let half_delta: Field = delta * field.inv(2)
    let neg_noise: Field = field.neg(noise)
    // Store values in RAM to keep stack clean for hi_lt calls.
    mem.write(1073741864, noise)
    mem.write(1073741865, half_delta)
    mem.write(1073741866, neg_noise)
    mem.write(1073741867, m)
    // Compare noise hi < half_delta hi (small positive noise).
    let small_pos: Bool = hi_lt(mem.read(1073741864), mem.read(1073741865))
    if small_pos {
        mem.read(1073741867)
    } else {
        // Compare neg_noise hi < half_delta hi (small negative noise).
        let small_neg: Bool = hi_lt(mem.read(1073741866), mem.read(1073741865))
        assert.is_true(small_neg)
        mem.read(1073741867)
    }
}

// ---------------------------------------------------------------------------
// Homomorphic addition: ct_out = ct1 + ct2.
// Adds all n+1 elements pointwise.
// ---------------------------------------------------------------------------
pub fn ct_add(ct1_addr: Field, ct2_addr: Field, out_addr: Field, n: Field) {
    let size: Field = n + 1
    for i in 0..size bounded 4096 {
        let idx: Field = convert.as_field(i)
        let v1: Field = mem.read(ct1_addr + idx)
        let v2: Field = mem.read(ct2_addr + idx)
        mem.write(out_addr + idx, v1 + v2)
    }
}

// ---------------------------------------------------------------------------
// Homomorphic scalar multiply: ct_out = k * ct.
// Multiplies all n+1 elements by scalar k.
// ---------------------------------------------------------------------------
pub fn scale(ct_addr: Field, k: Field, out_addr: Field, n: Field) {
    let size: Field = n + 1
    for i in 0..size bounded 4096 {
        let idx: Field = convert.as_field(i)
        let val: Field = mem.read(ct_addr + idx)
        mem.write(out_addr + idx, val * k)
    }
}

// ---------------------------------------------------------------------------
// Encrypted dot product: out = sum_i(w[i] * ct[i]).
//
// Computes weighted sum of ciphertexts (plaintext weights, encrypted data).
// Each ct[i] is an LWE ciphertext of dimension lwe_n at stride lwe_n+1.
//
// cts_addr: base of n_ct contiguous ciphertexts
// w_addr: n_ct plaintext weights in RAM
// out_addr: output ciphertext (lwe_n+1 elements)
// tmp_addr: scratch (lwe_n+1 elements)
// lwe_n: LWE dimension
// n_ct: number of input ciphertexts
// ---------------------------------------------------------------------------
pub fn private_dot(
    cts_addr: Field,
    w_addr: Field,
    out_addr: Field,
    tmp_addr: Field,
    lwe_n: Field,
    n_ct: Field
) {
    let stride: Field = lwe_n + 1
    // Zero output ciphertext
    for j in 0..stride bounded 4096 {
        let jj: Field = convert.as_field(j)
        mem.write(out_addr + jj, 0)
    }
    // Store params in RAM scratch area (addresses 1073741832..1073741837).
    // Use literal addresses so compiler emits push instead of dup.
    mem.write(1073741832, cts_addr)
    mem.write(1073741833, w_addr)
    mem.write(1073741834, out_addr)
    mem.write(1073741835, tmp_addr)
    mem.write(1073741836, lwe_n)
    mem.write(1073741837, stride)
    // Accumulate: out += w[i] * ct[i]
    for i in 0..n_ct bounded 4096 {
        let ii: Field = convert.as_field(i)
        let r_cts: Field = mem.read(1073741832)
        let r_w: Field = mem.read(1073741833)
        let r_out: Field = mem.read(1073741834)
        let r_tmp: Field = mem.read(1073741835)
        let r_lwe: Field = mem.read(1073741836)
        let r_stride: Field = mem.read(1073741837)
        let ct_base: Field = r_cts + ii * r_stride
        let w: Field = mem.read(r_w + ii)
        scale(ct_base, w, r_tmp, r_lwe)
        ct_add(r_out, r_tmp, r_out, r_lwe)
    }
}

// ---------------------------------------------------------------------------
// Private linear layer: W * x where x is encrypted, W is plaintext.
//
// Each row of W dot-producted with the encrypted input vector gives
// one output ciphertext. This is the core of private neural inference.
//
// cts_addr: input encrypted vector (input_dim ciphertexts)
// w_addr: weight matrix (neurons x input_dim, plaintext, row-major)
// out_addr: output encrypted vector (neurons ciphertexts)
// tmp_addr: scratch (lwe_n+1 elements)
// lwe_n: LWE dimension
// input_dim: number of input ciphertexts
// neurons: number of output neurons
// ---------------------------------------------------------------------------
pub fn private_linear(
    cts_addr: Field,
    w_addr: Field,
    out_addr: Field,
    tmp_addr: Field,
    lwe_n: Field,
    input_dim: Field,
    neurons: Field
) {
    // Store params in RAM scratch area (addresses 1073741824..1073741830).
    // Use literal addresses everywhere so the compiler emits push instead of dup.
    mem.write(1073741824, cts_addr)
    mem.write(1073741825, w_addr)
    mem.write(1073741826, out_addr)
    mem.write(1073741827, tmp_addr)
    mem.write(1073741828, lwe_n)
    mem.write(1073741829, input_dim)
    mem.write(1073741830, lwe_n + 1)
    for row in 0..neurons bounded 4096 {
        let row_idx: Field = convert.as_field(row)
        let r_cts: Field = mem.read(1073741824)
        let r_w: Field = mem.read(1073741825)
        let r_out: Field = mem.read(1073741826)
        let r_tmp: Field = mem.read(1073741827)
        let r_lwe: Field = mem.read(1073741828)
        let r_dim: Field = mem.read(1073741829)
        let r_stride: Field = mem.read(1073741830)
        let row_w_addr: Field = r_w + row_idx * r_dim
        let row_out_addr: Field = r_out + row_idx * r_stride
        private_dot(r_cts, row_w_addr, row_out_addr, r_tmp, r_lwe, r_dim)
    }
}

Local Graph