module genies.isogeny

// Velu's formulas for l-isogeny computation on Montgomery curves.
//
// Given a curve E_A and a kernel point K of order l, compute:
//   - the codomain curve E_{A'} = E_A / <K>
//   - the image of an arbitrary point Q under the isogeny
//
// Uses the Velu formulas adapted for Montgomery curves:
//   For each half-kernel point with affine x-coordinate x_i:
//     g_x = 3*x_i^2 + 2*A*x_i + 1
//     t_i = 2*g_x (for order > 2)
//     u_i = 4*(x_i^3 + A*x_i^2 + x_i)
//   Accumulate V = sum(t_i), W = sum(u_i + x_i * t_i)
//   Then recover the codomain Montgomery coefficient A'.
use genies.fq
use genies.curve
use vm.core.convert

// Maximum half-kernel size: (587 - 1) / 2 = 293.
// For the 74 CSIDH primes, the largest is 587.
// We handle up to 293 kernel multiples.

// ---------------------------------------------------------------------------
// Kernel multiple computation
// ---------------------------------------------------------------------------

// Compute [2]K using xdbl, given K and a24.
fn kernel_double(k: curve.MontPoint, a24: fq.Fq) -> curve.MontPoint {
    curve.xdbl(k, a24)
}

// Compute [i+1]K = xadd([i]K, K, [i-1]K) using differential addition.
fn kernel_next(prev: curve.MontPoint, k: curve.MontPoint, prevprev: curve.MontPoint) -> curve.MontPoint {
    curve.xadd(prev, k, prevprev)
}

// ---------------------------------------------------------------------------
// Isogeny codomain: compute A' from kernel point of order ell
// ---------------------------------------------------------------------------

// Compute the Velu sums for a single kernel point with affine x-coordinate x_i.
// Returns (t_i, u_i, x_i * t_i) where:
//   g_x = 3*x_i^2 + 2*A*x_i + 1
//   t_i = 2 * g_x
//   u_i = 4 * (x_i^3 + A*x_i^2 + x_i)
fn velu_terms(xi: fq.Fq, a_coeff: fq.Fq) -> (fq.Fq, fq.Fq, fq.Fq) {
    let two: fq.Fq = fq.fq_from_u32(convert.as_u32(2))
    let three: fq.Fq = fq.fq_from_u32(convert.as_u32(3))
    let four: fq.Fq = fq.fq_from_u32(convert.as_u32(4))
    let xi2: fq.Fq = fq.fq_square(xi)
    let xi3: fq.Fq = fq.fq_mul(xi2, xi)
    // g_x = 3*x_i^2 + 2*A*x_i + 1
    let gx: fq.Fq = fq.fq_add(
        fq.fq_add(fq.fq_mul(three, xi2), fq.fq_mul(fq.fq_mul(two, a_coeff), xi)),
        fq.fq_one()
    )
    // t_i = 2 * g_x
    let ti: fq.Fq = fq.fq_mul(two, gx)
    // u_i = 4 * (x_i^3 + A*x_i^2 + x_i)
    let ui: fq.Fq = fq.fq_mul(four, fq.fq_add(fq.fq_add(xi3, fq.fq_mul(a_coeff, xi2)), xi))
    // x_i * t_i
    let xi_ti: fq.Fq = fq.fq_mul(xi, ti)
    (ti, ui, xi_ti)
}

// Accumulate Velu sums from one kernel point.
// v_acc += t_i, w_acc += u_i + x_i * t_i
fn accum_velu(v_acc: fq.Fq, w_acc: fq.Fq, ti: fq.Fq, ui: fq.Fq, xi_ti: fq.Fq) -> (fq.Fq, fq.Fq) {
    let new_v: fq.Fq = fq.fq_add(v_acc, ti)
    let new_w: fq.Fq = fq.fq_add(w_acc, fq.fq_add(ui, xi_ti))
    (new_v, new_w)
}

// Compute the image of the origin (0,0) under the isogeny.
// For each half-kernel point: -t_i/x_i + u_i/x_i^2
fn origin_image_term(xi: fq.Fq, ti: fq.Fq, ui: fq.Fq) -> fq.Fq {
    let xi_inv: fq.Fq = fq.fq_inv(xi)
    let xi2_inv: fq.Fq = fq.fq_square(xi_inv)
    fq.fq_add(fq.fq_neg(fq.fq_mul(ti, xi_inv)), fq.fq_mul(ui, xi2_inv))
}

// Recover Montgomery coefficient A' from Velu sums and origin image.
//
// After Velu, the codomain in Weierstrass form is:
//   y^2 = x^3 + A*x^2 + (1 - 5V)*x - 7W
// Shift by the origin image x0 (a root of the RHS) to get:
//   y^2 = X^3 + B*X^2 + c*X  where B = 3*x0 + A, c = 3*x0^2 + 2*A*x0 + (1-5V)
// Rescale: A' = B / sqrt(c)
fn recover_a_prime(a_coeff: fq.Fq, v_sum: fq.Fq, x0: fq.Fq) -> fq.Fq {
    let two: fq.Fq = fq.fq_from_u32(convert.as_u32(2))
    let three: fq.Fq = fq.fq_from_u32(convert.as_u32(3))
    let five: fq.Fq = fq.fq_from_u32(convert.as_u32(5))
    let a4_new: fq.Fq = fq.fq_sub(fq.fq_one(), fq.fq_mul(five, v_sum))
    let x0_sq: fq.Fq = fq.fq_square(x0)
    let b_coeff: fq.Fq = fq.fq_add(fq.fq_mul(three, x0), a_coeff)
    let c_coeff: fq.Fq = fq.fq_add(
        fq.fq_add(fq.fq_mul(three, x0_sq), fq.fq_mul(fq.fq_mul(two, a_coeff), x0)),
        a4_new
    )
    // alpha = sqrt(c_coeff)
    let alpha: fq.Fq = fq.fq_sqrt(c_coeff)
    // A' = b_coeff / alpha
    // If c is NQR (twist case), use -b_coeff / sqrt(-c).
    // The prover handles the correct branch via witness.
    fq.fq_mul(b_coeff, fq.fq_inv(alpha))
}

// ---------------------------------------------------------------------------
// Isogeny codomain for small primes
// ---------------------------------------------------------------------------

// Compute the codomain of a 3-isogeny.
// Kernel has order 3, half-kernel size h = 1.
// Only one kernel multiple: [1]K itself.
pub fn isogeny_codomain_3(c: curve.MontCurve, kernel: curve.MontPoint) -> curve.MontCurve {
    // Affine x-coordinate of kernel
    let x1: fq.Fq = curve.to_affine(kernel)
    let (t1, u1, xt1) = velu_terms(x1, c.a)
    let (v_sum, _w_sum) = accum_velu(fq.fq_zero(), fq.fq_zero(), t1, u1, xt1)
    let x0_term: fq.Fq = origin_image_term(x1, t1, u1)
    let new_a: fq.Fq = recover_a_prime(c.a, v_sum, x0_term)
    curve.curve_from_a(new_a)
}

// Compute the codomain of a 5-isogeny.
// Half-kernel size h = 2: points [1]K and [2]K.
pub fn isogeny_codomain_5(c: curve.MontCurve, kernel: curve.MontPoint) -> curve.MontCurve {
    let a24: fq.Fq = curve.precompute_a24(c)
    let k1: curve.MontPoint = kernel
    let k2: curve.MontPoint = curve.xdbl(kernel, a24)
    let x1: fq.Fq = curve.to_affine(k1)
    let x2: fq.Fq = curve.to_affine(k2)
    let (t1, u1, xt1) = velu_terms(x1, c.a)
    let (t2, u2, xt2) = velu_terms(x2, c.a)
    let (v1, w1) = accum_velu(fq.fq_zero(), fq.fq_zero(), t1, u1, xt1)
    let (v_sum, _w_sum) = accum_velu(v1, w1, t2, u2, xt2)
    let x0_t1: fq.Fq = origin_image_term(x1, t1, u1)
    let x0_t2: fq.Fq = origin_image_term(x2, t2, u2)
    let x0: fq.Fq = fq.fq_add(x0_t1, x0_t2)
    let new_a: fq.Fq = recover_a_prime(c.a, v_sum, x0)
    curve.curve_from_a(new_a)
}

// ---------------------------------------------------------------------------
// General isogeny codomain (for any odd prime ell)
// ---------------------------------------------------------------------------
// For the general case, the half-kernel [1]K, [2]K, ..., [(ell-1)/2]K
// must be computed via differential addition chain, and Velu sums
// accumulated for each point.
//
// Since Trident has no dynamic loops, the general case is handled by:
//   1. The prover computes all kernel multiples externally.
//   2. Each kernel multiple is provided as a witness.
//   3. The circuit verifies the differential addition chain and Velu sums.
//
// For CSIDH-512 with 74 primes up to 587, the maximum half-kernel is 293.
// The prover unrolls up to 293 steps.

// General isogeny with kernel multiples provided as witnesses.
// Takes the half-kernel points as a list (represented as pairs of
// MontPoints and count), accumulates Velu sums, and recovers A'.
//
// For the ZK circuit, the prover provides:
//   - The kernel multiples [1]K, ..., [h]K as witnesses
//   - The circuit verifies: [i+1]K = xadd([i]K, K, [i-1]K) for each i
//   - Accumulates Velu sums and recovers A'
//
// This function handles the case for a single kernel point (ell = 3).
// Larger primes are composed from isogeny_codomain_3/5 or the
// general witness-based approach.

// ---------------------------------------------------------------------------
// Isogeny evaluation: push a point Q through the isogeny
// ---------------------------------------------------------------------------

// Evaluate the ell-isogeny at point Q.
// Projective formula:
//   phi(X_Q : Z_Q) = ( X_Q * prod((X_Q*Z_i - Z_Q*X_i)^2),
//                       Z_Q * prod((X_Q*X_i - Z_Q*Z_i)^2) )
// where (X_i : Z_i) are the half-kernel points.

// Evaluation factor for one kernel point: returns updated (num, den).
fn eval_factor(qx: fq.Fq, qz: fq.Fq, kx: fq.Fq, kz: fq.Fq, num: fq.Fq, den: fq.Fq) -> (fq.Fq, fq.Fq) {
    let t1: fq.Fq = fq.fq_sub(fq.fq_mul(qx, kz), fq.fq_mul(qz, kx))
    let t2: fq.Fq = fq.fq_sub(fq.fq_mul(qx, kx), fq.fq_mul(qz, kz))
    let new_num: fq.Fq = fq.fq_mul(num, fq.fq_square(t1))
    let new_den: fq.Fq = fq.fq_mul(den, fq.fq_square(t2))
    (new_num, new_den)
}

// Evaluate a 3-isogeny at Q.
pub fn isogeny_eval_3(kernel: curve.MontPoint, q_pt: curve.MontPoint) -> curve.MontPoint {
    let (num, den) = eval_factor(q_pt.x, q_pt.z, kernel.x, kernel.z, q_pt.x, q_pt.z)
    curve.MontPoint { x: num, z: den }
}

// Evaluate a 5-isogeny at Q.
pub fn isogeny_eval_5(kernel: curve.MontPoint, a24: fq.Fq, q_pt: curve.MontPoint) -> curve.MontPoint {
    let k2: curve.MontPoint = curve.xdbl(kernel, a24)
    let (num1, den1) = eval_factor(q_pt.x, q_pt.z, kernel.x, kernel.z, q_pt.x, q_pt.z)
    let (num2, den2) = eval_factor(q_pt.x, q_pt.z, k2.x, k2.z, num1, den1)
    curve.MontPoint { x: num2, z: den2 }
}

// ---------------------------------------------------------------------------
// Combined isogeny: compute codomain + push point
// ---------------------------------------------------------------------------
pub fn apply_isogeny_3(c: curve.MontCurve, kernel: curve.MontPoint, push: curve.MontPoint) -> (curve.MontCurve, curve.MontPoint) {
    let new_curve: curve.MontCurve = isogeny_codomain_3(c, kernel)
    let new_point: curve.MontPoint = isogeny_eval_3(kernel, push)
    (new_curve, new_point)
}

pub fn apply_isogeny_5(c: curve.MontCurve, kernel: curve.MontPoint, push: curve.MontPoint) -> (curve.MontCurve, curve.MontPoint) {
    let a24: fq.Fq = curve.precompute_a24(c)
    let new_curve: curve.MontCurve = isogeny_codomain_5(c, kernel)
    let new_point: curve.MontPoint = isogeny_eval_5(kernel, a24, push)
    (new_curve, new_point)
}

Local Graph