// ---
// tags: nebu, trident
// crystal-type: source
// crystal-domain: comp
// ---
// Quadratic extension field Fp2 = Fp[u] / (u^2 - 7).
//
// Elements are (re, im) representing re + im*u where u^2 = 7.

module nebu.fp2

// -- Type definition ---------------------------------------------------------

pub struct Fp2 {
    re: Field,
    im: Field,
}

// -- Constants ---------------------------------------------------------------

pub const ZERO: Fp2 = Fp2 { re: 0, im: 0 };
pub const ONE: Fp2 = Fp2 { re: 1, im: 0 };
const SEVEN: Field = 7;
const EIGHT: Field = 8;

// -- Construction ------------------------------------------------------------

#[pure]
pub fn new(re: Field, im: Field) -> Fp2 {
    Fp2 { re: re, im: im }
}

/// Embed a base field element as (a, 0).
#[pure]
pub fn from_base(a: Field) -> Fp2 {
    Fp2 { re: a, im: 0 }
}

// -- Arithmetic --------------------------------------------------------------

#[pure]
pub fn add(a: Fp2, b: Fp2) -> Fp2 {
    Fp2 { re: a.re + b.re, im: a.im + b.im }
}

#[pure]
pub fn fp2_sub(a: Fp2, b: Fp2) -> Fp2 {
    Fp2 { re: sub(a.re, b.re), im: sub(a.im, b.im) }
}

#[pure]
pub fn fp2_neg(a: Fp2) -> Fp2 {
    Fp2 { re: neg(a.re), im: neg(a.im) }
}

/// Karatsuba multiplication: 3 base muls + 1 mul-by-7 + 5 add/subs.
/// (a0 + a1*u)(b0 + b1*u) = (a0*b0 + 7*a1*b1) + (a0*b1 + a1*b0)*u
#[pure]
pub fn mul(a: Fp2, b: Fp2) -> Fp2 {
    let v0: Field = a.re * b.re;
    let v1: Field = a.im * b.im;
    let re: Field = v0 + SEVEN * v1;
    let im: Field = (a.re + a.im) * (b.re + b.im) - v0 - v1;
    Fp2 { re: re, im: im }
}

/// Scalar multiplication: Fp2 * Field.
#[pure]
pub fn scale(a: Fp2, s: Field) -> Fp2 {
    Fp2 { re: a.re * s, im: a.im * s }
}

/// Optimized squaring: 2 muls + small-constant muls.
#[pure]
pub fn sqr(a: Fp2) -> Fp2 {
    let ab: Field = a.re * a.im;
    let re: Field = (a.re + a.im) * (a.re + SEVEN * a.im) - EIGHT * ab;
    let im: Field = ab + ab;
    Fp2 { re: re, im: im }
}

/// Conjugate: (a, b) -> (a, -b).
#[pure]
pub fn conj(a: Fp2) -> Fp2 {
    Fp2 { re: a.re, im: neg(a.im) }
}

/// Norm: a.re^2 - 7*a.im^2 (in Fp).
#[pure]
pub fn norm(a: Fp2) -> Field {
    a.re * a.re - SEVEN * (a.im * a.im)
}

/// Inversion via norm: (a + b*u)^{-1} = (a - b*u) / (a^2 - 7*b^2).
#[pure]
pub fn fp2_inv(a: Fp2) -> Fp2 {
    let n_inv: Field = inv(norm(a));
    Fp2 { re: a.re * n_inv, im: neg(a.im) * n_inv }
}

/// Check equality.
#[pure]
pub fn eq(a: Fp2, b: Fp2) -> bool {
    a.re == b.re & a.im == b.im
}

Local Graph