program trinity_bench
// Benchmark program: calls Trinity phases individually.
// All parameters passed via RAM to avoid deep stack issues.
// Compiled to linked TASM with all dependencies resolved.
use vm.io.io
use vm.io.mem
use vm.core.field
use vm.core.convert
use std.math.lut
use std.trinity.inference
use std.nn.tensor
// Store all 29 parameters from public input into RAM config block.
// RAM addresses 0-28 hold the config.
fn store_params() {
mem.write(0, io.read()) // cts_addr
mem.write(1, io.read()) // s_addr
mem.write(2, io.read()) // w_priv_addr
mem.write(3, io.read()) // ct_out_addr
mem.write(4, io.read()) // tmp_addr
mem.write(5, io.read()) // result_addr
mem.write(6, io.read()) // delta
mem.write(7, io.read()) // lwe_n
mem.write(8, io.read()) // input_dim
mem.write(9, io.read()) // neurons
mem.write(10, io.read()) // dense_w_addr
mem.write(11, io.read()) // dense_b_addr
mem.write(12, io.read()) // activated_addr
mem.write(13, io.read()) // lut_addr
mem.write(14, io.read()) // expected_class
mem.write(15, io.read()) // rc_addr
mem.write(16, io.read()) // weights_digest
mem.write(17, io.read()) // key_digest
mem.write(18, io.read()) // expected_digest
mem.write(19, io.read()) // domain
mem.write(20, io.read()) // sponge_rc_addr
mem.write(21, io.read()) // expected_lut_digest
mem.write(22, io.read()) // pbs_sample_ct
mem.write(23, io.read()) // pbs_out_addr
mem.write(24, io.read()) // ring_n
mem.write(25, io.read()) // pbs_acc_addr
mem.write(26, io.read()) // pbs_test_addr
mem.write(27, io.read()) // pbs_tmp_addr
mem.write(28, io.read()) // pbs_expected_m
}
// Load Poseidon2 (86) and LUT sponge (112) round constants from public input
// into RAM at rc_addr and sponge_rc_addr respectively.
fn load_round_constants() {
let rc_addr: Field = mem.read(15)
for i in 0..86 bounded 86 {
let idx: Field = convert.as_field(i)
mem.write(rc_addr + idx, io.read())
}
let sponge_rc_addr: Field = mem.read(20)
for i in 0..112 bounded 112 {
let idx: Field = convert.as_field(i)
mem.write(sponge_rc_addr + idx, io.read())
}
}
fn run_phase0_lut() {
let lut_addr: Field = mem.read(13)
let domain: Field = mem.read(19)
let half_size: Field = domain * field.inv(2)
lut.build_relu(lut_addr, half_size, domain)
}
fn run_phase1_encrypt() {
let cts_addr: Field = mem.read(0)
let w_priv_addr: Field = mem.read(2)
let ct_out_addr: Field = mem.read(3)
let tmp_addr: Field = mem.read(4)
let lwe_n: Field = mem.read(7)
let input_dim: Field = mem.read(8)
let neurons: Field = mem.read(9)
inference.private_linear(cts_addr, w_priv_addr, ct_out_addr, tmp_addr, lwe_n, input_dim, neurons)
}
fn run_phase1b_decrypt() {
let ct_out_addr: Field = mem.read(3)
let s_addr: Field = mem.read(1)
let result_addr: Field = mem.read(5)
let delta: Field = mem.read(6)
let lwe_n: Field = mem.read(7)
let neurons: Field = mem.read(9)
inference.decrypt_outputs(ct_out_addr, s_addr, result_addr, delta, lwe_n, neurons)
}
fn run_phase2_neural() {
let dense_w_addr: Field = mem.read(10)
let result_addr: Field = mem.read(5)
let dense_b_addr: Field = mem.read(11)
let activated_addr: Field = mem.read(12)
let tmp_addr: Field = mem.read(4)
let lut_addr: Field = mem.read(13)
let neurons: Field = mem.read(9)
inference.dense_layer(dense_w_addr, result_addr, dense_b_addr, activated_addr, tmp_addr, lut_addr, neurons)
}
fn run_phase3_hash() {
let activated_addr: Field = mem.read(12)
let neurons: Field = mem.read(9)
let weights_digest: Field = mem.read(16)
let key_digest: Field = mem.read(17)
let class: Field = tensor.argmax(activated_addr, neurons)
// LUT sponge hash
let lut_addr: Field = mem.read(13)
let domain: Field = mem.read(19)
let sponge_rc_addr: Field = mem.read(20)
let lut_digest: Field = inference.lut_hash_commit(
activated_addr, neurons, weights_digest, key_digest,
class, lut_addr, domain, sponge_rc_addr
)
// Poseidon2 hash
let rc_addr: Field = mem.read(15)
let digest: Field = inference.hash_commit(activated_addr, neurons, weights_digest, key_digest, class, rc_addr)
}
fn run_phase4_pbs() {
let pbs_sample_ct: Field = mem.read(22)
let s_addr: Field = mem.read(1)
let lut_addr: Field = mem.read(13)
let pbs_out_addr: Field = mem.read(23)
let delta: Field = mem.read(6)
let lwe_n: Field = mem.read(7)
let ring_n: Field = mem.read(24)
let domain: Field = mem.read(19)
let pbs_acc_addr: Field = mem.read(25)
let pbs_test_addr: Field = mem.read(26)
let pbs_tmp_addr: Field = mem.read(27)
let pbs_expected_m: Field = mem.read(28)
inference.pbs_demo(
pbs_sample_ct, s_addr, lut_addr, pbs_out_addr,
delta, lwe_n, ring_n, domain,
pbs_acc_addr, pbs_test_addr, pbs_tmp_addr,
pbs_expected_m
)
}
fn run_phase5_quantum() {
let activated_addr: Field = mem.read(12)
let neurons: Field = mem.read(9)
let class: Field = tensor.argmax(activated_addr, neurons)
let qresult: Bool = inference.quantum_commit(class)
}
fn main() {
store_params()
load_round_constants()
run_phase0_lut()
run_phase1_encrypt()
run_phase1b_decrypt()
run_phase2_neural()
run_phase3_hash()
run_phase4_pbs()
run_phase5_quantum()
io.write(1)
}
trident/benches/harnesses/std/trinity/inference.tri
ฯ 0.0%