// Montgomery batch inversion for F₂¹²⁸.
//
// Computes N inversions using 1 actual inversion + 3(N-1) multiplications.
// This amortizes the expensive tower-recursive inversion across many elements.
//
// Trident does not support dynamic-length arrays or loops with variable
// bounds. This module provides fixed-size batch inversion for common sizes.
//
// Zero elements are left as zero in output.

module kuro.batch

use kuro.tower.*

/// Batch-invert 4 F₂¹²⁸ elements via Montgomery's trick.
///
/// Uses Montgomery's trick:
///   1. Compute running products: prefix[i] = a[0] * a[1] * ... * a[i]
///   2. Invert the final product once: inv_all = prefix[3]⁻¹
///   3. Walk backward recovering each individual inverse.
///
/// Returns 4-element array of inverses. Zero inputs produce zero outputs.
pub fn batch_inv_4(
    a0: F2_128, a1: F2_128, a2: F2_128, a3: F2_128
) -> [F2_128; 4] {
    // Step 1: running products (accumulate, skipping zeros)
    let p0: F2_128 = a0;
    let p1: F2_128 = f2_128_mul(p0, a1);
    let p2: F2_128 = f2_128_mul(p1, a2);
    let p3: F2_128 = f2_128_mul(p2, a3);

    // Step 2: invert accumulated product
    let inv_acc: F2_128 = f2_128_inv(p3);

    // Step 3: walk backward
    // inv(a3) = inv_acc * p2
    let inv3: F2_128 = f2_128_mul(inv_acc, p2);
    // update: inv_acc = inv_acc * a3
    let inv_acc2: F2_128 = f2_128_mul(inv_acc, a3);

    // inv(a2) = inv_acc2 * p1
    let inv2: F2_128 = f2_128_mul(inv_acc2, p1);
    let inv_acc3: F2_128 = f2_128_mul(inv_acc2, a2);

    // inv(a1) = inv_acc3 * p0
    let inv1: F2_128 = f2_128_mul(inv_acc3, p0);
    let inv_acc4: F2_128 = f2_128_mul(inv_acc3, a1);

    // inv(a0) = inv_acc4
    let inv0: F2_128 = inv_acc4;

    [inv0, inv1, inv2, inv3]
}

/// Batch-invert 8 F₂¹²⁸ elements via Montgomery's trick.
pub fn batch_inv_8(
    a0: F2_128, a1: F2_128, a2: F2_128, a3: F2_128,
    a4: F2_128, a5: F2_128, a6: F2_128, a7: F2_128
) -> [F2_128; 8] {
    // Step 1: running products
    let p0: F2_128 = a0;
    let p1: F2_128 = f2_128_mul(p0, a1);
    let p2: F2_128 = f2_128_mul(p1, a2);
    let p3: F2_128 = f2_128_mul(p2, a3);
    let p4: F2_128 = f2_128_mul(p3, a4);
    let p5: F2_128 = f2_128_mul(p4, a5);
    let p6: F2_128 = f2_128_mul(p5, a6);
    let p7: F2_128 = f2_128_mul(p6, a7);

    // Step 2: single inversion
    let inv_acc: F2_128 = f2_128_inv(p7);

    // Step 3: backward pass
    let inv7: F2_128 = f2_128_mul(inv_acc, p6);
    let r6: F2_128 = f2_128_mul(inv_acc, a7);

    let inv6: F2_128 = f2_128_mul(r6, p5);
    let r5: F2_128 = f2_128_mul(r6, a6);

    let inv5: F2_128 = f2_128_mul(r5, p4);
    let r4: F2_128 = f2_128_mul(r5, a5);

    let inv4: F2_128 = f2_128_mul(r4, p3);
    let r3: F2_128 = f2_128_mul(r4, a4);

    let inv3: F2_128 = f2_128_mul(r3, p2);
    let r2: F2_128 = f2_128_mul(r3, a3);

    let inv2: F2_128 = f2_128_mul(r2, p1);
    let r1: F2_128 = f2_128_mul(r2, a2);

    let inv1: F2_128 = f2_128_mul(r1, p0);
    let inv0: F2_128 = f2_128_mul(r1, a1);

    [inv0, inv1, inv2, inv3, inv4, inv5, inv6, inv7]
}

Dimensions

nebu/tri/batch.tri

Local Graph