// ---
// tags: nebu, trident
// crystal-type: source
// crystal-domain: comp
// ---
// Cubic extension field Fp3 = Fp[t] / (t^3 - t - 1).
//
// Elements are (c0, c1, c2) representing c0 + c1*t + c2*t^2.
// Reduction rule: t^3 = t + 1.

module nebu.fp3

// -- Type definition ---------------------------------------------------------

pub struct Fp3 {
    c0: Field,
    c1: Field,
    c2: Field,
}

// -- Constants ---------------------------------------------------------------

pub const ZERO: Fp3 = Fp3 { c0: 0, c1: 0, c2: 0 };
pub const ONE: Fp3 = Fp3 { c0: 1, c1: 0, c2: 0 };

// -- Construction ------------------------------------------------------------

#[pure]
pub fn new(c0: Field, c1: Field, c2: Field) -> Fp3 {
    Fp3 { c0: c0, c1: c1, c2: c2 }
}

/// Embed a base field element as (a, 0, 0).
#[pure]
pub fn from_base(a: Field) -> Fp3 {
    Fp3 { c0: a, c1: 0, c2: 0 }
}

// -- Arithmetic --------------------------------------------------------------

#[pure]
pub fn add(a: Fp3, b: Fp3) -> Fp3 {
    Fp3 {
        c0: a.c0 + b.c0,
        c1: a.c1 + b.c1,
        c2: a.c2 + b.c2,
    }
}

#[pure]
pub fn fp3_sub(a: Fp3, b: Fp3) -> Fp3 {
    Fp3 {
        c0: sub(a.c0, b.c0),
        c1: sub(a.c1, b.c1),
        c2: sub(a.c2, b.c2),
    }
}

#[pure]
pub fn fp3_neg(a: Fp3) -> Fp3 {
    Fp3 { c0: neg(a.c0), c1: neg(a.c1), c2: neg(a.c2) }
}

/// Schoolbook multiplication with t^3 = t + 1 reduction.
/// 9 base muls + 6 adds.
///
/// Product coefficients before reduction:
///   d0 = a0*b0
///   d1 = a0*b1 + a1*b0
///   d2 = a0*b2 + a1*b1 + a2*b0
///   d3 = a1*b2 + a2*b1
///   d4 = a2*b2
///
/// Reduction (t^3 = t+1, t^4 = t^2+t):
///   c0 = d0 + d3
///   c1 = d1 + d3 + d4
///   c2 = d2 + d4
#[pure]
pub fn mul(a: Fp3, b: Fp3) -> Fp3 {
    let d0: Field = a.c0 * b.c0;
    let d1: Field = a.c0 * b.c1 + a.c1 * b.c0;
    let d2: Field = a.c0 * b.c2 + a.c1 * b.c1 + a.c2 * b.c0;
    let d3: Field = a.c1 * b.c2 + a.c2 * b.c1;
    let d4: Field = a.c2 * b.c2;

    Fp3 {
        c0: d0 + d3,
        c1: d1 + d3 + d4,
        c2: d2 + d4,
    }
}

/// Scalar multiplication: Fp3 * Field.
#[pure]
pub fn scale(a: Fp3, s: Field) -> Fp3 {
    Fp3 { c0: a.c0 * s, c1: a.c1 * s, c2: a.c2 * s }
}

/// Squaring using t^3 = t + 1.
#[pure]
pub fn sqr(a: Fp3) -> Fp3 {
    let s0: Field = a.c0 * a.c0;
    let s1: Field = a.c0 * a.c1 + a.c0 * a.c1;
    let s2: Field = a.c1 * a.c1 + a.c0 * a.c2 + a.c0 * a.c2;
    let a1a2: Field = a.c1 * a.c2;
    let s3: Field = a1a2 + a1a2;
    let s4: Field = a.c2 * a.c2;

    // Reduce: t^3 = t+1, t^4 = t^2+t
    Fp3 {
        c0: s0 + s3,
        c1: s1 + s3 + s4,
        c2: s2 + s4,
    }
}

/// Norm: Fp3 -> Fp.
///
/// norm(a) = c0^3 + c1^3 + c2^3 - 3*c0*c1*c2
///         + 2*c0^2*c2 + c0*c2^2 - c1*c2^2 - c0*c1^2
#[pure]
pub fn norm(a: Fp3) -> Field {
    let c0_2: Field = a.c0 * a.c0;
    let c1_2: Field = a.c1 * a.c1;
    let c2_2: Field = a.c2 * a.c2;
    let c0_3: Field = c0_2 * a.c0;
    let c1_3: Field = c1_2 * a.c1;
    let c2_3: Field = c2_2 * a.c2;
    let c0c1c2: Field = a.c0 * a.c1 * a.c2;
    let three: Field = 3;

    c0_3 + c1_3 + c2_3 - three * c0c1c2
        + c0_2 * a.c2 + c0_2 * a.c2
        + a.c0 * c2_2
        - a.c1 * c2_2
        - a.c0 * c1_2
}

/// Inversion via norm and adjugate.
///
/// Uses the multiplication matrix for t^3 = t + 1:
///   inv(alpha) = adj(M)[first column] / det(M)
#[pure]
pub fn fp3_inv(a: Fp3) -> Fp3 {
    let n_inv: Field = inv(norm(a));

    let c0_2: Field = a.c0 * a.c0;
    let c1_2: Field = a.c1 * a.c1;
    let c2_2: Field = a.c2 * a.c2;

    // First column of adjugate matrix
    let r0: Field = (c0_2 + a.c0 * a.c2 + a.c0 * a.c2 - c1_2 - a.c1 * a.c2 + c2_2) * n_inv;
    let r1: Field = (c2_2 - a.c0 * a.c1) * n_inv;
    let r2: Field = (c1_2 - a.c0 * a.c2 - c2_2) * n_inv;

    Fp3 { c0: r0, c1: r1, c2: r2 }
}

/// Check equality.
#[pure]
pub fn eq(a: Fp3, b: Fp3) -> bool {
    a.c0 == b.c0 & a.c1 == b.c1 & a.c2 == b.c2
}

Local Graph