use std::path::Path;
use std::process;
use std::cell::RefCell;
use clap::{Args, Subcommand};
thread_local! {
static BEAM_DIAGNOSTIC: RefCell<Option<String>> = RefCell::new(None);
}
#[derive(Args)]
pub struct TrainArgs {
#[command(subcommand)]
pub action: Option<TrainAction>,
/// Epochs over the full corpus (default: 10)
#[arg(short, long, default_value = "10")]
pub epochs: u64,
/// Disable GPU (use CPU for training)
#[arg(long)]
pub cpu: bool,
/// Force a specific stage (1=supervised, 2=gflownet)
#[arg(long)]
pub stage: Option<u32>,
}
#[derive(Subcommand)]
pub enum TrainAction {
/// Delete all neural weights and generated .neural.tasm files
Reset,
}
/// Pre-compiled file data โ TIR + baselines, computed once.
struct CompiledFile {
path: String,
tir_ops: Vec<trident::tir::TIROp>,
tasm_lines: Vec<String>,
baseline_cost: u64,
}
/// Per-file eval result after beam search.
struct FileEval {
total_blocks: usize,
decoded: usize,
checked: usize,
proven: usize,
wins: usize,
checked_cost: u64,
checked_baseline: u64,
}
pub fn cmd_train(args: TrainArgs) {
if let Some(TrainAction::Reset) = args.action {
cmd_train_reset();
return;
}
use trident::neural::checkpoint::{self, TrainingStage};
use trident::neural::data::pairs::extract_pairs;
use trident::neural::model::composite::NeuralCompilerConfig;
use trident::neural::model::vocab::Vocab;
let corpus = discover_corpus();
if corpus.is_empty() {
eprintln!("error: no .tri files found in vm/, std/, os/");
process::exit(1);
}
eprintln!("trident train");
eprintln!(" compiling corpus...");
let _guard = trident::diagnostic::suppress_warnings();
let compiled = compile_corpus(&corpus);
drop(_guard);
let total_baseline: u64 = compiled.iter().map(|c| c.baseline_cost).sum();
let config = NeuralCompilerConfig::new();
let vocab = Vocab::new();
let blocks: Vec<(Vec<trident::tir::TIROp>, Vec<String>, String, u64)> = compiled
.iter()
.map(|cf| {
(
cf.tir_ops.clone(),
cf.tasm_lines.clone(),
cf.path.clone(),
cf.baseline_cost,
)
})
.collect();
let raw_pairs = extract_pairs(&blocks, &vocab);
let total_blocks: usize = compiled.len();
if raw_pairs.is_empty() {
eprintln!("error: no training pairs extracted from corpus");
process::exit(1);
}
// Holdout split: reserve ~5% (min 5) of raw pairs for validation.
// Split before augmentation so holdout is unseen during training.
let raw_count = raw_pairs.len();
let holdout_count = (raw_count / 20).max(5).min(raw_count);
let (train_raw, holdout) =
trident::neural::data::pairs::train_holdout_split(raw_pairs, holdout_count);
// Data augmentation: expand training set via TASM random walk,
// equivalent substitutions, and dead code insertion.
eprintln!(" augmenting {} train pairs...", train_raw.len());
let augment_config = trident::neural::training::augment::AugmentConfig {
tir_reorder_variants: 3,
tasm_walk_variants: 5,
max_swap_attempts: 10,
seed: 0xDEAD_BEEF_A097,
};
let pairs =
trident::neural::training::augment::augment_pairs(&train_raw, &vocab, &augment_config);
eprintln!(" holdout: {} pairs", holdout.len());
// Stage selection: explicit --stage flag, or default to Stage 1.
// Auto-detection was causing problems (stale checkpoint from broken run
// would skip Stage 1 forever). User controls stage transitions.
let stage = match args.stage {
Some(1) => TrainingStage::Stage1Supervised,
Some(2) => TrainingStage::Stage2GFlowNet,
Some(3) => TrainingStage::Stage3Online,
_ => TrainingStage::Stage1Supervised,
};
// Show target sequence length stats
let mut lens: Vec<usize> = pairs.iter().map(|p| p.target_tokens.len()).collect();
lens.sort();
let median = lens[lens.len() / 2];
let min_len = lens[0];
let max_len = lens[lens.len() - 1];
let device_tag = if args.cpu { "CPU" } else { "GPU" };
eprintln!(
" corpus {} files, {} raw pairs, {} augmented, {} blocks",
corpus.len(),
raw_count,
pairs.len(),
total_blocks,
);
eprintln!(
" targets len min={} median={} max={} (tokens incl EOS)",
min_len, median, max_len,
);
eprintln!(" baseline {} total cost", total_baseline);
eprintln!(
" model ~{}M params | v2 GNN+Transformer | {}",
config.param_estimate() / 1_000_000,
device_tag,
);
eprintln!(
" schedule {} epochs, {} (use --stage 2 for GFlowNet)",
args.epochs, stage,
);
// Show existing checkpoints
let ckpts = checkpoint::available_checkpoints();
if !ckpts.is_empty() {
for (tag, path) in &ckpts {
eprintln!(" checkpoint {:?} -> {}", tag, path.display());
}
}
eprintln!();
if args.cpu {
use burn::backend::Autodiff;
use burn::backend::NdArray;
type B = Autodiff<NdArray>;
let device = Default::default();
let model = config.init::<B>(&device);
run_training_loop::<B>(
model,
&pairs,
&holdout,
&compiled,
&vocab,
args.epochs,
&device,
stage,
);
} else {
use burn::backend::wgpu::{Wgpu, WgpuDevice};
use burn::backend::Autodiff;
type B = Autodiff<Wgpu>;
let device = WgpuDevice::default();
let model = config.init::<B>(&device);
run_training_loop::<B>(
model,
&pairs,
&holdout,
&compiled,
&vocab,
args.epochs,
&device,
stage,
);
}
}
/// Main training loop โ generic over backend.
fn run_training_loop<B: burn::tensor::backend::AutodiffBackend>(
model: trident::neural::model::composite::NeuralCompilerV2<B>,
pairs: &[trident::neural::data::pairs::TrainingPair],
holdout: &[trident::neural::data::pairs::TrainingPair],
compiled: &[CompiledFile],
vocab: &trident::neural::model::vocab::Vocab,
epochs: u64,
device: &B::Device,
stage: trident::neural::checkpoint::TrainingStage,
) where
<B as burn::tensor::backend::Backend>::FloatElem: From<f32>,
{
use trident::neural::checkpoint::{self, CheckpointTag, TrainingStage};
// Try loading existing checkpoint
let load_tag = match stage {
TrainingStage::Stage1Supervised => CheckpointTag::Stage1Best,
TrainingStage::Stage2GFlowNet => CheckpointTag::Stage1Best,
TrainingStage::Stage3Online => CheckpointTag::Stage2Latest,
};
let model = match checkpoint::load_checkpoint(model, load_tag, device) {
Ok(Some(loaded)) => {
eprintln!(" loaded checkpoint {:?}", load_tag);
loaded
}
Ok(None) => {
eprintln!(" no checkpoint found, training from scratch");
trident::neural::model::composite::NeuralCompilerConfig::new().init::<B>(device)
}
Err(e) => {
eprintln!(" warning: checkpoint load failed: {}", e);
trident::neural::model::composite::NeuralCompilerConfig::new().init::<B>(device)
}
};
match stage {
TrainingStage::Stage1Supervised => {
run_stage1(model, pairs, holdout, compiled, vocab, epochs, device);
}
TrainingStage::Stage2GFlowNet => {
run_stage2(model, compiled, vocab, epochs, device);
}
TrainingStage::Stage3Online => {
run_stage2(model, compiled, vocab, epochs, device);
}
}
}
/// Stage 1: Supervised pre-training with cosine LR decay.
fn run_stage1<B: burn::tensor::backend::AutodiffBackend>(
model: trident::neural::model::composite::NeuralCompilerV2<B>,
pairs: &[trident::neural::data::pairs::TrainingPair],
holdout: &[trident::neural::data::pairs::TrainingPair],
compiled: &[CompiledFile],
vocab: &trident::neural::model::vocab::Vocab,
epochs: u64,
device: &B::Device,
) where
<B as burn::tensor::backend::Backend>::FloatElem: From<f32>,
{
use burn::module::AutodiffModule;
use trident::neural::checkpoint::{self, CheckpointTag};
use trident::neural::inference::beam::BeamConfig;
use trident::neural::training::supervised;
use trident::neural::training::supervised::cosine_lr;
let sup_config = supervised::SupervisedConfig::default();
let mut optimizer = supervised::create_optimizer::<B>(&sup_config);
let start = std::time::Instant::now();
let mut model = model;
let mut best_loss = f32::INFINITY;
let mut stale_epochs = 0usize;
let mut prev_table_lines: usize = 0;
let mut phase_a_reached = false;
// Small beam for eval during training โ full K=32 is too slow (197s/epoch).
// K=4, max_steps=64 gives fast feedback; production inference uses K=32.
// min_tokens=1: many target functions are 1-2 instructions; forcing 5+ tokens
// prevents the model from outputting correct short sequences.
let beam_config = BeamConfig {
k: 4,
max_steps: 64,
min_tokens: 1,
..Default::default()
};
for epoch in 0..epochs {
let epoch_start = std::time::Instant::now();
// Cosine LR decay
let lr = cosine_lr(&sup_config, epoch as usize, epochs as usize);
// Train one epoch
let (updated, result) = supervised::train_epoch(model, pairs, &mut optimizer, lr, device);
model = updated;
let improved = result.avg_loss < best_loss;
if improved {
best_loss = result.avg_loss;
stale_epochs = 0;
// Save best checkpoint
if let Err(e) = checkpoint::save_checkpoint(&model, CheckpointTag::Stage1Best, device) {
eprintln!(" warning: checkpoint save failed: {}", e);
}
} else {
stale_epochs += 1;
}
// Evaluate via beam search on compiled files
let inner = model.valid();
let inner_device = device.clone();
let file_evals = eval_files(&inner, compiled, vocab, &beam_config, &inner_device);
// Holdout validity: beam search on unseen pairs, check via stack verifier
let holdout_valid = if !holdout.is_empty() {
eval_holdout_validity(&inner, holdout, vocab, &beam_config, &inner_device)
} else {
(0, 0)
};
let epoch_elapsed = epoch_start.elapsed();
// Display table
prev_table_lines = display_epoch_table(
epoch,
epochs,
&result,
improved,
lr,
&file_evals,
compiled,
epoch_elapsed,
prev_table_lines,
);
// Show holdout validity
if holdout_valid.1 > 0 {
let validity_pct = holdout_valid.0 as f64 / holdout_valid.1 as f64 * 100.0;
eprintln!(
" holdout validity: {}/{} ({:.0}%)\x1B[K",
holdout_valid.0, holdout_valid.1, validity_pct,
);
prev_table_lines += 1;
// Phase A gate: validity >= 80% on holdout
if validity_pct >= 80.0 && !phase_a_reached {
phase_a_reached = true;
eprintln!(
" PHASE A REACHED: validity={:.0}% on holdout (>= 80%)\x1B[K",
validity_pct,
);
prev_table_lines += 1;
// Save production checkpoint
if let Err(e) =
checkpoint::save_checkpoint(&model, CheckpointTag::Production, device)
{
eprintln!(" warning: production checkpoint save failed: {}", e);
} else {
eprintln!(" saved production checkpoint\x1B[K");
prev_table_lines += 1;
}
}
}
// Early stopping
if stale_epochs >= sup_config.patience {
eprintln!(
" early stopping: no improvement for {} epochs",
sup_config.patience,
);
break;
}
}
let elapsed = start.elapsed();
eprintln!();
eprintln!(
" Stage 1 done in {:.1}s, best loss: {:.4}",
elapsed.as_secs_f64(),
best_loss
);
}
/// Evaluate holdout validity: beam search on each holdout pair, check if
/// the best candidate is equivalent to the target TASM via stack verifier.
/// Returns (valid_count, total_count).
fn eval_holdout_validity<B: burn::prelude::Backend>(
model: &trident::neural::model::composite::NeuralCompilerV2<B>,
holdout: &[trident::neural::data::pairs::TrainingPair],
vocab: &trident::neural::model::vocab::Vocab,
beam_config: &trident::neural::inference::beam::BeamConfig,
device: &B::Device,
) -> (usize, usize) {
use trident::neural::inference::beam::beam_search;
use trident::neural::inference::execute::validate_and_rank;
use trident::neural::training::supervised::{graph_to_edges, graph_to_features};
let mut valid = 0usize;
let mut total = 0usize;
for pair in holdout {
if pair.graph.nodes.is_empty() {
continue;
}
// Decode target tokens to TASM for equivalence check
let target_tasm: Vec<String> = pair
.target_tokens
.iter()
.filter(|&&t| t != 0) // skip EOS
.filter_map(|&t| vocab.decode(t).map(|s| s.to_string()))
.collect();
if target_tasm.is_empty() {
continue;
}
total += 1;
let node_features = graph_to_features::<B>(&pair.graph, device);
let (edge_src, edge_dst, edge_types) = graph_to_edges::<B>(&pair.graph, device);
let beam_result = beam_search(
&model.encoder,
&model.decoder,
node_features,
edge_src,
edge_dst,
edge_types,
beam_config,
0,
device,
);
// Check if any beam candidate validates against the target
if validate_and_rank(&beam_result.sequences, vocab, &target_tasm, 0).is_some() {
valid += 1;
}
}
(valid, total)
}
/// Stage 2: GFlowNet fine-tuning.
fn run_stage2<B: burn::tensor::backend::AutodiffBackend>(
model: trident::neural::model::composite::NeuralCompilerV2<B>,
compiled: &[CompiledFile],
vocab: &trident::neural::model::vocab::Vocab,
epochs: u64,
device: &B::Device,
) where
<B as burn::tensor::backend::Backend>::FloatElem: From<f32>,
{
use burn::grad_clipping::GradientClippingConfig;
use burn::module::AutodiffModule;
use burn::optim::{AdamWConfig, GradientsParams, Optimizer};
use trident::neural::checkpoint::{self, CheckpointTag};
use trident::neural::inference::beam::BeamConfig;
use trident::neural::training::gflownet::{self, GFlowNetConfig};
let gf_config = GFlowNetConfig::default();
let mut optimizer = AdamWConfig::new()
.with_weight_decay(0.01)
.with_grad_clipping(Some(GradientClippingConfig::Norm(1.0)))
.init();
let beam_config = BeamConfig {
k: 4,
max_steps: 64,
..Default::default()
};
let start = std::time::Instant::now();
let mut model = model;
let mut global_step = 0usize;
let mut prev_table_lines: usize = 0;
let mut total_valid = 0usize;
let mut total_sampled = 0usize;
for epoch in 0..epochs {
let epoch_start = std::time::Instant::now();
let mut epoch_loss = 0.0f32;
let mut epoch_valid = 0usize;
let mut epoch_reward = 0.0f32;
// For each compiled file, sample a sequence and compute TB loss
for cf in compiled.iter() {
let graph = trident::neural::data::tir_graph::TirGraph::from_tir_ops(&cf.tir_ops);
if graph.nodes.is_empty() {
continue;
}
let (loss, reward, valid) = gflownet::gflownet_step(
&model,
&graph,
&cf.tasm_lines,
cf.baseline_cost,
burn::tensor::Tensor::<B, 1>::zeros([1], device), // log_z (simplified)
global_step,
&gf_config,
vocab,
device,
);
let loss_val: f32 = loss.clone().into_data().to_vec::<f32>().unwrap()[0];
epoch_loss += loss_val;
epoch_reward += reward;
if valid {
epoch_valid += 1;
}
total_sampled += 1;
if valid {
total_valid += 1;
}
// Backward + step
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
let lr = 1e-4; // Lower LR for fine-tuning
model = optimizer.step(lr, model, grads);
global_step += 1;
}
// Save checkpoint periodically
if (epoch + 1) % 5 == 0 || epoch == epochs - 1 {
if let Err(e) = checkpoint::save_checkpoint(&model, CheckpointTag::Stage2Latest, device)
{
eprintln!(" warning: checkpoint save failed: {}", e);
}
}
// Evaluate
let inner = model.valid();
let inner_device = device.clone();
let file_evals = eval_files(&inner, compiled, vocab, &beam_config, &inner_device);
let epoch_elapsed = epoch_start.elapsed();
// Aggregate
let epoch_decoded: usize = file_evals.iter().map(|e| e.decoded).sum();
let epoch_checked: usize = file_evals.iter().map(|e| e.checked).sum();
let epoch_wins: usize = file_evals.iter().map(|e| e.wins).sum();
let total_blk: usize = file_evals.iter().map(|e| e.total_blocks).sum();
let epoch_sampled = compiled.len();
let validity_rate = if epoch_sampled > 0 {
epoch_valid as f32 / epoch_sampled as f32 * 100.0
} else {
0.0
};
let num_files = compiled.len().max(1) as f32;
let avg_loss = epoch_loss / num_files;
let avg_reward = epoch_reward / num_files;
let tau = gflownet::temperature_at_step(global_step, &gf_config);
// Display
let active_count2 = file_evals
.iter()
.filter(|e| e.decoded > 0 || e.checked > 0)
.count();
let skipped2 = file_evals.len() - active_count2;
let skipped_line2 = if skipped2 > 0 { 1 } else { 0 };
let table_lines = 1 + 1 + 1 + active_count2 + skipped_line2 + 1;
if epoch > 0 && prev_table_lines > 0 {
eprint!("\x1B[{}A", prev_table_lines);
}
eprintln!(
"\r epoch {}/{} | TB loss: {:.4} | reward: {:.3} | tau: {:.2} | valid: {:.0}% | decoded {}/{} checked {} won {} | {:.1}s\x1B[K",
epoch + 1, epochs, avg_loss, avg_reward, tau, validity_rate,
epoch_decoded, total_blk, epoch_checked, epoch_wins,
epoch_elapsed.as_secs_f64(),
);
display_file_table(&file_evals, compiled);
prev_table_lines = table_lines;
}
let elapsed = start.elapsed();
eprintln!();
eprintln!(
" Stage 2 done in {:.1}s, validity: {:.0}%",
elapsed.as_secs_f64(),
if total_sampled > 0 {
total_valid as f64 / total_sampled as f64 * 100.0
} else {
0.0
},
);
}
/// Evaluate a subset of compiled files via beam search (no grads).
///
/// Only evaluates functions whose TASM target length is โค beam max_steps
/// (longer targets can't be reproduced by the beam). Caps at 50 functions
/// to keep eval under ~30s. Returns one FileEval per compiled file; skipped
/// files get zeroed evals.
fn eval_files<B: burn::prelude::Backend>(
model: &trident::neural::model::composite::NeuralCompilerV2<B>,
compiled: &[CompiledFile],
vocab: &trident::neural::model::vocab::Vocab,
beam_config: &trident::neural::inference::beam::BeamConfig,
device: &B::Device,
) -> Vec<FileEval> {
use trident::neural::inference::beam::beam_search;
use trident::neural::inference::execute::validate_and_rank;
use trident::neural::training::supervised::{graph_to_edges, graph_to_features};
// Pre-filter: only evaluate functions that the beam could realistically match.
// Sort by target length (shortest first) and cap at 50 to keep eval fast.
let max_eval = 50;
let mut eval_indices: Vec<usize> = (0..compiled.len())
.filter(|&i| compiled[i].tasm_lines.len() <= beam_config.max_steps)
.collect();
eval_indices.sort_by_key(|&i| compiled[i].tasm_lines.len());
eval_indices.truncate(max_eval);
let eval_set: std::collections::HashSet<usize> = eval_indices.iter().copied().collect();
let mut evals = Vec::with_capacity(compiled.len());
for (file_idx, cf) in compiled.iter().enumerate() {
// Skip functions too long for beam or beyond eval cap
if !eval_set.contains(&file_idx) {
evals.push(FileEval {
total_blocks: cf.tasm_lines.len(),
decoded: 0,
checked: 0,
proven: 0,
wins: 0,
checked_cost: 0,
checked_baseline: 0,
});
continue;
}
let graph = trident::neural::data::tir_graph::TirGraph::from_tir_ops(&cf.tir_ops);
if graph.nodes.is_empty() {
evals.push(FileEval {
total_blocks: cf.tasm_lines.len(),
decoded: 0,
checked: 0,
proven: 0,
wins: 0,
checked_cost: 0,
checked_baseline: 0,
});
continue;
}
let node_features = graph_to_features::<B>(&graph, device);
let (edge_src, edge_dst, edge_types) = graph_to_edges::<B>(&graph, device);
let beam_result = beam_search(
&model.encoder,
&model.decoder,
node_features,
edge_src,
edge_dst,
edge_types,
beam_config,
0, // must match training initial_stack_depth
device,
);
let ranked = validate_and_rank(
&beam_result.sequences,
vocab,
&cf.tasm_lines,
file_idx as u64,
);
let decoded = beam_result
.sequences
.iter()
.filter(|s| !s.is_empty())
.count();
// Diagnostic: save top beam for first 3 evaluated files to see if model differentiates
let first_3: Vec<usize> = eval_indices.iter().take(3).copied().collect();
if first_3.contains(&file_idx) && !beam_result.sequences.is_empty() {
let top = &beam_result.sequences[0];
let tasm = vocab.decode_sequence(top);
let preview: Vec<&str> = tasm.iter().map(|s| s.as_str()).take(10).collect();
let target_preview: Vec<&str> =
cf.tasm_lines.iter().map(|s| s.as_str()).take(5).collect();
let line = format!(
"{}: beam={} [{}] target=[{}]",
cf.path,
top.len(),
preview.join(", "),
target_preview.join(", "),
);
BEAM_DIAGNOSTIC.with(|d| {
let mut val = d.borrow_mut();
if let Some(existing) = val.as_mut() {
existing.push_str(&format!("\n {}", line));
} else {
*val = Some(line);
}
});
}
if let Some(r) = ranked {
let wins = if r.cost < cf.baseline_cost { 1 } else { 0 };
evals.push(FileEval {
total_blocks: cf.tasm_lines.len(),
decoded,
checked: r.valid_count,
proven: r.valid_count,
wins,
checked_cost: r.cost,
checked_baseline: cf.baseline_cost,
});
} else {
evals.push(FileEval {
total_blocks: cf.tasm_lines.len(),
decoded,
checked: 0,
proven: 0,
wins: 0,
checked_cost: 0,
checked_baseline: 0,
});
}
}
evals
}
/// Display epoch summary + per-file table. Returns number of lines for overwriting.
fn display_epoch_table(
epoch: u64,
total_epochs: u64,
result: &trident::neural::training::supervised::EpochResult,
improved: bool,
lr: f64,
file_evals: &[FileEval],
compiled: &[CompiledFile],
elapsed: std::time::Duration,
prev_table_lines: usize,
) -> usize {
let epoch_decoded: usize = file_evals.iter().map(|e| e.decoded).sum();
let epoch_checked: usize = file_evals.iter().map(|e| e.checked).sum();
let epoch_proven: usize = file_evals.iter().map(|e| e.proven).sum();
let epoch_wins: usize = file_evals.iter().map(|e| e.wins).sum();
let total_blk: usize = file_evals.iter().map(|e| e.total_blocks).sum();
let loss_marker = if improved { " *" } else { "" };
// Pick up beam diagnostic if available
let diag = BEAM_DIAGNOSTIC.with(|d| d.borrow_mut().take());
let diag_lines = diag.as_ref().map_or(0, |d| d.lines().count());
let active_count = file_evals
.iter()
.filter(|e| e.decoded > 0 || e.checked > 0)
.count();
let skipped = file_evals.len() - active_count;
let skipped_line = if skipped > 0 { 1 } else { 0 };
let table_lines = 1 + diag_lines + 1 + 1 + active_count + skipped_line + 1;
if epoch > 0 && prev_table_lines > 0 {
eprint!("\x1B[{}A", prev_table_lines);
}
eprintln!(
"\r epoch {}/{} | loss: {:.4}{} | lr: {:.1e} | decoded {}/{} | checked {} proven {} won {} | {:.1}s\x1B[K",
epoch + 1, total_epochs, result.avg_loss, loss_marker, lr,
epoch_decoded, total_blk, epoch_checked, epoch_proven, epoch_wins,
elapsed.as_secs_f64(),
);
if let Some(d) = diag {
for line in d.lines() {
eprintln!(" {}\x1B[K", line);
}
}
display_file_table(file_evals, compiled);
table_lines
}
/// Per-file table rows (shared between Stage 1 and Stage 2).
fn display_file_table(file_evals: &[FileEval], compiled: &[CompiledFile]) {
let mut sorted: Vec<_> = file_evals
.iter()
.enumerate()
.map(|(i, e)| {
(
compiled[i].path.as_str(),
e.total_blocks,
e.decoded,
e.checked,
e.proven,
e.wins,
e.checked_cost,
e.checked_baseline,
)
})
.collect();
sorted.sort_by(|a, b| b.5.cmp(&a.5).then(b.3.cmp(&a.3)).then(b.2.cmp(&a.2)));
eprintln!(
" {:<60} | {:>9} | {:>7} {:>8} {:>7} {:>5} | {:>15}\x1B[K",
"Module", "Blocks", "Decoded", "Checked", "Proven", "Won", "Cost (ratio)"
);
eprintln!(" {}\x1B[K", "-".repeat(122));
// Only show files that were actually evaluated (decoded > 0 or checked > 0)
let active: Vec<_> = sorted
.iter()
.filter(|s| s.2 > 0 || s.3 > 0)
.copied()
.collect::<Vec<_>>();
for &(path, total, decoded, checked, proven, wins, vc, vb) in &active {
let blocks_col = format!("{}/{}", total, total);
let cost_col = if checked > 0 {
let vr = vc as f64 / vb.max(1) as f64;
format!("{}/{} ({:.2}x)", vc, vb, vr)
} else {
"\u{2013}".to_string()
};
eprintln!(
" {:<60} | {:>9} | {:>7} {:>8} {:>7} {:>5} | {:>15}\x1B[K",
path, blocks_col, decoded, checked, proven, wins, cost_col,
);
}
let skipped = sorted.len() - active.len();
if skipped > 0 {
eprintln!(
" ... {} functions skipped (target > max_steps)\x1B[K",
skipped
);
}
eprintln!(" {}\x1B[K", "-".repeat(122));
}
fn cmd_train_reset() {
let repo_root = find_repo_root();
let mut deleted = 0usize;
// Delete neural weights
for subdir in &["model/general", "model/general/v2"] {
let dir = repo_root.join(subdir);
if dir.exists() {
if let Err(e) = std::fs::remove_dir_all(&dir) {
eprintln!("error: failed to delete {}: {}", dir.display(), e);
process::exit(1);
}
eprintln!(
" deleted {}",
dir.strip_prefix(&repo_root).unwrap_or(&dir).display()
);
deleted += 1;
}
}
// Delete all .neural.tasm files under baselines/
let benches_dir = repo_root.join("baselines");
if benches_dir.exists() {
for entry in walkdir(&benches_dir) {
if entry.extension().and_then(|e| e.to_str()) == Some("tasm") {
let name = entry.file_name().unwrap_or_default().to_string_lossy();
if name.ends_with(".neural.tasm") {
if let Err(e) = std::fs::remove_file(&entry) {
eprintln!(" warning: failed to delete {}: {}", entry.display(), e);
} else {
eprintln!(
" deleted {}",
entry.strip_prefix(&repo_root).unwrap_or(&entry).display()
);
deleted += 1;
}
}
}
}
}
if deleted == 0 {
eprintln!("trident train reset: nothing to delete (already clean)");
} else {
eprintln!("trident train reset: deleted {} artifacts", deleted);
}
}
fn walkdir(dir: &Path) -> Vec<std::path::PathBuf> {
let mut result = Vec::new();
walkdir_recursive(dir, &mut result, 0);
result
}
fn walkdir_recursive(dir: &Path, result: &mut Vec<std::path::PathBuf>, depth: usize) {
if depth >= 32 {
return;
}
let entries = match std::fs::read_dir(dir) {
Ok(e) => e,
Err(_) => return,
};
for entry in entries.flatten() {
let path = entry.path();
if path.is_dir() {
walkdir_recursive(&path, result, depth + 1);
} else {
result.push(path);
}
}
}
fn compile_corpus(files: &[std::path::PathBuf]) -> Vec<CompiledFile> {
use trident::neural::data::pairs::split_tir_by_function;
let options = super::resolve_options("triton", "debug", None);
let mut compiled = Vec::new();
for file in files {
let ir = match trident::build_tir_project(file, &options) {
Ok(ir) => ir,
Err(_) => continue,
};
if ir.is_empty() {
continue;
}
let file_path = short_path(file);
// Split TIR into per-function chunks and lower each independently.
// This produces many shorter training pairs (50-300 tokens) instead of
// one huge per-file pair (500-2000 tokens).
let functions = split_tir_by_function(&ir);
for (fn_name, fn_tir) in &functions {
// Skip entry/trailing scaffolding โ not useful training data
if fn_name.starts_with("__") {
continue;
}
if fn_tir.is_empty() {
continue;
}
// Lower this function's TIR to TASM
let lowering =
trident::ir::tir::lower::create_stack_lowering(&options.target_config.name);
let tasm_lines = lowering.lower(fn_tir);
// Filter out labels and empty lines โ keep only instructions
let tasm_lines: Vec<String> = tasm_lines
.into_iter()
.filter(|l| {
let t = l.trim();
!t.is_empty() && !t.ends_with(':') && !t.starts_with("//")
})
.map(|l| l.trim().to_string())
.collect();
if tasm_lines.is_empty() {
continue;
}
let profile = trident::cost::scorer::profile_tasm(
&tasm_lines.iter().map(|s| s.as_str()).collect::<Vec<_>>(),
);
let baseline_cost = profile.cost().max(1);
compiled.push(CompiledFile {
path: format!("{}:{}", file_path, fn_name),
tir_ops: fn_tir.clone(),
tasm_lines,
baseline_cost,
});
}
}
compiled
}
fn discover_corpus() -> Vec<std::path::PathBuf> {
let root = find_repo_root();
let mut files = Vec::new();
for dir in &["vm", "std", "os"] {
let dir_path = root.join(dir);
if dir_path.is_dir() {
files.extend(super::resolve_tri_files(&dir_path));
}
}
files.sort();
files
}
fn find_repo_root() -> std::path::PathBuf {
let mut dir = std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from("."));
loop {
if dir.join("Cargo.toml").exists() && dir.join("vm").is_dir() {
return dir;
}
if !dir.pop() {
return std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from("."));
}
}
}
fn short_path(path: &Path) -> String {
let s = path.to_string_lossy();
for prefix in &["vm/", "std/", "os/"] {
if let Some(pos) = s.find(prefix) {
return s[pos..].to_string();
}
}
s.to_string()
}
trident/src/cli/train.rs
ฯ 0.0%
use Path;
use process;
use RefCell;
use ;
thread_local!
/// Pre-compiled file data โ TIR + baselines, computed once.
/// Per-file eval result after beam search.
/// Main training loop โ generic over backend.
/// Stage 1: Supervised pre-training with cosine LR decay.
/// Evaluate holdout validity: beam search on each holdout pair, check if
/// the best candidate is equivalent to the target TASM via stack verifier.
/// Returns (valid_count, total_count).
/// Stage 2: GFlowNet fine-tuning.
/// Evaluate a subset of compiled files via beam search (no grads).
///
/// Only evaluates functions whose TASM target length is โค beam max_steps
/// (longer targets can't be reproduced by the beam). Caps at 50 functions
/// to keep eval under ~30s. Returns one FileEval per compiled file; skipped
/// files get zeroed evals.
/// Display epoch summary + per-file table. Returns number of lines for overwriting.
/// Per-file table rows (shared between Stage 1 and Stage 2).