// ---
// tags: jali, trident
// crystal-type: circuit
// crystal-domain: comp
// ---
// Negacyclic NTT for R_q = F_p[x]/(x^1024 + 1).
//
// The twisting trick converts negacyclic convolution (mod x^n+1) into
// standard cyclic convolution (mod x^n-1):
// forward: twist by psi^i, then standard NTT
// inverse: standard INTT, then untwist by psi^{-i}, scale by 1/n
//
// All arithmetic is native Field ops -- Field IS Goldilocks, so roots
// of unity, twiddle factors, and butterflies are all native.
//
// Fixed n = 1024, log2(n) = 10 butterfly stages.
module jali.ntt
/// Primitive 2048-th root of unity: psi = 7^((p-1)/2048) mod p.
/// psi^2048 = 1, psi^1024 = -1 (negacyclic property).
const PSI: Field = 455906449640507599
/// Inverse of PSI: psi^{-1} mod p.
const PSI_INV: Field = 8548973421900915981
/// Primitive 1024-th root of unity: omega = psi^2.
/// Used as the base twiddle factor for standard NTT butterflies.
const OMEGA: Field = 11353340290879379826
/// Inverse of n = 1024 in the field: used to scale after inverse NTT.
const N_INV: Field = 18428729670909296641
/// Forward negacyclic NTT: coefficient form -> NTT evaluation form.
///
/// Steps:
/// 1. Twist: data[i] *= psi^i (absorbs the x^n+1 reduction)
/// 2. Radix-2 DIT butterfly network, 10 stages for n=1024
pub fn forward(data: [Field; 1024]) -> [Field; 1024] {
let mut a: [Field; 1024] = [0; 1024]
// Step 1: twist -- multiply each coefficient by psi^i
let mut psi_pow: Field = 1
for i in 0..1024 {
a[i] = data[i] * psi_pow
psi_pow = psi_pow * PSI
}
// Step 2: standard radix-2 DIT NTT (10 stages, Cooley-Tukey)
//
// Precompute twiddle bases for each stage via repeated squaring.
// twiddle_bases[s] = omega^(2^(9-s)) for s = 0..9:
// twiddle_bases[0] = omega^512 (coarsest stage, half=1)
// twiddle_bases[9] = omega (finest stage, half=512)
let mut twiddle_bases: [Field; 10] = [0; 10]
twiddle_bases[9] = OMEGA
// twiddle_bases[8] = omega^2, ..., twiddle_bases[0] = omega^512
let mut pw: Field = OMEGA
for s in 0..9 {
pw = pw * pw
twiddle_bases[8 - s] = pw
}
// 10 butterfly stages
let mut half: U32 = 1
for stage in 0..10 {
let w_step: Field = twiddle_bases[stage]
let group_size: U32 = half + half
// For each twiddle index j in 0..half
let mut w: Field = 1
for j in 0..512 {
if j < half {
// Butterfly across all groups with this j offset
let mut k: U32 = j
for g in 0..1024 {
if k < 1024 {
let lo: Field = a[k]
let hi: Field = a[k + half]
let t: Field = hi * w
a[k] = lo + t
a[k + half] = lo + field.neg(t)
k = k + group_size
}
}
w = w * w_step
}
}
half = half + half
}
a
}
/// Inverse negacyclic NTT: NTT evaluation form -> coefficient form.
///
/// Steps:
/// 1. Inverse radix-2 NTT (same butterfly with omega^{-1})
/// 2. Untwist: data[i] *= psi^{-i}
/// 3. Scale by 1/n
pub fn inverse(data: [Field; 1024]) -> [Field; 1024] {
let mut a: [Field; 1024] = data
// Step 1: standard inverse NTT using omega^{-1} as base
// Compute omega_inv = omega^{-1} mod p
// omega_inv = omega^(p-2) mod p, but we can derive it: omega_inv = (psi^2)^{-1} = psi_inv^2
let omega_inv: Field = PSI_INV * PSI_INV
// Precompute twiddle bases for inverse: omega_inv^(2^(9-s))
let mut inv_twiddle_bases: [Field; 10] = [0; 10]
inv_twiddle_bases[9] = omega_inv
let mut pw: Field = omega_inv
for s in 0..9 {
pw = pw * pw
inv_twiddle_bases[8 - s] = pw
}
let mut half: U32 = 1
for stage in 0..10 {
let w_step: Field = inv_twiddle_bases[stage]
let group_size: U32 = half + half
let mut w: Field = 1
for j in 0..512 {
if j < half {
let mut k: U32 = j
for g in 0..1024 {
if k < 1024 {
let lo: Field = a[k]
let hi: Field = a[k + half]
let t: Field = hi * w
a[k] = lo + t
a[k + half] = lo + field.neg(t)
k = k + group_size
}
}
w = w * w_step
}
}
half = half + half
}
// Step 2: untwist and scale by 1/n in one pass
// result[i] = a[i] * psi_inv^i * n_inv
let mut psi_inv_pow: Field = 1
for i in 0..1024 {
a[i] = a[i] * psi_inv_pow * N_INV
psi_inv_pow = psi_inv_pow * PSI_INV
}
a
}
jali/tri/ntt.tri
ฯ 0.0%