program trinity_bench

// Benchmark program: calls Trinity phases individually.
// All parameters passed via RAM to avoid deep stack issues.
// Compiled to linked TASM with all dependencies resolved.
use vm.io.io

use vm.io.mem

use vm.core.field

use vm.core.convert

use std.math.lut

use std.trinity.inference

use std.nn.tensor

// Store all 29 parameters from public input into RAM config block.
// RAM addresses 0-28 hold the config.
fn store_params() {
    mem.write(0, io.read())   // cts_addr
    mem.write(1, io.read())   // s_addr
    mem.write(2, io.read())   // w_priv_addr
    mem.write(3, io.read())   // ct_out_addr
    mem.write(4, io.read())   // tmp_addr
    mem.write(5, io.read())   // result_addr
    mem.write(6, io.read())   // delta
    mem.write(7, io.read())   // lwe_n
    mem.write(8, io.read())   // input_dim
    mem.write(9, io.read())   // neurons
    mem.write(10, io.read())  // dense_w_addr
    mem.write(11, io.read())  // dense_b_addr
    mem.write(12, io.read())  // activated_addr
    mem.write(13, io.read())  // lut_addr
    mem.write(14, io.read())  // expected_class
    mem.write(15, io.read())  // rc_addr
    mem.write(16, io.read())  // weights_digest
    mem.write(17, io.read())  // key_digest
    mem.write(18, io.read())  // expected_digest
    mem.write(19, io.read())  // domain
    mem.write(20, io.read())  // sponge_rc_addr
    mem.write(21, io.read())  // expected_lut_digest
    mem.write(22, io.read())  // pbs_sample_ct
    mem.write(23, io.read())  // pbs_out_addr
    mem.write(24, io.read())  // ring_n
    mem.write(25, io.read())  // pbs_acc_addr
    mem.write(26, io.read())  // pbs_test_addr
    mem.write(27, io.read())  // pbs_tmp_addr
    mem.write(28, io.read())  // pbs_expected_m
}

// Load Poseidon2 (86) and LUT sponge (112) round constants from public input
// into RAM at rc_addr and sponge_rc_addr respectively.
fn load_round_constants() {
    let rc_addr: Field = mem.read(15)
    for i in 0..86 bounded 86 {
        let idx: Field = convert.as_field(i)
        mem.write(rc_addr + idx, io.read())
    }
    let sponge_rc_addr: Field = mem.read(20)
    for i in 0..112 bounded 112 {
        let idx: Field = convert.as_field(i)
        mem.write(sponge_rc_addr + idx, io.read())
    }
}

fn run_phase0_lut() {
    let lut_addr: Field = mem.read(13)
    let domain: Field = mem.read(19)
    let half_size: Field = domain * field.inv(2)
    lut.build_relu(lut_addr, half_size, domain)
}

fn run_phase1_encrypt() {
    let cts_addr: Field = mem.read(0)
    let w_priv_addr: Field = mem.read(2)
    let ct_out_addr: Field = mem.read(3)
    let tmp_addr: Field = mem.read(4)
    let lwe_n: Field = mem.read(7)
    let input_dim: Field = mem.read(8)
    let neurons: Field = mem.read(9)
    inference.private_linear(cts_addr, w_priv_addr, ct_out_addr, tmp_addr, lwe_n, input_dim, neurons)
}

fn run_phase1b_decrypt() {
    let ct_out_addr: Field = mem.read(3)
    let s_addr: Field = mem.read(1)
    let result_addr: Field = mem.read(5)
    let delta: Field = mem.read(6)
    let lwe_n: Field = mem.read(7)
    let neurons: Field = mem.read(9)
    inference.decrypt_outputs(ct_out_addr, s_addr, result_addr, delta, lwe_n, neurons)
}

fn run_phase2_neural() {
    let dense_w_addr: Field = mem.read(10)
    let result_addr: Field = mem.read(5)
    let dense_b_addr: Field = mem.read(11)
    let activated_addr: Field = mem.read(12)
    let tmp_addr: Field = mem.read(4)
    let lut_addr: Field = mem.read(13)
    let neurons: Field = mem.read(9)
    inference.dense_layer(dense_w_addr, result_addr, dense_b_addr, activated_addr, tmp_addr, lut_addr, neurons)
}

fn run_phase3_hash() {
    let activated_addr: Field = mem.read(12)
    let neurons: Field = mem.read(9)
    let weights_digest: Field = mem.read(16)
    let key_digest: Field = mem.read(17)
    let class: Field = tensor.argmax(activated_addr, neurons)
    // LUT sponge hash
    let lut_addr: Field = mem.read(13)
    let domain: Field = mem.read(19)
    let sponge_rc_addr: Field = mem.read(20)
    let lut_digest: Field = inference.lut_hash_commit(
        activated_addr, neurons, weights_digest, key_digest,
        class, lut_addr, domain, sponge_rc_addr
    )
    // Poseidon2 hash
    let rc_addr: Field = mem.read(15)
    let digest: Field = inference.hash_commit(activated_addr, neurons, weights_digest, key_digest, class, rc_addr)
}

fn run_phase4_pbs() {
    let pbs_sample_ct: Field = mem.read(22)
    let s_addr: Field = mem.read(1)
    let lut_addr: Field = mem.read(13)
    let pbs_out_addr: Field = mem.read(23)
    let delta: Field = mem.read(6)
    let lwe_n: Field = mem.read(7)
    let ring_n: Field = mem.read(24)
    let domain: Field = mem.read(19)
    let pbs_acc_addr: Field = mem.read(25)
    let pbs_test_addr: Field = mem.read(26)
    let pbs_tmp_addr: Field = mem.read(27)
    let pbs_expected_m: Field = mem.read(28)
    inference.pbs_demo(
        pbs_sample_ct, s_addr, lut_addr, pbs_out_addr,
        delta, lwe_n, ring_n, domain,
        pbs_acc_addr, pbs_test_addr, pbs_tmp_addr,
        pbs_expected_m
    )
}

fn run_phase5_quantum() {
    let activated_addr: Field = mem.read(12)
    let neurons: Field = mem.read(9)
    let class: Field = tensor.argmax(activated_addr, neurons)
    let qresult: Bool = inference.quantum_commit(class)
}

fn main() {
    store_params()
    load_round_constants()
    run_phase0_lut()
    run_phase1_encrypt()
    run_phase1b_decrypt()
    run_phase2_neural()
    run_phase3_hash()
    run_phase4_pbs()
    run_phase5_quantum()
    io.write(1)
}

Dimensions

trident/std/trinity/inference.tri

Local Graph