// ---
// tags: nebu, trident
// crystal-type: source
// crystal-domain: comp
// ---
// Square root and Legendre symbol over the Goldilocks field.
//
// Uses Tonelli-Shanks with z = 7 (quadratic non-residue).
// p - 1 = 2^32 * s where s = 2^32 - 1 (epsilon).

module nebu.sqrt

use nebu.field

// -- Constants ---------------------------------------------------------------

// s = 2^32 - 1 = 4294967295 (the odd part of p-1)
const S: U32 = 4294967295;

// two-adicity of p-1
const TWO_ADICITY: U32 = 32;

// half_p = (p - 1) / 2, used for Legendre symbol.
// Computed as field.exp(a, half_p_lo) with additional squarings.

// -- Legendre symbol ---------------------------------------------------------

/// Legendre symbol: a^((p-1)/2).
/// Returns 0 if a=0, 1 if QR, p-1 if QNR.
#[pure]
pub fn legendre(a: Field) -> Field {
    // (p-1)/2 = 2^63 - 2^31 = 2^31 * (2^32 - 1)
    // Compute a^(2^32-1) then square 31 times.
    let mut t: Field = field.exp(a, S);
    for _i in 0..31 bounded 31 {
        t = t * t;
    }
    t
}

/// Check if `a` is a quadratic residue (or zero).
#[pure]
pub fn is_qr(a: Field) -> bool {
    let leg: Field = legendre(a);
    leg == 1 | leg == 0
}

// -- Tonelli-Shanks square root ----------------------------------------------

/// Compute sqrt(n) in the Goldilocks field.
/// Returns (root, found) where found indicates whether n is a QR.
/// If n is a QNR, root is 0 and found is false.
///
/// Uses Tonelli-Shanks with non-residue z = 7.
/// p - 1 = 2^32 * (2^32 - 1).
pub fn sqrt(n: Field) -> (Field, bool) {
    if n == 0 {
        return (0, true);
    }

    // Check Legendre symbol
    let leg: Field = legendre(n);
    if leg == sub(0, 1) {
        return (0, false);
    }

    // Tonelli-Shanks
    // c = 7^s (a 2^M-th root of unity)
    let mut c: Field = field.exp(7, S);
    // t = n^s
    let mut t: Field = field.exp(n, S);
    // r = n^((s+1)/2) = n^(2^31)  since (s+1)/2 = 2^31
    let mut r: Field = field.exp(n, 2147483648);

    let mut big_m: U32 = TWO_ADICITY;

    // Main loop: at most 32 iterations (bounded by two-adicity)
    for _iter in 0..32 bounded 32 {
        if t != 1 {
            // Find least i > 0 such that t^(2^i) = 1
            let mut i: U32 = 1;
            let mut tmp: Field = t * t;
            for _j in 0..31 bounded 31 {
                if tmp != 1 {
                    tmp = tmp * tmp;
                    i = i + 1;
                }
            }

            // b = c^(2^(big_m - i - 1))
            let mut b: Field = c;
            let sq_count: U32 = big_m - i - 1;
            for _k in 0..32 bounded 32 {
                if _k < sq_count {
                    b = b * b;
                }
            }

            big_m = i;
            c = b * b;
            t = t * c;
            r = r * b;
        }
    }

    (r, true)
}

Local Graph