module std.fhe.rlwe

// Ring-LWE encryption over R_q = F_p[X]/(X^N + 1).
//
// Ciphertext (a(X), b(X)) where b(X) = a(X)*s(X) + m(X)*delta + e(X)
// modulo (X^N + 1). Ciphertext modulus q = p (Goldilocks), no impedance
// mismatch between FHE ring and STARK field.
//
// All ring operations use NTT from std.private.poly. Goldilocks has
// 2^32 roots of unity, making NTT native and efficient.
//
// Memory layout for one RLWE ciphertext of ring dimension N:
//   ct_addr + 0    .. ct_addr + N-1    : polynomial a(X), N coefficients
//   ct_addr + N    .. ct_addr + 2N-1   : polynomial b(X), N coefficients
//   Stride per ciphertext: 2N
//
// Pitch parameters: N = 64 (structurally identical to production N = 1024+).
use vm.core.convert

use vm.core.assert

use vm.io.io

use vm.io.mem

use std.private.poly

// ---------------------------------------------------------------------------
// Encrypt: produce RLWE ciphertext ct = (a, b).
// ---------------------------------------------------------------------------
// b(X) = a(X)*s(X) + m(X)*delta + e(X)  mod (X^N + 1)
//
// a_addr: prover-supplied random polynomial (N coefficients)
// s_addr: secret key polynomial (N coefficients)
// m_addr: message polynomial (N coefficients)
// ct_addr: output ciphertext (2N coefficients: a then b)
// delta: scaling factor (p / plaintext_space)
// e_addr: prover-supplied noise polynomial (N coefficients)
// n: ring dimension N
// omega, omega_inv, n_inv, log_n: NTT parameters
// tmp1_addr, tmp2_addr: scratch (N elements each)
pub fn encrypt(
    a_addr: Field,
    s_addr: Field,
    m_addr: Field,
    ct_addr: Field,
    delta: Field,
    e_addr: Field,
    n: Field,
    omega: Field,
    omega_inv: Field,
    n_inv: Field,
    log_n: Field,
    tmp1_addr: Field,
    tmp2_addr: Field
) {
    let b_addr: Field = ct_addr + n
    // Copy a to ct_addr (a part of ciphertext)
    for i in 0..n bounded 4096 {
        let idx: Field = convert.as_field(i)
        mem.write(ct_addr + idx, mem.read(a_addr + idx))
    }
    // b = a * s  (polynomial multiply via NTT)
    poly.poly_mul(a_addr, s_addr, b_addr, tmp1_addr, omega, omega_inv, n_inv, log_n)
    // b += m * delta
    for i in 0..n bounded 4096 {
        let idx: Field = convert.as_field(i)
        let b_val: Field = mem.read(b_addr + idx)
        let m_val: Field = mem.read(m_addr + idx)
        mem.write(b_addr + idx, b_val + m_val * delta)
    }
    // b += e
    poly.add(b_addr, e_addr, b_addr, n)
}

// ---------------------------------------------------------------------------
// Decrypt: recover message polynomial from RLWE ciphertext.
// ---------------------------------------------------------------------------
// phase(X) = b(X) - a(X)*s(X)  = m(X)*delta + e(X)
// For each coefficient: m_i = round(phase_i / delta) via divine().
//
// ct_addr: ciphertext (2N coefficients)
// s_addr: secret key polynomial (N coefficients)
// out_addr: decrypted message polynomial (N coefficients)
// delta: scaling factor
// n, omega, omega_inv, n_inv, log_n: NTT parameters
// tmp1_addr, tmp2_addr: scratch (N elements each)
pub fn decrypt(
    ct_addr: Field,
    s_addr: Field,
    out_addr: Field,
    delta: Field,
    n: Field,
    omega: Field,
    omega_inv: Field,
    n_inv: Field,
    log_n: Field,
    tmp1_addr: Field,
    tmp2_addr: Field
) {
    let b_addr: Field = ct_addr + n
    // phase = b - a * s
    // First compute a * s -> tmp1
    poly.poly_mul(ct_addr, s_addr, tmp1_addr, tmp2_addr, omega, omega_inv, n_inv, log_n)
    // phase = b - a*s -> tmp1
    poly.sub(b_addr, tmp1_addr, tmp1_addr, n)
    // For each coefficient: divine the plaintext, verify noise bound
    let half_delta: Field = delta * field.inv(2)
    let (half_hi, half_lo) = convert.split(half_delta)
    for i in 0..n bounded 4096 {
        let idx: Field = convert.as_field(i)
        let phase_i: Field = mem.read(tmp1_addr + idx)
        // Prover supplies candidate plaintext coefficient
        let m_i: Field = io.divine()
        // Verify: noise = phase - m * delta, |noise| < delta/2
        let noise: Field = phase_i + field.neg(m_i * delta)
        let neg_noise: Field = field.neg(noise)
        let (noise_hi, noise_lo) = convert.split(noise)
        let (neg_hi, neg_lo) = convert.split(neg_noise)
        let small_pos: Bool = noise_hi < half_hi
        let small_neg: Bool = neg_hi < half_hi
        if small_pos {
            mem.write(out_addr + idx, m_i)
        } else {
            assert.is_true(small_neg)
            mem.write(out_addr + idx, m_i)
        }
    }
}

// ---------------------------------------------------------------------------
// Homomorphic addition: ct_out = ct1 + ct2.
// ---------------------------------------------------------------------------
// Adds both polynomial components (a and b) coefficient-wise.
pub fn add(
    ct1_addr: Field,
    ct2_addr: Field,
    out_addr: Field,
    n: Field
) {
    // Add a-components
    poly.add(ct1_addr, ct2_addr, out_addr, n)
    // Add b-components
    let b1: Field = ct1_addr + n
    let b2: Field = ct2_addr + n
    let b_out: Field = out_addr + n
    poly.add(b1, b2, b_out, n)
}

// ---------------------------------------------------------------------------
// Homomorphic scalar multiply: ct_out = k * ct.
// ---------------------------------------------------------------------------
// Scales both polynomial components by a scalar.
pub fn scale_ct(
    ct_addr: Field,
    k: Field,
    out_addr: Field,
    n: Field
) {
    poly.scale(ct_addr, k, out_addr, n)
    let b_in: Field = ct_addr + n
    let b_out: Field = out_addr + n
    poly.scale(b_in, k, b_out, n)
}

// ---------------------------------------------------------------------------
// External product: multiply RLWE ciphertext by a plaintext polynomial.
// ---------------------------------------------------------------------------
// ct_out = poly * ct (multiply both components of ct by poly via NTT).
// This is the core building block for blind rotation (CMux gate).
pub fn external_product(
    ct_addr: Field,
    poly_addr: Field,
    out_addr: Field,
    n: Field,
    omega: Field,
    omega_inv: Field,
    n_inv: Field,
    log_n: Field,
    tmp1_addr: Field,
    tmp2_addr: Field
) {
    // a_out = poly * a
    poly.poly_mul(ct_addr, poly_addr, out_addr, tmp1_addr, omega, omega_inv, n_inv, log_n)
    // b_out = poly * b
    let b_in: Field = ct_addr + n
    let b_out: Field = out_addr + n
    poly.poly_mul(b_in, poly_addr, b_out, tmp2_addr, omega, omega_inv, n_inv, log_n)
}

Local Graph