// Montgomery batch inversion for F₂¹²⁸.
//
// Computes N inversions using 1 actual inversion + 3(N-1) multiplications.
// This amortizes the expensive tower-recursive inversion across many elements.
//
// Trident does not support dynamic-length arrays or loops with variable
// bounds. This module provides fixed-size batch inversion for common sizes.
//
// Zero elements are left as zero in output.
module kuro.batch
use kuro.tower.*
/// Batch-invert 4 F₂¹²⁸ elements via Montgomery's trick.
///
/// Uses Montgomery's trick:
/// 1. Compute running products: prefix[i] = a[0] * a[1] * ... * a[i]
/// 2. Invert the final product once: inv_all = prefix[3]⁻¹
/// 3. Walk backward recovering each individual inverse.
///
/// Returns 4-element array of inverses. Zero inputs produce zero outputs.
pub fn batch_inv_4(
a0: F2_128, a1: F2_128, a2: F2_128, a3: F2_128
) -> [F2_128; 4] {
// Step 1: running products (accumulate, skipping zeros)
let p0: F2_128 = a0;
let p1: F2_128 = f2_128_mul(p0, a1);
let p2: F2_128 = f2_128_mul(p1, a2);
let p3: F2_128 = f2_128_mul(p2, a3);
// Step 2: invert accumulated product
let inv_acc: F2_128 = f2_128_inv(p3);
// Step 3: walk backward
// inv(a3) = inv_acc * p2
let inv3: F2_128 = f2_128_mul(inv_acc, p2);
// update: inv_acc = inv_acc * a3
let inv_acc2: F2_128 = f2_128_mul(inv_acc, a3);
// inv(a2) = inv_acc2 * p1
let inv2: F2_128 = f2_128_mul(inv_acc2, p1);
let inv_acc3: F2_128 = f2_128_mul(inv_acc2, a2);
// inv(a1) = inv_acc3 * p0
let inv1: F2_128 = f2_128_mul(inv_acc3, p0);
let inv_acc4: F2_128 = f2_128_mul(inv_acc3, a1);
// inv(a0) = inv_acc4
let inv0: F2_128 = inv_acc4;
[inv0, inv1, inv2, inv3]
}
/// Batch-invert 8 F₂¹²⁸ elements via Montgomery's trick.
pub fn batch_inv_8(
a0: F2_128, a1: F2_128, a2: F2_128, a3: F2_128,
a4: F2_128, a5: F2_128, a6: F2_128, a7: F2_128
) -> [F2_128; 8] {
// Step 1: running products
let p0: F2_128 = a0;
let p1: F2_128 = f2_128_mul(p0, a1);
let p2: F2_128 = f2_128_mul(p1, a2);
let p3: F2_128 = f2_128_mul(p2, a3);
let p4: F2_128 = f2_128_mul(p3, a4);
let p5: F2_128 = f2_128_mul(p4, a5);
let p6: F2_128 = f2_128_mul(p5, a6);
let p7: F2_128 = f2_128_mul(p6, a7);
// Step 2: single inversion
let inv_acc: F2_128 = f2_128_inv(p7);
// Step 3: backward pass
let inv7: F2_128 = f2_128_mul(inv_acc, p6);
let r6: F2_128 = f2_128_mul(inv_acc, a7);
let inv6: F2_128 = f2_128_mul(r6, p5);
let r5: F2_128 = f2_128_mul(r6, a6);
let inv5: F2_128 = f2_128_mul(r5, p4);
let r4: F2_128 = f2_128_mul(r5, a5);
let inv4: F2_128 = f2_128_mul(r4, p3);
let r3: F2_128 = f2_128_mul(r4, a4);
let inv3: F2_128 = f2_128_mul(r3, p2);
let r2: F2_128 = f2_128_mul(r3, a3);
let inv2: F2_128 = f2_128_mul(r2, p1);
let r1: F2_128 = f2_128_mul(r2, a2);
let inv1: F2_128 = f2_128_mul(r1, p0);
let inv0: F2_128 = f2_128_mul(r1, a1);
[inv0, inv1, inv2, inv3, inv4, inv5, inv6, inv7]
}
kuro/tri/batch.tri
π 0.0%