// ---
// tags: jali, trident
// crystal-type: circuit
// crystal-domain: comp
// ---

//  Negacyclic NTT for R_q = F_p[x]/(x^1024 + 1).
//
//  The twisting trick converts negacyclic convolution (mod x^n+1) into
//  standard cyclic convolution (mod x^n-1):
//    forward: twist by psi^i, then standard NTT
//    inverse: standard INTT, then untwist by psi^{-i}, scale by 1/n
//
//  All arithmetic is native Field ops -- Field IS Goldilocks, so roots
//  of unity, twiddle factors, and butterflies are all native.
//
//  Fixed n = 1024, log2(n) = 10 butterfly stages.

module jali.ntt

/// Primitive 2048-th root of unity: psi = 7^((p-1)/2048) mod p.
/// psi^2048 = 1, psi^1024 = -1 (negacyclic property).
const PSI: Field = 455906449640507599

/// Inverse of PSI: psi^{-1} mod p.
const PSI_INV: Field = 8548973421900915981

/// Primitive 1024-th root of unity: omega = psi^2.
/// Used as the base twiddle factor for standard NTT butterflies.
const OMEGA: Field = 11353340290879379826

/// Inverse of n = 1024 in the field: used to scale after inverse NTT.
const N_INV: Field = 18428729670909296641

/// Forward negacyclic NTT: coefficient form -> NTT evaluation form.
///
/// Steps:
///   1. Twist: data[i] *= psi^i  (absorbs the x^n+1 reduction)
///   2. Radix-2 DIT butterfly network, 10 stages for n=1024
pub fn forward(data: [Field; 1024]) -> [Field; 1024] {
    let mut a: [Field; 1024] = [0; 1024]

    // Step 1: twist -- multiply each coefficient by psi^i
    let mut psi_pow: Field = 1
    for i in 0..1024 {
        a[i] = data[i] * psi_pow
        psi_pow = psi_pow * PSI
    }

    // Step 2: standard radix-2 DIT NTT (10 stages, Cooley-Tukey)
    //
    // Precompute twiddle bases for each stage via repeated squaring.
    // twiddle_bases[s] = omega^(2^(9-s)) for s = 0..9:
    //   twiddle_bases[0] = omega^512 (coarsest stage, half=1)
    //   twiddle_bases[9] = omega      (finest stage, half=512)
    let mut twiddle_bases: [Field; 10] = [0; 10]
    twiddle_bases[9] = OMEGA
    // twiddle_bases[8] = omega^2, ..., twiddle_bases[0] = omega^512
    let mut pw: Field = OMEGA
    for s in 0..9 {
        pw = pw * pw
        twiddle_bases[8 - s] = pw
    }

    // 10 butterfly stages
    let mut half: U32 = 1
    for stage in 0..10 {
        let w_step: Field = twiddle_bases[stage]
        let group_size: U32 = half + half
        // For each twiddle index j in 0..half
        let mut w: Field = 1
        for j in 0..512 {
            if j < half {
                // Butterfly across all groups with this j offset
                let mut k: U32 = j
                for g in 0..1024 {
                    if k < 1024 {
                        let lo: Field = a[k]
                        let hi: Field = a[k + half]
                        let t: Field = hi * w
                        a[k] = lo + t
                        a[k + half] = lo + field.neg(t)
                        k = k + group_size
                    }
                }
                w = w * w_step
            }
        }
        half = half + half
    }

    a
}

/// Inverse negacyclic NTT: NTT evaluation form -> coefficient form.
///
/// Steps:
///   1. Inverse radix-2 NTT (same butterfly with omega^{-1})
///   2. Untwist: data[i] *= psi^{-i}
///   3. Scale by 1/n
pub fn inverse(data: [Field; 1024]) -> [Field; 1024] {
    let mut a: [Field; 1024] = data

    // Step 1: standard inverse NTT using omega^{-1} as base
    // Compute omega_inv = omega^{-1} mod p
    // omega_inv = omega^(p-2) mod p, but we can derive it: omega_inv = (psi^2)^{-1} = psi_inv^2
    let omega_inv: Field = PSI_INV * PSI_INV

    // Precompute twiddle bases for inverse: omega_inv^(2^(9-s))
    let mut inv_twiddle_bases: [Field; 10] = [0; 10]
    inv_twiddle_bases[9] = omega_inv
    let mut pw: Field = omega_inv
    for s in 0..9 {
        pw = pw * pw
        inv_twiddle_bases[8 - s] = pw
    }

    let mut half: U32 = 1
    for stage in 0..10 {
        let w_step: Field = inv_twiddle_bases[stage]
        let group_size: U32 = half + half
        let mut w: Field = 1
        for j in 0..512 {
            if j < half {
                let mut k: U32 = j
                for g in 0..1024 {
                    if k < 1024 {
                        let lo: Field = a[k]
                        let hi: Field = a[k + half]
                        let t: Field = hi * w
                        a[k] = lo + t
                        a[k + half] = lo + field.neg(t)
                        k = k + group_size
                    }
                }
                w = w * w_step
            }
        }
        half = half + half
    }

    // Step 2: untwist and scale by 1/n in one pass
    // result[i] = a[i] * psi_inv^i * n_inv
    let mut psi_inv_pow: Field = 1
    for i in 0..1024 {
        a[i] = a[i] * psi_inv_pow * N_INV
        psi_inv_pow = psi_inv_pow * PSI_INV
    }

    a
}

Dimensions

nebu/tri/ntt.tri

Local Graph