// ---
// tags: nebu, trident
// crystal-type: source
// crystal-domain: comp
// ---
// Number Theoretic Transform over the Goldilocks field.
//
// Forward NTT: Cooley-Tukey decimation-in-time.
// Inverse NTT: Gentleman-Sande decimation-in-frequency.
// Primitive root g = 7.
module nebu.ntt
use nebu.field
// -- Helpers -----------------------------------------------------------------
/// Reverse the lowest `k` bits of `val` (k <= 16).
#[pure]
fn bit_reverse(val: U32, k: U32) -> U32 {
let mut result: U32 = 0;
let mut v: U32 = val;
for _b in 0..16 bounded 16 {
if _b < k {
result = (result * 2) + (v & 1);
let pair: (U32, U32) = v /% 2;
v = pair.0;
}
}
result
}
/// Compute g^((p-1)/N) where p-1 = 0xFFFFFFFF00000000.
/// N must be a power of 2. Returns the principal N-th root of unity.
#[pure]
fn root_of_unity(log_n: U32) -> Field {
// g = 7, exponent = (p-1)/N = (p-1) >> log_n
// We compute 7^((p-1)/N) via repeated squaring.
// (p-1)/2 has exponent with 63 bits set except bit 32.
// Strategy: start with g, square (64 - log_n) times to get g^(2^(64-log_n))
// but that doesn't work directly. Instead use field.exp.
//
// (p-1) = 2^64 - 2^32 = 2^32 * (2^32 - 1)
// (p-1)/N = 2^(32-log_n) * (2^32 - 1) when log_n <= 32
//
// We compute 7^(2^32-1) first, then square (32-log_n) times.
let s: Field = field.exp(7, 4294967295); // 7^(2^32 - 1)
let mut w: Field = s;
// Square (32 - log_n) times
let sq_count: U32 = 32 - log_n;
for _i in 0..32 bounded 32 {
if _i < sq_count {
w = w * w;
}
}
w
}
// -- Fixed-size NTT (N = 1024, log_n = 10) ----------------------------------
/// Forward NTT, length 1024. Cooley-Tukey DIT.
/// Input: natural order. Output: natural order.
pub fn ntt_1024(data: [Field; 1024]) -> [Field; 1024] {
let mut a: [Field; 1024] = data;
// Bit-reversal permutation (10-bit)
for i in 0..1024 bounded 1024 {
let j: U32 = bit_reverse(i, 10);
if i < j {
let tmp: Field = a[i];
a[i] = a[j];
a[j] = tmp;
}
}
// Butterfly stages: s = 0..9 (10 stages)
// Stage s: block size m = 2^(s+1), half = 2^s
// omega_m = root_of_unity(s+1)
// Stage 0: m=2, half=1
let omega: Field = root_of_unity(1);
for j in 0..512 bounded 512 {
let base: U32 = j * 2;
let t: Field = a[base + 1];
a[base + 1] = sub(a[base], t);
a[base] = a[base] + t;
}
// Stage 1: m=4, half=2
let omega: Field = root_of_unity(2);
for j in 0..256 bounded 256 {
let base: U32 = j * 4;
let mut w: Field = 1;
for i in 0..2 bounded 2 {
let t: Field = w * a[base + i + 2];
a[base + i + 2] = sub(a[base + i], t);
a[base + i] = a[base + i] + t;
w = w * omega;
}
}
// Stages 2..9: general butterfly
for s in 2..10 bounded 8 {
let log_m: U32 = s + 1;
let omega_m: Field = root_of_unity(log_m);
let half: U32 = 1 << s;
let m: U32 = 1 << log_m;
let groups: U32 = 1024 / m;
for g in 0..512 bounded 512 {
if g < groups {
let base: U32 = g * m;
let mut w: Field = 1;
for i in 0..512 bounded 512 {
if i < half {
let t: Field = w * a[base + i + half];
a[base + i + half] = sub(a[base + i], t);
a[base + i] = a[base + i] + t;
w = w * omega_m;
}
}
}
}
}
a
}
/// Inverse NTT, length 1024. Gentleman-Sande DIF.
/// Input: natural order. Output: natural order.
/// Includes N^{-1} scaling.
pub fn intt_1024(data: [Field; 1024]) -> [Field; 1024] {
let mut a: [Field; 1024] = data;
// DIF stages: s = 9..0 (reversed)
for s_rev in 0..10 bounded 10 {
let s: U32 = 9 - s_rev;
let log_m: U32 = s + 1;
let half: U32 = 1 << s;
let m: U32 = 1 << log_m;
let groups: U32 = 1024 / m;
// omega_m_inv = inverse of the m-th root of unity
// = root_of_unity(log_m)^(m-1)
let omega_m: Field = root_of_unity(log_m);
let omega_m_inv: Field = field.exp(omega_m, (1 << log_m) - 1);
for g in 0..512 bounded 512 {
if g < groups {
let base: U32 = g * m;
let mut w: Field = 1;
for i in 0..512 bounded 512 {
if i < half {
let u: Field = a[base + i];
let v: Field = a[base + i + half];
a[base + i] = u + v;
a[base + i + half] = w * sub(u, v);
w = w * omega_m_inv;
}
}
}
}
}
// Bit-reversal permutation (10-bit)
for i in 0..1024 bounded 1024 {
let j: U32 = bit_reverse(i, 10);
if i < j {
let tmp: Field = a[i];
a[i] = a[j];
a[j] = tmp;
}
}
// Scale by N^{-1} mod p
let n_inv: Field = inv(1024);
for i in 0..1024 bounded 1024 {
a[i] = a[i] * n_inv;
}
a
}
nebu/tri/ntt.tri
ฯ 0.0%