//! Stage 1: Supervised pre-training with cross-entropy loss.
//!
//! Teacher forcing with grammar mask penalties. Trains the composite
//! model (GNN encoder + Transformer decoder) on (TirGraph, TASM) pairs.
use burn::grad_clipping::GradientClippingConfig;
use burn::optim::{AdamWConfig, GradientsParams, Optimizer};
use burn::prelude::*;
use burn::tensor::activation;
use crate::neural::data::pairs::TrainingPair;
use crate::neural::data::tir_graph::NODE_FEATURE_DIM;
use crate::neural::model::composite::NeuralCompilerV2;
use crate::neural::model::grammar::precompute_sequence_state;
/// Supervised training configuration.
pub struct SupervisedConfig {
/// Initial learning rate.
pub lr: f64,
/// Minimum learning rate (cosine decay target).
pub lr_min: f64,
/// Weight decay.
pub weight_decay: f64,
/// Gradient clipping norm.
pub grad_clip: f32,
/// Maximum epochs.
pub max_epochs: usize,
/// Early stopping patience (epochs without improvement).
pub patience: usize,
}
impl Default for SupervisedConfig {
fn default() -> Self {
Self {
lr: 3e-4,
lr_min: 1e-5,
weight_decay: 0.01,
grad_clip: 1.0,
max_epochs: 100,
patience: 3,
}
}
}
/// Cosine annealing learning rate: lr_min + 0.5*(lr - lr_min)*(1 + cos(pi*t/T))
pub fn cosine_lr(config: &SupervisedConfig, epoch: usize, total_epochs: usize) -> f64 {
if total_epochs <= 1 {
return config.lr;
}
let t = epoch as f64 / total_epochs as f64;
config.lr_min + 0.5 * (config.lr - config.lr_min) * (1.0 + (std::f64::consts::PI * t).cos())
}
/// Result of one training epoch.
pub struct EpochResult {
/// Average cross-entropy loss over all pairs.
pub avg_loss: f32,
/// Number of training pairs processed.
pub num_pairs: usize,
}
/// Train one epoch of supervised learning on the given pairs.
///
/// Uses teacher forcing: at each step, the ground-truth previous token
/// is provided as input. Grammar masks are applied as logit penalties.
///
/// Returns the model with updated weights and the epoch result.
pub fn train_epoch<B: burn::tensor::backend::AutodiffBackend>(
model: NeuralCompilerV2<B>,
pairs: &[TrainingPair],
optimizer: &mut impl Optimizer<NeuralCompilerV2<B>, B>,
lr: f64,
device: &B::Device,
) -> (NeuralCompilerV2<B>, EpochResult) {
let mut total_loss = 0.0f32;
let mut model = model;
for pair in pairs {
// 1. Prepare graph inputs
let node_features = graph_to_features::<B>(&pair.graph, device);
let (edge_src, edge_dst, edge_types) = graph_to_edges::<B>(&pair.graph, device);
// 2. Encode graph
let (node_emb, _global) =
model
.encoder
.forward(node_features, edge_src, edge_dst, edge_types);
// node_emb: [N, d_model] โ expand to [1, N, d_model] for batch=1
let d_model = node_emb.dims()[1];
let num_nodes = node_emb.dims()[0];
let memory = node_emb.unsqueeze_dim::<3>(0);
// 3. Prepare decoder inputs (teacher forcing)
// Truncate to max_seq=256 to fit position embedding table
const MAX_SEQ: usize = 256;
let tokens = if pair.target_tokens.len() > MAX_SEQ {
&pair.target_tokens[..MAX_SEQ]
} else {
&pair.target_tokens
};
let seq_len = tokens.len();
if seq_len < 2 {
continue; // Need at least input + one target
}
// Input tokens: [0, t0, t1, ..., t_{n-2}] (shifted right, prepend EOS=0)
let mut input_tokens = vec![0i32]; // Start with EOS
for &t in &tokens[..seq_len - 1] {
input_tokens.push(t as i32);
}
let token_ids =
Tensor::<B, 2, Int>::from_data(TensorData::new(input_tokens, [1, seq_len]), device);
// Positions: [0, 1, 2, ...]
let positions = Tensor::<B, 2, Int>::from_data(
TensorData::new((0..seq_len as i32).collect::<Vec<_>>(), [1, seq_len]),
device,
);
// Precompute grammar state for the (truncated) target sequence
let state = precompute_sequence_state(tokens, 0);
let stack_depths = Tensor::<B, 2, Int>::from_data(
TensorData::new(
state
.depths
.iter()
.map(|&d| (d as i32).min(64))
.collect::<Vec<_>>(),
[1, seq_len],
),
device,
);
let type_data: Vec<f32> = state.type_states.into_iter().flatten().collect();
let type_states =
Tensor::<B, 3>::from_data(TensorData::new(type_data, [1, seq_len, 24]), device);
// 4. Forward pass
let memory_expanded = memory.expand([1, num_nodes, d_model]);
let logits = model.decoder.forward(
token_ids,
positions,
stack_depths,
type_states,
memory_expanded,
);
// logits: [1, seq_len, VOCAB_SIZE]
// 5. Cross-entropy loss (no grammar mask during teacher forcing)
//
// Grammar masks are for inference-time beam search. During supervised
// training with teacher forcing, the target tokens ARE correct โ applying
// -1e9 penalties to them causes loss explosion (~1e8). The stack depth
// and type state features above already provide grammar awareness to the
// decoder as input conditioning.
let targets = Tensor::<B, 2, Int>::from_data(
TensorData::new(
tokens.iter().map(|&t| t as i32).collect::<Vec<_>>(),
[1, seq_len],
),
device,
);
let loss = cross_entropy_loss(logits, targets);
let loss_val: f32 = loss.clone().into_data().to_vec::<f32>().unwrap()[0];
total_loss += loss_val;
// 7. Backward pass + optimizer step
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
model = optimizer.step(lr, model, grads);
}
let avg_loss = if pairs.is_empty() {
0.0
} else {
total_loss / pairs.len() as f32
};
(
model,
EpochResult {
avg_loss,
num_pairs: pairs.len(),
},
)
}
/// Cross-entropy loss between logits and targets.
/// logits: [batch, seq, vocab], targets: [batch, seq]
fn cross_entropy_loss<B: Backend>(
logits: Tensor<B, 3>,
targets: Tensor<B, 2, Int>,
) -> Tensor<B, 1> {
let [batch, seq, vocab] = logits.dims();
// Reshape to [batch*seq, vocab] for softmax
let logits_flat = logits.reshape([batch * seq, vocab]);
let targets_flat = targets.reshape([batch * seq]);
// Log-softmax
let log_probs = activation::log_softmax(logits_flat, 1);
// Gather the log-prob of the target class
let targets_2d: Tensor<B, 2, Int> = targets_flat.unsqueeze_dim::<2>(1);
let selected = log_probs.gather(1, targets_2d); // [batch*seq, 1]
// Negative mean
selected.mean().neg().unsqueeze()
}
/// Convert TirGraph nodes to a feature tensor.
pub fn graph_to_features<B: Backend>(
graph: &crate::neural::data::tir_graph::TirGraph,
device: &B::Device,
) -> Tensor<B, 2> {
let num_nodes = graph.nodes.len();
let mut data = vec![0.0f32; num_nodes * NODE_FEATURE_DIM];
for (i, node) in graph.nodes.iter().enumerate() {
let fv = node.feature_vector();
data[i * NODE_FEATURE_DIM..(i + 1) * NODE_FEATURE_DIM].copy_from_slice(&fv);
}
Tensor::from_data(TensorData::new(data, [num_nodes, NODE_FEATURE_DIM]), device)
}
/// Convert TirGraph edges to index tensors.
pub fn graph_to_edges<B: Backend>(
graph: &crate::neural::data::tir_graph::TirGraph,
device: &B::Device,
) -> (Tensor<B, 1, Int>, Tensor<B, 1, Int>, Tensor<B, 1, Int>) {
let num_edges = graph.edges.len().max(1); // Need at least 1 edge for burn
let mut src = vec![0i32; num_edges];
let mut dst = vec![0i32; num_edges];
let mut types = vec![0i32; num_edges];
for (i, &(s, d, ref kind)) in graph.edges.iter().enumerate() {
src[i] = s as i32;
dst[i] = d as i32;
types[i] = match kind {
crate::neural::data::tir_graph::EdgeKind::DataDep => 0,
crate::neural::data::tir_graph::EdgeKind::ControlFlow => 1,
crate::neural::data::tir_graph::EdgeKind::MemOrder => 2,
};
}
(
Tensor::from_data(TensorData::new(src, [num_edges]), device),
Tensor::from_data(TensorData::new(dst, [num_edges]), device),
Tensor::from_data(TensorData::new(types, [num_edges]), device),
)
}
/// Create an AdamW optimizer with gradient clipping.
pub fn create_optimizer<B: burn::tensor::backend::AutodiffBackend>(
config: &SupervisedConfig,
) -> impl Optimizer<NeuralCompilerV2<B>, B> {
AdamWConfig::new()
.with_weight_decay(config.weight_decay as f32)
.with_grad_clipping(Some(GradientClippingConfig::Norm(config.grad_clip)))
.init()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::tir::TIROp;
use crate::neural::data::pairs::extract_pairs;
use crate::neural::model::composite::NeuralCompilerConfig;
use crate::neural::model::vocab::Vocab;
use burn::backend::Autodiff;
use burn::backend::NdArray;
type B = Autodiff<NdArray>;
#[test]
fn train_epoch_runs() {
let device = Default::default();
let config = NeuralCompilerConfig {
d_model: 32,
d_edge: 8,
gnn_layers: 1,
decoder_layers: 1,
n_heads: 4,
d_ff: 64,
max_seq: 32,
dropout: 0.0,
};
let model = config.init::<B>(&device);
let vocab = Vocab::new();
let blocks = vec![(
vec![TIROp::Push(1), TIROp::Push(2), TIROp::Add],
vec!["push 1".into(), "push 2".into(), "add".into()],
"test:0..3".into(),
3u64,
)];
let pairs = extract_pairs(&blocks, &vocab);
let supervised_config = SupervisedConfig::default();
let mut optimizer = create_optimizer::<B>(&supervised_config);
let lr = supervised_config.lr;
let (model, result) = train_epoch(model, &pairs, &mut optimizer, lr, &device);
assert_eq!(result.num_pairs, 1);
assert!(result.avg_loss > 0.0, "loss should be positive");
assert!(result.avg_loss.is_finite(), "loss should be finite");
// Train a second epoch โ loss should change
let (_model2, result2) = train_epoch(model, &pairs, &mut optimizer, lr, &device);
assert!(result2.avg_loss.is_finite());
}
}
trident/src/neural/training/supervised.rs
ฯ 0.0%
//! Stage 1: Supervised pre-training with cross-entropy loss.
//!
//! Teacher forcing with grammar mask penalties. Trains the composite
//! model (GNN encoder + Transformer decoder) on (TirGraph, TASM) pairs.
use GradientClippingConfig;
use ;
use *;
use activation;
use crateTrainingPair;
use crateNODE_FEATURE_DIM;
use crateNeuralCompilerV2;
use crateprecompute_sequence_state;
/// Supervised training configuration.
/// Cosine annealing learning rate: lr_min + 0.5*(lr - lr_min)*(1 + cos(pi*t/T))
/// Result of one training epoch.
/// Train one epoch of supervised learning on the given pairs.
///
/// Uses teacher forcing: at each step, the ground-truth previous token
/// is provided as input. Grammar masks are applied as logit penalties.
///
/// Returns the model with updated weights and the epoch result.
/// Cross-entropy loss between logits and targets.
/// logits: [batch, seq, vocab], targets: [batch, seq]
/// Convert TirGraph nodes to a feature tensor.
/// Convert TirGraph edges to index tensors.
/// Create an AdamW optimizer with gradient clipping.