// ---
// tags: nebu, trident
// crystal-type: source
// crystal-domain: comp
// ---
// Quartic extension field Fp4 = Fp[w] / (w^4 - 7).
//
// Elements are (c0, c1, c2, c3) representing c0 + c1*w + c2*w^2 + c3*w^3.
// Reduction rule: w^4 = 7.
//
// Tower decomposition: Fp4 = Fp2[v] / (v^2 - u) where u^2 = 7.

module nebu.fp4

use nebu.fp2

// -- Type definition ---------------------------------------------------------

pub struct Fp4 {
    c0: Field,
    c1: Field,
    c2: Field,
    c3: Field,
}

// -- Constants ---------------------------------------------------------------

pub const ZERO: Fp4 = Fp4 { c0: 0, c1: 0, c2: 0, c3: 0 };
pub const ONE: Fp4 = Fp4 { c0: 1, c1: 0, c2: 0, c3: 0 };
const SEVEN: Field = 7;

/// Frobenius constant: 7^((p-1)/4) = 2^48.
const W_FROB: Field = 281474976710656;  // 0x0001000000000000

// -- Construction ------------------------------------------------------------

#[pure]
pub fn new(c0: Field, c1: Field, c2: Field, c3: Field) -> Fp4 {
    Fp4 { c0: c0, c1: c1, c2: c2, c3: c3 }
}

/// Embed a base field element as (a, 0, 0, 0).
#[pure]
pub fn from_base(a: Field) -> Fp4 {
    Fp4 { c0: a, c1: 0, c2: 0, c3: 0 }
}

/// Embed an Fp2 element via tower: (re, im) -> (re, 0, im, 0).
#[pure]
pub fn from_fp2(x: fp2.Fp2) -> Fp4 {
    Fp4 { c0: x.re, c1: 0, c2: x.im, c3: 0 }
}

// -- Arithmetic --------------------------------------------------------------

#[pure]
pub fn add(a: Fp4, b: Fp4) -> Fp4 {
    Fp4 {
        c0: a.c0 + b.c0,
        c1: a.c1 + b.c1,
        c2: a.c2 + b.c2,
        c3: a.c3 + b.c3,
    }
}

#[pure]
pub fn fp4_sub(a: Fp4, b: Fp4) -> Fp4 {
    Fp4 {
        c0: sub(a.c0, b.c0),
        c1: sub(a.c1, b.c1),
        c2: sub(a.c2, b.c2),
        c3: sub(a.c3, b.c3),
    }
}

#[pure]
pub fn fp4_neg(a: Fp4) -> Fp4 {
    Fp4 {
        c0: neg(a.c0),
        c1: neg(a.c1),
        c2: neg(a.c2),
        c3: neg(a.c3),
    }
}

/// Schoolbook multiplication with w^4 = 7 reduction.
/// 16 base muls + 3 mul-by-7 + 9 adds.
#[pure]
pub fn mul(a: Fp4, b: Fp4) -> Fp4 {
    let d0: Field = a.c0 * b.c0;
    let d1: Field = a.c0 * b.c1 + a.c1 * b.c0;
    let d2: Field = a.c0 * b.c2 + a.c1 * b.c1 + a.c2 * b.c0;
    let d3: Field = a.c0 * b.c3 + a.c1 * b.c2 + a.c2 * b.c1 + a.c3 * b.c0;
    let d4: Field = a.c1 * b.c3 + a.c2 * b.c2 + a.c3 * b.c1;
    let d5: Field = a.c2 * b.c3 + a.c3 * b.c2;
    let d6: Field = a.c3 * b.c3;

    Fp4 {
        c0: d0 + SEVEN * d4,
        c1: d1 + SEVEN * d5,
        c2: d2 + SEVEN * d6,
        c3: d3,
    }
}

/// Scalar multiplication: Fp4 * Field.
#[pure]
pub fn scale(a: Fp4, s: Field) -> Fp4 {
    Fp4 { c0: a.c0 * s, c1: a.c1 * s, c2: a.c2 * s, c3: a.c3 * s }
}

/// Squaring using w^4 = 7.
#[pure]
pub fn sqr(a: Fp4) -> Fp4 {
    let s0: Field = a.c0 * a.c0;
    let s1: Field = a.c0 * a.c1 + a.c0 * a.c1;
    let s2: Field = a.c1 * a.c1 + a.c0 * a.c2 + a.c0 * a.c2;
    let s3: Field = a.c0 * a.c3 + a.c1 * a.c2;
    let s3: Field = s3 + s3;
    let s4: Field = a.c2 * a.c2 + a.c1 * a.c3 + a.c1 * a.c3;
    let s5: Field = a.c2 * a.c3 + a.c2 * a.c3;
    let s6: Field = a.c3 * a.c3;

    Fp4 {
        c0: s0 + SEVEN * s4,
        c1: s1 + SEVEN * s5,
        c2: s2 + SEVEN * s6,
        c3: s3,
    }
}

/// Tower conjugate: (A + Bv) -> (A - Bv) = (c0, -c1, c2, -c3).
#[pure]
pub fn conj(a: Fp4) -> Fp4 {
    Fp4 { c0: a.c0, c1: neg(a.c1), c2: a.c2, c3: neg(a.c3) }
}

/// Extract tower Fp2 components: A = (c0, c2), B = (c1, c3).
#[pure]
pub fn to_fp2_pair(a: Fp4) -> (fp2.Fp2, fp2.Fp2) {
    (fp2.new(a.c0, a.c2), fp2.new(a.c1, a.c3))
}

/// Norm to Fp2: N = A^2 - u*B^2 where A,B are Fp2 tower components.
/// u*x for Fp2 element x=(re,im): u*(re+im*u) = 7*im + re*u.
#[pure]
pub fn norm_fp2(a: Fp4) -> fp2.Fp2 {
    let tower_a: fp2.Fp2 = fp2.new(a.c0, a.c2);
    let tower_b: fp2.Fp2 = fp2.new(a.c1, a.c3);
    let a_sq: fp2.Fp2 = fp2.sqr(tower_a);
    let b_sq: fp2.Fp2 = fp2.sqr(tower_b);
    // u * b_sq: (re, im) -> (7*im, re)
    let u_b_sq: fp2.Fp2 = fp2.new(SEVEN * b_sq.im, b_sq.re);
    fp2.fp2_sub(a_sq, u_b_sq)
}

/// Norm to Fp: compose Fp4->Fp2->Fp norms.
#[pure]
pub fn norm(a: Fp4) -> Field {
    fp2.norm(norm_fp2(a))
}

/// Inversion via tower norm.
#[pure]
pub fn fp4_inv(a: Fp4) -> Fp4 {
    let tower_a: fp2.Fp2 = fp2.new(a.c0, a.c2);
    let tower_b: fp2.Fp2 = fp2.new(a.c1, a.c3);
    let n: fp2.Fp2 = norm_fp2(a);
    let n_inv: fp2.Fp2 = fp2.fp2_inv(n);

    // result = conj * n_inv = (A*n_inv, -B*n_inv)
    let r_a: fp2.Fp2 = fp2.mul(tower_a, n_inv);
    let r_b: fp2.Fp2 = fp2.fp2_neg(fp2.mul(tower_b, n_inv));
    Fp4 { c0: r_a.re, c1: r_b.re, c2: r_a.im, c3: r_b.im }
}

/// Frobenius endomorphism: sigma(w) = 2^48 * w.
/// sigma(c0 + c1*w + c2*w^2 + c3*w^3) = c0 + 2^48*c1*w - c2*w^2 - 2^48*c3*w^3
#[pure]
pub fn frobenius(a: Fp4) -> Fp4 {
    Fp4 {
        c0: a.c0,
        c1: W_FROB * a.c1,
        c2: neg(a.c2),
        c3: neg(W_FROB * a.c3),
    }
}

/// Check equality.
#[pure]
pub fn eq(a: Fp4, b: Fp4) -> bool {
    a.c0 == b.c0 & a.c1 == b.c1 & a.c2 == b.c2 & a.c3 == b.c3
}

Local Graph