// ---
// tags: jali, trident
// crystal-type: circuit
// crystal-domain: comp
// ---

//  Deterministic sampling for lattice-based schemes.
//
//  In the proving context, the prover supplies ring element witnesses
//  via divine(). The circuit constrains the samples to lie in the
//  correct distribution. This is the standard ZK-FHE pattern:
//  the prover knows the secret, the proof verifies correctness.
//
//  No PRNG inside the circuit -- randomness is the prover's responsibility.
//  The circuit only enforces distribution membership constraints.

module jali.sample

use jali.ring.{RingElement}

/// Sample a uniform ring element from the prover.
/// Each coefficient is an unconstrained Field element via divine().
/// Used for the random `a` polynomial in encryption.
pub fn sample_uniform() -> RingElement {
    let mut coeffs: [Field; 1024] = [0; 1024]
    for i in 0..1024 {
        coeffs[i] = divine()
    }
    RingElement { coeffs: coeffs }
}

/// Sample a ternary ring element: each coefficient in {-1, 0, 1}.
/// The prover supplies coefficients via divine(); the circuit constrains
/// each value v by asserting v * (v - 1) * (v + 1) == 0.
/// This is equivalent to v in {0, 1, p-1} over the field.
pub fn sample_ternary() -> RingElement {
    let mut coeffs: [Field; 1024] = [0; 1024]
    for i in 0..1024 {
        let v: Field = divine()
        // Constraint: v * (v - 1) * (v + 1) == 0
        // Factors to: v^3 - v == 0, i.e. v in {0, 1, -1}
        let check: Field = v * (v + field.neg(1)) * (v + 1)
        assert(check == 0)
        coeffs[i] = v
    }
    RingElement { coeffs: coeffs }
}

/// Sample from centered binomial distribution CBD(eta).
/// Each coefficient is in {-eta, ..., eta}.
/// The prover supplies each coefficient via divine(); the circuit
/// constrains (v + eta) to be in {0, 1, ..., 2*eta} using a
/// vanishing product: (s)(s-1)...(s-2*eta) == 0 where s = v + eta.
/// Max supported eta = 15 (bound loop runs up to 32).
pub fn sample_cbd(eta: U32) -> RingElement {
    let mut coeffs: [Field; 1024] = [0; 1024]
    let eta_f: Field = eta
    let bound: Field = eta_f + eta_f + 1
    for i in 0..1024 {
        let v: Field = divine()
        // Shift into non-negative range: s = v + eta in {0, ..., 2*eta}
        let s: Field = v + eta_f
        // Vanishing product constraint: s must be one of {0, 1, ..., 2*eta}
        let mut product: Field = 1
        for k in 0..32 {
            if k < bound {
                let kf: Field = k
                product = product * (s + field.neg(kf))
            }
        }
        assert(product == 0)
        coeffs[i] = v
    }
    RingElement { coeffs: coeffs }
}

Local Graph