module std.private.poly

// Polynomial arithmetic over the Goldilocks field for FHE.
//
// All FHE schemes (TFHE, BGV, CKKS) operate on polynomial rings
// R_q = F_p[X]/(X^N + 1). The core operations โ€” polynomial multiply,
// NTT, modular reduction โ€” map directly to Goldilocks arithmetic.
//
// Goldilocks (p = 2^64 - 2^32 + 1) has 2^32 roots of unity,
// making NTT native and efficient. This is not a coincidence โ€”
// the field was chosen for exactly this property.
//
// Polynomials stored in RAM as coefficient arrays: a[i] = coefficient of X^i.
use vm.core.field

use vm.core.convert

use vm.io.mem

// ---------------------------------------------------------------------------
// Horner evaluation: P(x) = a[n-1]*x^(n-1) + ... + a[1]*x + a[0]
// Evaluates polynomial at a given point using Horner's method.
// Coefficients at coeff_addr in RAM, degree n.
// ---------------------------------------------------------------------------
pub fn eval(coeff_addr: Field, x: Field, n: Field) -> Field {
    // Horner: start from highest coefficient, multiply by x, add next
    let mut result: Field = 0 // will compute 2^log_n
    for i in 0..n bounded 4096 {
        // Process from high to low: index = n - 1 - i
        let idx: Field = n + field.neg(convert.as_field(i)) + field.neg(1)
        let coeff: Field = mem.read(coeff_addr + idx)
        result = result * x + coeff
    }
    result
}

// ---------------------------------------------------------------------------
// Coefficient-wise addition: c[i] = a[i] + b[i]
// ---------------------------------------------------------------------------
pub fn add(a_addr: Field, b_addr: Field, out_addr: Field, n: Field) {
    for i in 0..n bounded 4096 {
        let idx: Field = convert.as_field(i)
        let a_val: Field = mem.read(a_addr + idx)
        let b_val: Field = mem.read(b_addr + idx)
        mem.write(out_addr + idx, a_val + b_val)
    }
}

// ---------------------------------------------------------------------------
// Coefficient-wise subtraction: c[i] = a[i] - b[i]
// ---------------------------------------------------------------------------
pub fn sub(a_addr: Field, b_addr: Field, out_addr: Field, n: Field) {
    for i in 0..n bounded 4096 {
        let idx: Field = convert.as_field(i)
        let a_val: Field = mem.read(a_addr + idx)
        let b_val: Field = mem.read(b_addr + idx)
        mem.write(out_addr + idx, a_val + field.neg(b_val))
    }
}

// ---------------------------------------------------------------------------
// Coefficient-wise (pointwise) multiplication: c[i] = a[i] * b[i]
// Used in NTT domain where polynomial multiply becomes pointwise.
// ---------------------------------------------------------------------------
pub fn pointwise_mul(a_addr: Field, b_addr: Field, out_addr: Field, n: Field) {
    for i in 0..n bounded 4096 {
        let idx: Field = convert.as_field(i)
        let a_val: Field = mem.read(a_addr + idx)
        let b_val: Field = mem.read(b_addr + idx)
        mem.write(out_addr + idx, a_val * b_val)
    }
}

// ---------------------------------------------------------------------------
// Scalar multiply: c[i] = a[i] * s
// ---------------------------------------------------------------------------
pub fn scale(a_addr: Field, s: Field, out_addr: Field, n: Field) {
    for i in 0..n bounded 4096 {
        let idx: Field = convert.as_field(i)
        let val: Field = mem.read(a_addr + idx)
        mem.write(out_addr + idx, val * s)
    }
}

// ---------------------------------------------------------------------------
// Number Theoretic Transform (iterative Cooley-Tukey, in-place).
//
// Transforms polynomial from coefficient form to evaluation form.
// a_addr: N coefficients in RAM (N must be a power of 2)
// omega: primitive N-th root of unity in Goldilocks
// log_n: log2(N) โ€” number of butterfly stages
//
// After NTT, polynomial multiply is pointwise multiply.
// This is the core primitive that makes FHE practical.
// ---------------------------------------------------------------------------
pub fn ntt(a_addr: Field, omega: Field, log_n: Field) {
    // Bit-reversal permutation
    // Compute n = 2^log_n iteratively
    let mut size: Field = 1
    for s in 0..log_n bounded 32 {
        size = size + size
    }
    // size is now N = 2^log_n
    // Butterfly stages
    let mut len: Field = 1
    let mut w_step: Field = omega
    // Precompute omega powers for each stage
    for stage in 0..log_n bounded 32 {
        // At this stage, butterfly size = 2*len
        let half: Field = len
        // w_step = omega^(N / (2*len))
        // For first stage len=1: w_step = omega^(N/2)
        // We compute the twiddle base for this stage
        let mut w: Field = 1
        for j in 0..half bounded 4096 {
            let j_idx: Field = convert.as_field(j)
            // For each butterfly group
            let mut k: Field = j_idx
            let group_size: Field = len + len
            for g in 0..size bounded 4096 {
                // Only process if k < size
                let idx_lo: Field = a_addr + k
                let idx_hi: Field = a_addr + k + half
                let lo: Field = mem.read(idx_lo)
                let hi: Field = mem.read(idx_hi)
                let t: Field = hi * w
                mem.write(idx_lo, lo + t)
                mem.write(idx_hi, lo + field.neg(t))
                k = k + group_size
            }
            w = w * w_step
        }
        len = len + len
        // Square the twiddle step for next stage
        w_step = w_step * w_step
    }
}

// ---------------------------------------------------------------------------
// Inverse NTT: evaluation form โ†’ coefficient form.
// a_addr: N evaluations in RAM
// omega_inv: inverse of the primitive N-th root of unity
// n_inv: multiplicative inverse of N in the field
// log_n: log2(N)
// ---------------------------------------------------------------------------
pub fn intt(a_addr: Field, omega_inv: Field, n_inv: Field, log_n: Field) {
    ntt(a_addr, omega_inv, log_n)
    // Scale by 1/N
    let mut size: Field = 1
    for s in 0..log_n bounded 32 {
        size = size + size
    }
    scale(a_addr, n_inv, a_addr, size)
}

// ---------------------------------------------------------------------------
// Polynomial multiply via NTT.
// a_addr, b_addr: input polynomials (N coefficients each)
// out_addr: result polynomial (N coefficients)
// tmp_addr: scratch space (N elements)
// omega: primitive N-th root of unity
// omega_inv: inverse root
// n_inv: inverse of N
// log_n: log2(N)
//
// Steps: NTT(a), NTT(b), pointwise multiply, INTT(result)
// ---------------------------------------------------------------------------
pub fn poly_mul(
    a_addr: Field,
    b_addr: Field,
    out_addr: Field,
    tmp_addr: Field,
    omega: Field,
    omega_inv: Field,
    n_inv: Field,
    log_n: Field
) {
    let mut size: Field = 1
    for s in 0..log_n bounded 32 {
        size = size + size
    }
    // Copy a to out, b to tmp (NTT is in-place)
    for i in 0..size bounded 4096 {
        let idx: Field = convert.as_field(i)
        mem.write(out_addr + idx, mem.read(a_addr + idx))
        mem.write(tmp_addr + idx, mem.read(b_addr + idx))
    }
    // Forward NTT on both
    ntt(out_addr, omega, log_n)
    ntt(tmp_addr, omega, log_n)
    // Pointwise multiply
    pointwise_mul(out_addr, tmp_addr, out_addr, size)
    // Inverse NTT
    intt(out_addr, omega_inv, n_inv, log_n)
}

Local Graph