use criterion::{black_box, criterion_group, criterion_main, Criterion};
use trident::ir::tir::TIROp;
use trident::neural::data::tir_graph::TirGraph;
use trident::neural::inference::beam::{beam_search, BeamConfig};
use trident::neural::inference::execute::validate_and_rank;
use trident::neural::model::composite::NeuralCompilerConfig;
use trident::neural::model::vocab::Vocab;
use trident::neural::training::supervised::{graph_to_edges, graph_to_features};
use burn::backend::NdArray;
type B = NdArray;
fn synthetic_tir(n: usize) -> Vec<TIROp> {
let mut ops = Vec::with_capacity(n);
for i in 0..n {
match i % 5 {
0 => ops.push(TIROp::Push(i as u64)),
1 => ops.push(TIROp::Add),
2 => ops.push(TIROp::Mul),
3 => ops.push(TIROp::Dup(0)),
4 => ops.push(TIROp::Swap(1)),
_ => unreachable!(),
}
}
ops
}
fn bench_graph_build(c: &mut Criterion) {
let ops_50 = synthetic_tir(50);
let ops_100 = synthetic_tir(100);
let mut group = c.benchmark_group("graph_build");
group.bench_function("50_ops", |b| {
b.iter(|| TirGraph::from_tir_ops(black_box(&ops_50)))
});
group.bench_function("100_ops", |b| {
b.iter(|| TirGraph::from_tir_ops(black_box(&ops_100)))
});
group.finish();
}
fn bench_gnn_encoder(c: &mut Criterion) {
let device = Default::default();
let config = NeuralCompilerConfig {
d_model: 256,
d_edge: 32,
gnn_layers: 4,
decoder_layers: 6,
n_heads: 8,
d_ff: 1024,
max_seq: 256,
dropout: 0.0,
};
let model = config.init::<B>(&device);
let ops = synthetic_tir(100);
let graph = TirGraph::from_tir_ops(&ops);
let node_features = graph_to_features::<B>(&graph, &device);
let (edge_src, edge_dst, edge_types) = graph_to_edges::<B>(&graph, &device);
c.bench_function("gnn_encoder_100_nodes", |b| {
b.iter(|| {
model.encoder.forward(
black_box(node_features.clone()),
black_box(edge_src.clone()),
black_box(edge_dst.clone()),
black_box(edge_types.clone()),
)
})
});
}
fn bench_beam_search(c: &mut Criterion) {
let device = Default::default();
let config = NeuralCompilerConfig {
d_model: 64,
d_edge: 16,
gnn_layers: 2,
decoder_layers: 2,
n_heads: 4,
d_ff: 128,
max_seq: 64,
dropout: 0.0,
};
let model = config.init::<B>(&device);
let ops = synthetic_tir(20);
let graph = TirGraph::from_tir_ops(&ops);
let node_features = graph_to_features::<B>(&graph, &device);
let (edge_src, edge_dst, edge_types) = graph_to_edges::<B>(&graph, &device);
let beam_config = BeamConfig {
k: 8, max_steps: 16,
};
c.bench_function("beam_search_k8_steps16", |b| {
b.iter(|| {
beam_search(
black_box(&model.encoder),
black_box(&model.decoder),
black_box(node_features.clone()),
black_box(edge_src.clone()),
black_box(edge_dst.clone()),
black_box(edge_types.clone()),
black_box(&beam_config),
0,
&device,
)
})
});
}
fn bench_validation(c: &mut Criterion) {
let vocab = Vocab::new();
let baseline: Vec<String> = vec![
"push 1".into(),
"push 2".into(),
"add".into(),
"push 3".into(),
"mul".into(),
];
let candidates: Vec<Vec<u32>> = (0..32)
.map(|i| {
vec![
(i % 10 + 1) as u32, (i % 5 + 1) as u32,
]
})
.collect();
c.bench_function("validate_32_candidates", |b| {
b.iter(|| {
validate_and_rank(
black_box(&candidates),
black_box(&vocab),
black_box(&baseline),
42,
)
})
});
}
fn bench_end_to_end(c: &mut Criterion) {
let device = Default::default();
let config = NeuralCompilerConfig {
d_model: 64,
d_edge: 16,
gnn_layers: 2,
decoder_layers: 2,
n_heads: 4,
d_ff: 128,
max_seq: 64,
dropout: 0.0,
};
let model = config.init::<B>(&device);
let vocab = Vocab::new();
let ops = synthetic_tir(20);
let baseline: Vec<String> = vec!["push 1".into(), "push 2".into(), "add".into()];
let beam_config = BeamConfig {
k: 8,
max_steps: 16,
};
c.bench_function("end_to_end_20ops_k8", |b| {
b.iter(|| {
let graph = TirGraph::from_tir_ops(black_box(&ops));
let node_features = graph_to_features::<B>(&graph, &device);
let (edge_src, edge_dst, edge_types) = graph_to_edges::<B>(&graph, &device);
let result = beam_search(
&model.encoder,
&model.decoder,
node_features,
edge_src,
edge_dst,
edge_types,
&beam_config,
0,
&device,
);
let _ranked = validate_and_rank(&result.sequences, &vocab, &baseline, 42);
})
});
}
criterion_group!(
benches,
bench_graph_build,
bench_gnn_encoder,
bench_beam_search,
bench_validation,
bench_end_to_end,
);
criterion_main!(benches);