// ---
// tags: nebu, trident
// crystal-type: source
// crystal-domain: comp
// ---
// Number Theoretic Transform over the Goldilocks field.
//
// Forward NTT: Cooley-Tukey decimation-in-time.
// Inverse NTT: Gentleman-Sande decimation-in-frequency.
// Primitive root g = 7.

module nebu.ntt

use nebu.field

// -- Helpers -----------------------------------------------------------------

/// Reverse the lowest `k` bits of `val` (k <= 16).
#[pure]
fn bit_reverse(val: U32, k: U32) -> U32 {
    let mut result: U32 = 0;
    let mut v: U32 = val;
    for _b in 0..16 bounded 16 {
        if _b < k {
            result = (result * 2) + (v & 1);
            let pair: (U32, U32) = v /% 2;
            v = pair.0;
        }
    }
    result
}

/// Compute g^((p-1)/N) where p-1 = 0xFFFFFFFF00000000.
/// N must be a power of 2. Returns the principal N-th root of unity.
#[pure]
fn root_of_unity(log_n: U32) -> Field {
    // g = 7, exponent = (p-1)/N = (p-1) >> log_n
    // We compute 7^((p-1)/N) via repeated squaring.
    // (p-1)/2 has exponent with 63 bits set except bit 32.
    // Strategy: start with g, square (64 - log_n) times to get g^(2^(64-log_n))
    // but that doesn't work directly. Instead use field.exp.
    //
    // (p-1) = 2^64 - 2^32 = 2^32 * (2^32 - 1)
    // (p-1)/N = 2^(32-log_n) * (2^32 - 1) when log_n <= 32
    //
    // We compute 7^(2^32-1) first, then square (32-log_n) times.
    let s: Field = field.exp(7, 4294967295);  // 7^(2^32 - 1)
    let mut w: Field = s;
    // Square (32 - log_n) times
    let sq_count: U32 = 32 - log_n;
    for _i in 0..32 bounded 32 {
        if _i < sq_count {
            w = w * w;
        }
    }
    w
}

// -- Fixed-size NTT (N = 1024, log_n = 10) ----------------------------------

/// Forward NTT, length 1024. Cooley-Tukey DIT.
/// Input: natural order. Output: natural order.
pub fn ntt_1024(data: [Field; 1024]) -> [Field; 1024] {
    let mut a: [Field; 1024] = data;

    // Bit-reversal permutation (10-bit)
    for i in 0..1024 bounded 1024 {
        let j: U32 = bit_reverse(i, 10);
        if i < j {
            let tmp: Field = a[i];
            a[i] = a[j];
            a[j] = tmp;
        }
    }

    // Butterfly stages: s = 0..9 (10 stages)
    // Stage s: block size m = 2^(s+1), half = 2^s
    // omega_m = root_of_unity(s+1)

    // Stage 0: m=2, half=1
    let omega: Field = root_of_unity(1);
    for j in 0..512 bounded 512 {
        let base: U32 = j * 2;
        let t: Field = a[base + 1];
        a[base + 1] = sub(a[base], t);
        a[base] = a[base] + t;
    }

    // Stage 1: m=4, half=2
    let omega: Field = root_of_unity(2);
    for j in 0..256 bounded 256 {
        let base: U32 = j * 4;
        let mut w: Field = 1;
        for i in 0..2 bounded 2 {
            let t: Field = w * a[base + i + 2];
            a[base + i + 2] = sub(a[base + i], t);
            a[base + i] = a[base + i] + t;
            w = w * omega;
        }
    }

    // Stages 2..9: general butterfly
    for s in 2..10 bounded 8 {
        let log_m: U32 = s + 1;
        let omega_m: Field = root_of_unity(log_m);
        let half: U32 = 1 << s;
        let m: U32 = 1 << log_m;
        let groups: U32 = 1024 / m;
        for g in 0..512 bounded 512 {
            if g < groups {
                let base: U32 = g * m;
                let mut w: Field = 1;
                for i in 0..512 bounded 512 {
                    if i < half {
                        let t: Field = w * a[base + i + half];
                        a[base + i + half] = sub(a[base + i], t);
                        a[base + i] = a[base + i] + t;
                        w = w * omega_m;
                    }
                }
            }
        }
    }

    a
}

/// Inverse NTT, length 1024. Gentleman-Sande DIF.
/// Input: natural order. Output: natural order.
/// Includes N^{-1} scaling.
pub fn intt_1024(data: [Field; 1024]) -> [Field; 1024] {
    let mut a: [Field; 1024] = data;

    // DIF stages: s = 9..0 (reversed)
    for s_rev in 0..10 bounded 10 {
        let s: U32 = 9 - s_rev;
        let log_m: U32 = s + 1;
        let half: U32 = 1 << s;
        let m: U32 = 1 << log_m;
        let groups: U32 = 1024 / m;
        // omega_m_inv = inverse of the m-th root of unity
        // = root_of_unity(log_m)^(m-1)
        let omega_m: Field = root_of_unity(log_m);
        let omega_m_inv: Field = field.exp(omega_m, (1 << log_m) - 1);

        for g in 0..512 bounded 512 {
            if g < groups {
                let base: U32 = g * m;
                let mut w: Field = 1;
                for i in 0..512 bounded 512 {
                    if i < half {
                        let u: Field = a[base + i];
                        let v: Field = a[base + i + half];
                        a[base + i] = u + v;
                        a[base + i + half] = w * sub(u, v);
                        w = w * omega_m_inv;
                    }
                }
            }
        }
    }

    // Bit-reversal permutation (10-bit)
    for i in 0..1024 bounded 1024 {
        let j: U32 = bit_reverse(i, 10);
        if i < j {
            let tmp: Field = a[i];
            a[i] = a[j];
            a[j] = tmp;
        }
    }

    // Scale by N^{-1} mod p
    let n_inv: Field = inv(1024);
    for i in 0..1024 bounded 1024 {
        a[i] = a[i] * n_inv;
    }

    a
}

Dimensions

jali/tri/ntt.tri

Local Graph