pub mod checkpoint;
pub mod data;
pub mod inference;
pub mod model;
pub mod training;
use burn::backend::Wgpu;
use crate::ir::tir::TIROp;
use data::tir_graph::TirGraph;
use inference::beam::{beam_search, BeamConfig};
use inference::execute::validate_and_rank;
use model::vocab::Vocab;
use training::supervised::{graph_to_edges, graph_to_features};
pub struct CompileResult {
pub tasm_lines: Vec<String>,
pub cost: u64,
pub valid_count: usize,
pub total_count: usize,
pub neural: bool,
}
pub fn compile(tir_ops: &[TIROp], baseline_tasm: &[String]) -> Result<CompileResult, String> {
let device = burn::backend::wgpu::WgpuDevice::default();
compile_with_device::<Wgpu>(tir_ops, baseline_tasm, &device)
}
pub fn compile_with_device<B: burn::prelude::Backend>(
tir_ops: &[TIROp],
baseline_tasm: &[String],
device: &B::Device,
) -> Result<CompileResult, String> {
let vocab = Vocab::new();
let graph = TirGraph::from_tir_ops(tir_ops);
if graph.nodes.is_empty() {
return Ok(fallback_result(baseline_tasm));
}
let config = model::composite::NeuralCompilerConfig::new();
let model = config.init::<B>(device);
let model =
match checkpoint::load_checkpoint(model, checkpoint::CheckpointTag::Production, device) {
Ok(Some(loaded)) => loaded,
Ok(None) => {
let model2 = config.init::<B>(device);
match checkpoint::load_checkpoint(
model2,
checkpoint::CheckpointTag::Stage1Best,
device,
) {
Ok(Some(loaded)) => loaded,
_ => return Ok(fallback_result(baseline_tasm)),
}
}
Err(_) => return Ok(fallback_result(baseline_tasm)),
};
let node_features = graph_to_features::<B>(&graph, device);
let (edge_src, edge_dst, edge_types) = graph_to_edges::<B>(&graph, device);
let beam_config = BeamConfig::default(); let beam_result = beam_search(
&model.encoder,
&model.decoder,
node_features,
edge_src,
edge_dst,
edge_types,
&beam_config,
0, device,
);
match validate_and_rank(&beam_result.sequences, &vocab, baseline_tasm, 0) {
Some(ranked) => Ok(CompileResult {
tasm_lines: ranked.tasm_lines,
cost: ranked.cost,
valid_count: ranked.valid_count,
total_count: ranked.total_count,
neural: true,
}),
None => Ok(fallback_result(baseline_tasm)),
}
}
pub fn load_model<B: burn::prelude::Backend>(
device: &B::Device,
) -> Option<model::composite::NeuralCompilerV2<B>> {
let config = model::composite::NeuralCompilerConfig::new();
let m = config.init::<B>(device);
match checkpoint::load_checkpoint(m, checkpoint::CheckpointTag::Production, device) {
Ok(Some(loaded)) => Some(loaded),
Ok(None) => {
let m2 = config.init::<B>(device);
match checkpoint::load_checkpoint(m2, checkpoint::CheckpointTag::Stage1Best, device) {
Ok(Some(loaded)) => Some(loaded),
_ => None,
}
}
Err(_) => None,
}
}
pub fn compile_with_model<B: burn::prelude::Backend>(
tir_ops: &[TIROp],
baseline_tasm: &[String],
model: &model::composite::NeuralCompilerV2<B>,
device: &B::Device,
) -> Result<CompileResult, String> {
let vocab = Vocab::new();
let graph = TirGraph::from_tir_ops(tir_ops);
if graph.nodes.is_empty() {
return Ok(fallback_result(baseline_tasm));
}
let node_features = graph_to_features::<B>(&graph, device);
let (edge_src, edge_dst, edge_types) = graph_to_edges::<B>(&graph, device);
let beam_config = BeamConfig::default();
let beam_result = beam_search(
&model.encoder,
&model.decoder,
node_features,
edge_src,
edge_dst,
edge_types,
&beam_config,
0,
device,
);
match validate_and_rank(&beam_result.sequences, &vocab, baseline_tasm, 0) {
Some(ranked) => Ok(CompileResult {
tasm_lines: ranked.tasm_lines,
cost: ranked.cost,
valid_count: ranked.valid_count,
total_count: ranked.total_count,
neural: true,
}),
None => Ok(fallback_result(baseline_tasm)),
}
}
fn fallback_result(baseline_tasm: &[String]) -> CompileResult {
use crate::cost::scorer::profile_tasm;
let refs: Vec<&str> = baseline_tasm.iter().map(|s| s.as_str()).collect();
let cost = profile_tasm(&refs).cost();
CompileResult {
tasm_lines: baseline_tasm.to_vec(),
cost,
valid_count: 0,
total_count: 0,
neural: false,
}
}