module std.crypto.bigint

// Multi-precision arithmetic over U32 limbs.
//
// The native Goldilocks field (p = 2^64 - 2^32 + 1) is too small for
// elliptic curve fields like secp256k1 (256-bit prime). This module
// provides 256-bit integer arithmetic built on arrays of U32 limbs
// in little-endian order (l0 is least significant).
//
// U32 arithmetic constraints:
//   - `+` and `*` are Field-only operators.
//   - U32 has `&`, `^`, `<`, `/%` as native operators.
//   - To add U32 values we widen to Field, add, then split back.
//   - `convert.as_field(u)` widens U32 -> Field losslessly.
//   - `convert.as_u32(f)` narrows Field -> U32 (asserts fits in 32 bits).
//   - `convert.split(f)` splits a Field into (hi: U32, lo: U32).
use vm.core.convert

use vm.core.field

// A 256-bit integer represented as 8 U32 limbs (little-endian).
// l0 is the least significant limb, l7 is the most significant.
pub struct U256 {
    l0: U32,
    l1: U32,
    l2: U32,
    l3: U32,
    l4: U32,
    l5: U32,
    l6: U32,
    l7: U32,
}

// ---------------------------------------------------------------------------
// Constants
// ---------------------------------------------------------------------------
// Zero constant.
pub fn zero256() -> U256 {
    let z: U32 = convert.as_u32(0)
    U256 { l0: z, l1: z, l2: z, l3: z, l4: z, l5: z, l6: z, l7: z }
}

// One constant.
pub fn one256() -> U256 {
    let z: U32 = convert.as_u32(0)
    let o: U32 = convert.as_u32(1)
    U256 { l0: o, l1: z, l2: z, l3: z, l4: z, l5: z, l6: z, l7: z }
}

// Convert a single U32 to U256.
pub fn from_u32(x: U32) -> U256 {
    let z: U32 = convert.as_u32(0)
    U256 { l0: x, l1: z, l2: z, l3: z, l4: z, l5: z, l6: z, l7: z }
}

// ---------------------------------------------------------------------------
// Helpers: U32-level add / subtract with carry / borrow
// ---------------------------------------------------------------------------
// Add two U32 values plus a carry-in (0 or 1). Returns (sum, carry_out).
// Max value: (2^32-1) + (2^32-1) + 1 = 2^33 - 1, which fits in a Field.
// After split, hi is 0 or 1 (the carry), lo is the 32-bit sum.
fn add_u32_carry(a: U32, b: U32, carry_in: U32) -> (U32, U32) {
    let fa: Field = convert.as_field(a)
    let fb: Field = convert.as_field(b)
    let fc: Field = convert.as_field(carry_in)
    let sum_field: Field = fa + fb + fc
    let (hi, lo) = convert.split(sum_field)
    // hi is carry_out (0 or 1), lo is the 32-bit result
    (lo, hi)
}

// Subtract b + borrow_in from a, returning (result, borrow_out).
// We compute a - b - borrow_in. If a >= b + borrow_in, borrow_out = 0.
// Otherwise we add 2^32 and borrow_out = 1.
// Implementation: compute (2^32 + a) - b - borrow_in to avoid underflow.
// The result's hi limb after split tells us if we borrowed.
fn sub_u32_borrow(a: U32, b: U32, borrow_in: U32) -> (U32, U32) {
    let fa: Field = convert.as_field(a)
    let fb: Field = convert.as_field(b)
    let fbi: Field = convert.as_field(borrow_in)
    // Add 2^32 to a to ensure non-negative, then subtract b and borrow_in.
    // Result = (2^32 + a) - b - borrow_in
    // If a >= b + borrow_in: result >= 2^32, so hi = 1, borrow_out = 0.
    // If a < b + borrow_in:  result < 2^32,  so hi = 0, borrow_out = 1.
    // We need the constant 2^32. We can get it via field arithmetic.
    // 2^32 = 4294967296. We build it as 65536 * 65536.
    let half: Field = convert.as_field(convert.as_u32(65536))
    let pow2_32: Field = half * half
    let diff_field: Field = pow2_32 + fa + field.neg(fb) + field.neg(fbi)
    let (hi, lo) = convert.split(diff_field)
    // hi = 1 means no borrow needed, hi = 0 means borrow was needed
    // borrow_out = 1 - hi
    if hi == convert.as_u32(1) {
        let zero_borrow: U32 = convert.as_u32(0)
        (lo, zero_borrow)
    } else {
        let one_borrow: U32 = convert.as_u32(1)
        (lo, one_borrow)
    }
}

// ---------------------------------------------------------------------------
// 256-bit Addition
// ---------------------------------------------------------------------------
// Add two 256-bit integers, returning (result, carry).
// Propagates carries through all 8 limbs using the add_u32_carry helper.
pub fn add256(a: U256, b: U256) -> (U256, U32) {
    let c0: U32 = convert.as_u32(0)
    let (r0, c1) = add_u32_carry(a.l0, b.l0, c0)
    let (r1, c2) = add_u32_carry(a.l1, b.l1, c1)
    let (r2, c3) = add_u32_carry(a.l2, b.l2, c2)
    let (r3, c4) = add_u32_carry(a.l3, b.l3, c3)
    let (r4, c5) = add_u32_carry(a.l4, b.l4, c4)
    let (r5, c6) = add_u32_carry(a.l5, b.l5, c5)
    let (r6, c7) = add_u32_carry(a.l6, b.l6, c6)
    let (r7, c8) = add_u32_carry(a.l7, b.l7, c7)
    let result: U256 = U256 { l0: r0, l1: r1, l2: r2, l3: r3, l4: r4, l5: r5, l6: r6, l7: r7 }
    (result, c8)
}

// ---------------------------------------------------------------------------
// 256-bit Subtraction
// ---------------------------------------------------------------------------
// Subtract b from a (assumes a >= b), returning result.
// Propagates borrows through all 8 limbs.
pub fn sub256(a: U256, b: U256) -> U256 {
    let bw0: U32 = convert.as_u32(0)
    let (r0, bw1) = sub_u32_borrow(a.l0, b.l0, bw0)
    let (r1, bw2) = sub_u32_borrow(a.l1, b.l1, bw1)
    let (r2, bw3) = sub_u32_borrow(a.l2, b.l2, bw2)
    let (r3, bw4) = sub_u32_borrow(a.l3, b.l3, bw3)
    let (r4, bw5) = sub_u32_borrow(a.l4, b.l4, bw4)
    let (r5, bw6) = sub_u32_borrow(a.l5, b.l5, bw5)
    let (r6, bw7) = sub_u32_borrow(a.l6, b.l6, bw6)
    let (r7, _bw8) = sub_u32_borrow(a.l7, b.l7, bw7)
    U256 { l0: r0, l1: r1, l2: r2, l3: r3, l4: r4, l5: r5, l6: r6, l7: r7 }
}

// ---------------------------------------------------------------------------
// 256-bit Comparison
// ---------------------------------------------------------------------------
// Compare: returns true if a < b.
// Compares limb by limb from most significant (l7) to least significant (l0).
pub fn lt256(a: U256, b: U256) -> Bool {
    // Compare from most significant limb downward.
    if a.l7 < b.l7 {
        true
    } else if b.l7 < a.l7 {
        false
    } else if a.l6 < b.l6 {
        true
    } else if b.l6 < a.l6 {
        false
    } else if a.l5 < b.l5 {
        true
    } else if b.l5 < a.l5 {
        false
    } else if a.l4 < b.l4 {
        true
    } else if b.l4 < a.l4 {
        false
    } else if a.l3 < b.l3 {
        true
    } else if b.l3 < a.l3 {
        false
    } else if a.l2 < b.l2 {
        true
    } else if b.l2 < a.l2 {
        false
    } else if a.l1 < b.l1 {
        true
    } else if b.l1 < a.l1 {
        false
    } else if a.l0 < b.l0 {
        true
    } else {
        false
    }
}

// Check equality of two U256 values.
// All 8 limbs must be equal.
pub fn eq256(a: U256, b: U256) -> Bool {
    if a.l0 == b.l0 {
        if a.l1 == b.l1 {
            if a.l2 == b.l2 {
                if a.l3 == b.l3 {
                    if a.l4 == b.l4 {
                        if a.l5 == b.l5 {
                            if a.l6 == b.l6 {
                                if a.l7 == b.l7 {
                                    true
                                } else {
                                    false
                                }
                            } else {
                                false
                            }
                        } else {
                            false
                        }
                    } else {
                        false
                    }
                } else {
                    false
                }
            } else {
                false
            }
        } else {
            false
        }
    } else {
        false
    }
}

// ---------------------------------------------------------------------------
// 256-bit Multiplication (low 256 bits only)
// ---------------------------------------------------------------------------
// Multiply a U256 by a single U32 limb, returning (U256 result, U32 carry).
// This is the core building block for schoolbook multiplication.
fn mul256_by_u32(a: U256, b: U32) -> (U256, U32) {
    let fb: Field = convert.as_field(b)
    let zero: U32 = convert.as_u32(0)
    // Limb 0: a.l0 * b
    let prod0: Field = convert.as_field(a.l0) * fb
    let (hi0, lo0) = convert.split(prod0)
    // hi0 is at most (2^32-1)*(2^32-1) >> 32 which fits in U32
    // Limb 1: a.l1 * b + carry
    let prod1: Field = convert.as_field(a.l1) * fb + convert.as_field(hi0)
    let (hi1, lo1) = convert.split(prod1)
    // Limb 2
    let prod2: Field = convert.as_field(a.l2) * fb + convert.as_field(hi1)
    let (hi2, lo2) = convert.split(prod2)
    // Limb 3
    let prod3: Field = convert.as_field(a.l3) * fb + convert.as_field(hi2)
    let (hi3, lo3) = convert.split(prod3)
    // Limb 4
    let prod4: Field = convert.as_field(a.l4) * fb + convert.as_field(hi3)
    let (hi4, lo4) = convert.split(prod4)
    // Limb 5
    let prod5: Field = convert.as_field(a.l5) * fb + convert.as_field(hi4)
    let (hi5, lo5) = convert.split(prod5)
    // Limb 6
    let prod6: Field = convert.as_field(a.l6) * fb + convert.as_field(hi5)
    let (hi6, lo6) = convert.split(prod6)
    // Limb 7
    let prod7: Field = convert.as_field(a.l7) * fb + convert.as_field(hi6)
    let (hi7, lo7) = convert.split(prod7)
    let result: U256 = U256 { l0: lo0, l1: lo1, l2: lo2, l3: lo3, l4: lo4, l5: lo5, l6: lo6, l7: lo7 }
    (result, hi7)
}

// Shift a U256 left by one limb position (32 bits).
// The least significant limb becomes zero; the most significant limb is lost.
fn shl_one_limb(a: U256) -> U256 {
    let z: U32 = convert.as_u32(0)
    U256 { l0: z, l1: a.l0, l2: a.l1, l3: a.l2, l4: a.l3, l5: a.l4, l6: a.l5, l7: a.l6 }
}

// Shift left by N limb positions (0..7). Only keeps low 256 bits.
fn shl_limbs(a: U256, n: U32) -> U256 {
    let z: U32 = convert.as_u32(0)
    if n == convert.as_u32(0) {
        a
    } else if n == convert.as_u32(1) {
        U256 { l0: z, l1: a.l0, l2: a.l1, l3: a.l2, l4: a.l3, l5: a.l4, l6: a.l5, l7: a.l6 }
    } else if n == convert.as_u32(2) {
        U256 { l0: z, l1: z, l2: a.l0, l3: a.l1, l4: a.l2, l5: a.l3, l6: a.l4, l7: a.l5 }
    } else if n == convert.as_u32(3) {
        U256 { l0: z, l1: z, l2: z, l3: a.l0, l4: a.l1, l5: a.l2, l6: a.l3, l7: a.l4 }
    } else if n == convert.as_u32(4) {
        U256 { l0: z, l1: z, l2: z, l3: z, l4: a.l0, l5: a.l1, l6: a.l2, l7: a.l3 }
    } else if n == convert.as_u32(5) {
        U256 { l0: z, l1: z, l2: z, l3: z, l4: z, l5: a.l0, l6: a.l1, l7: a.l2 }
    } else if n == convert.as_u32(6) {
        U256 { l0: z, l1: z, l2: z, l3: z, l4: z, l5: z, l6: a.l0, l7: a.l1 }
    } else {
        U256 { l0: z, l1: z, l2: z, l3: z, l4: z, l5: z, l6: z, l7: a.l0 }
    }
}

// Multiply two 256-bit integers, returning the low 256 bits.
// Uses schoolbook multiplication: for each limb of b, multiply a by that limb,
// shift by the appropriate number of positions, and accumulate.
pub fn mul256_low(a: U256, b: U256) -> U256 {
    // Partial product for limb 0 of b
    let (p0, _) = mul256_by_u32(a, b.l0)
    let s0: U256 = p0
    // Partial product for limb 1 of b, shifted left by 1 limb
    let (p1, _) = mul256_by_u32(a, b.l1)
    let p1_shifted: U256 = shl_limbs(p1, convert.as_u32(1))
    let (s1, _) = add256(s0, p1_shifted)
    // Partial product for limb 2 of b, shifted left by 2 limbs
    let (p2, _) = mul256_by_u32(a, b.l2)
    let p2_shifted: U256 = shl_limbs(p2, convert.as_u32(2))
    let (s2, _) = add256(s1, p2_shifted)
    // Partial product for limb 3 of b, shifted left by 3 limbs
    let (p3, _) = mul256_by_u32(a, b.l3)
    let p3_shifted: U256 = shl_limbs(p3, convert.as_u32(3))
    let (s3, _) = add256(s2, p3_shifted)
    // Partial product for limb 4 of b, shifted left by 4 limbs
    let (p4, _) = mul256_by_u32(a, b.l4)
    let p4_shifted: U256 = shl_limbs(p4, convert.as_u32(4))
    let (s4, _) = add256(s3, p4_shifted)
    // Partial product for limb 5 of b, shifted left by 5 limbs
    let (p5, _) = mul256_by_u32(a, b.l5)
    let p5_shifted: U256 = shl_limbs(p5, convert.as_u32(5))
    let (s5, _) = add256(s4, p5_shifted)
    // Partial product for limb 6 of b, shifted left by 6 limbs
    let (p6, _) = mul256_by_u32(a, b.l6)
    let p6_shifted: U256 = shl_limbs(p6, convert.as_u32(6))
    let (s6, _) = add256(s5, p6_shifted)
    // Partial product for limb 7 of b, shifted left by 7 limbs
    let (p7, _) = mul256_by_u32(a, b.l7)
    let p7_shifted: U256 = shl_limbs(p7, convert.as_u32(7))
    let (s7, _) = add256(s6, p7_shifted)
    s7
}

// ---------------------------------------------------------------------------
// Modular reduction (conditional subtraction)
// ---------------------------------------------------------------------------
// Reduce a U256 modulo m by conditional subtraction.
// If a >= m, returns a - m. Otherwise returns a.
// For full modular reduction of values that may be much larger than m,
// the caller must apply this repeatedly or use a different strategy.
pub fn mod_reduce_once(a: U256, m: U256) -> U256 {
    if lt256(a, m) {
        a
    } else {
        sub256(a, m)
    }
}

// Modular addition: (a + b) mod m.
// Assumes a < m and b < m, so a + b < 2m, and a single conditional
// subtraction suffices.
pub fn add_mod(a: U256, b: U256, m: U256) -> U256 {
    let (sum, _carry) = add256(a, b)
    mod_reduce_once(sum, m)
}

// Modular subtraction: (a - b) mod m.
// Assumes a < m and b < m. If a >= b, returns a - b.
// If a < b, returns m - (b - a) = m + a - b.
pub fn sub_mod(a: U256, b: U256, m: U256) -> U256 {
    if lt256(a, b) {
        // a < b: result = m - (b - a)
        let diff: U256 = sub256(b, a)
        sub256(m, diff)
    } else {
        sub256(a, b)
    }
}

// Modular multiplication: (a * b) mod m.
// NOTE: This computes only the low 256 bits of a*b, then reduces once.
// For full correctness with large operands, a multi-word Barrett or
// Montgomery reduction would be needed.
// TODO: requires full 512-bit multiplication and proper Barrett reduction
pub fn mul_mod(a: U256, b: U256, m: U256) -> U256 {
    let product: U256 = mul256_low(a, b)
    mod_reduce_once(product, m)
}

// ---------------------------------------------------------------------------
// Bitwise operations on U256
// ---------------------------------------------------------------------------
// Bitwise AND of two U256 values.
pub fn and256(a: U256, b: U256) -> U256 {
    U256 { l0: a.l0 & b.l0, l1: a.l1 & b.l1, l2: a.l2 & b.l2, l3: a.l3 & b.l3, l4: a.l4 & b.l4, l5: a.l5 & b.l5, l6: a.l6 & b.l6, l7: a.l7 & b.l7 }
}

// Bitwise XOR of two U256 values.
pub fn xor256(a: U256, b: U256) -> U256 {
    U256 { l0: a.l0 ^ b.l0, l1: a.l1 ^ b.l1, l2: a.l2 ^ b.l2, l3: a.l3 ^ b.l3, l4: a.l4 ^ b.l4, l5: a.l5 ^ b.l5, l6: a.l6 ^ b.l6, l7: a.l7 ^ b.l7 }
}

// Check if a U256 is zero.
pub fn is_zero(a: U256) -> Bool {
    let z: U32 = convert.as_u32(0)
    if a.l0 == z {
        if a.l1 == z {
            if a.l2 == z {
                if a.l3 == z {
                    if a.l4 == z {
                        if a.l5 == z {
                            if a.l6 == z {
                                if a.l7 == z {
                                    true
                                } else {
                                    false
                                }
                            } else {
                                false
                            }
                        } else {
                            false
                        }
                    } else {
                        false
                    }
                } else {
                    false
                }
            } else {
                false
            }
        } else {
            false
        }
    } else {
        false
    }
}

Local Graph