use crate::ir::tir::TIROp;
use crate::neural::data::tir_graph::TirGraph;
use crate::neural::model::vocab::Vocab;
pub struct TrainingPair {
pub graph: TirGraph,
pub target_tokens: Vec<u32>,
pub source_id: String,
pub baseline_cost: u64,
}
pub fn extract_pairs(
blocks: &[(Vec<TIROp>, Vec<String>, String, u64)],
vocab: &Vocab,
) -> Vec<TrainingPair> {
let mut pairs = Vec::new();
for (tir_ops, tasm_lines, source_id, baseline_cost) in blocks {
if tir_ops.is_empty() || tasm_lines.is_empty() {
continue;
}
let graph = TirGraph::from_tir_ops(tir_ops);
if graph.nodes.is_empty() {
continue;
}
let target_tokens = vocab.encode_sequence(tasm_lines);
if target_tokens.len() <= 1 {
continue;
}
pairs.push(TrainingPair {
graph,
target_tokens,
source_id: source_id.clone(),
baseline_cost: *baseline_cost,
});
}
pairs
}
pub fn split_tir_by_function(ops: &[TIROp]) -> Vec<(String, Vec<TIROp>)> {
let mut functions = Vec::new();
let mut current_name = String::new();
let mut current_ops = Vec::new();
let mut in_function = false;
for op in ops {
match op {
TIROp::FnStart(name) => {
if !current_ops.is_empty() && !in_function {
functions.push(("__entry".to_string(), std::mem::take(&mut current_ops)));
}
current_name = name.clone();
current_ops = vec![op.clone()];
in_function = true;
}
TIROp::FnEnd => {
current_ops.push(op.clone());
functions.push((
std::mem::take(&mut current_name),
std::mem::take(&mut current_ops),
));
in_function = false;
}
_ => {
current_ops.push(op.clone());
}
}
}
if !current_ops.is_empty() {
let name = if in_function && !current_name.is_empty() {
current_name
} else {
"__trailing".to_string()
};
functions.push((name, current_ops));
}
functions
}
pub fn train_holdout_split(
pairs: Vec<TrainingPair>,
holdout_count: usize,
) -> (Vec<TrainingPair>, Vec<TrainingPair>) {
if pairs.len() <= holdout_count {
return (pairs, Vec::new());
}
let split_point = pairs.len() - holdout_count;
let mut all = pairs;
let holdout = all.split_off(split_point);
(all, holdout)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::tir::TIROp;
#[test]
fn split_tir_by_function_basic() {
let ops = vec![
TIROp::Entry("main".into()),
TIROp::Call("foo".into()),
TIROp::Halt,
TIROp::FnStart("foo".into()),
TIROp::Push(1),
TIROp::Push(2),
TIROp::Add,
TIROp::Return,
TIROp::FnEnd,
TIROp::FnStart("bar".into()),
TIROp::Push(3),
TIROp::Return,
TIROp::FnEnd,
];
let functions = split_tir_by_function(&ops);
assert_eq!(functions.len(), 3); assert_eq!(functions[0].0, "__entry");
assert_eq!(functions[1].0, "foo");
assert_eq!(functions[2].0, "bar");
assert_eq!(functions[1].1.len(), 6);
assert_eq!(functions[2].1.len(), 4);
}
#[test]
fn extract_pairs_basic() {
let vocab = Vocab::new();
let blocks = vec![(
vec![TIROp::Push(1), TIROp::Push(2), TIROp::Add],
vec!["push 1".into(), "push 2".into(), "add".into()],
"test:0..3".into(),
3u64,
)];
let pairs = extract_pairs(&blocks, &vocab);
assert_eq!(pairs.len(), 1);
assert_eq!(pairs[0].source_id, "test:0..3");
assert_eq!(pairs[0].baseline_cost, 3);
assert_eq!(*pairs[0].target_tokens.last().unwrap(), 0);
}
#[test]
fn extract_pairs_skips_empty() {
let vocab = Vocab::new();
let blocks: Vec<(Vec<TIROp>, Vec<String>, String, u64)> = vec![
(vec![], vec!["push 1".into()], "empty_tir".into(), 0),
(vec![TIROp::Push(1)], vec![], "empty_tasm".into(), 0),
];
let pairs = extract_pairs(&blocks, &vocab);
assert_eq!(pairs.len(), 0);
}
#[test]
fn train_holdout_split_works() {
let vocab = Vocab::new();
let blocks: Vec<_> = (0..10)
.map(|i| {
(
vec![TIROp::Push(i as u64)],
vec![format!("push {}", i)],
format!("block:{}", i),
1u64,
)
})
.collect();
let pairs = extract_pairs(&blocks, &vocab);
let (train, holdout) = train_holdout_split(pairs, 3);
assert_eq!(train.len(), 7);
assert_eq!(holdout.len(), 3);
}
}