//! Stage 2: GFlowNet training with Trajectory Balance loss.
//!
//! After supervised pre-training, the model is fine-tuned using GFlowNets
//! to explore the space of valid TASM programs. The reward signal comes
//! from actual clock cycle improvements over the compiler baseline.
use burn::prelude::*;
use crate::neural::model::composite::NeuralCompilerV2;
use crate::neural::model::grammar::StackStateMachine;
use crate::neural::model::vocab::{Vocab, VOCAB_SIZE};
use crate::neural::training::supervised::{graph_to_edges, graph_to_features};
/// GFlowNet training configuration.
pub struct GFlowNetConfig {
/// Initial temperature for sampling.
pub tau_start: f32,
/// Final temperature.
pub tau_end: f32,
/// Total steps over which temperature anneals.
pub anneal_steps: usize,
/// Maximum sequence length for sampling.
pub max_seq_len: usize,
/// Epsilon floor for reward (prevents log(0)).
pub reward_epsilon: f32,
/// Steps of partial credit shaping before switching to pure reward.
pub shaping_steps: usize,
/// Validity threshold to disable shaping.
pub shaping_validity_threshold: f32,
}
impl Default for GFlowNetConfig {
fn default() -> Self {
Self {
tau_start: 2.0,
tau_end: 0.5,
anneal_steps: 10_000,
max_seq_len: 256,
reward_epsilon: 1e-3,
shaping_steps: 1000,
shaping_validity_threshold: 0.7,
}
}
}
/// Reward for a generated TASM sequence.
///
/// R(tasm) = epsilon if !valid
/// = 1 + max(0, (compiler_cycles - model_cycles) if valid
/// / compiler_cycles)
pub fn compute_reward(
valid: bool,
model_cycles: Option<u64>,
compiler_cycles: u64,
epsilon: f32,
) -> f32 {
if !valid || model_cycles.is_none() {
return epsilon;
}
let mc = model_cycles.unwrap() as f32;
let cc = compiler_cycles as f32;
if cc <= 0.0 {
return 1.0;
}
1.0 + ((cc - mc) / cc).max(0.0)
}
/// Partial credit shaped reward for early training.
///
/// R_shaped = epsilon + (k / total_length) * validity_bonus
/// where k = step of first stack violation.
pub fn compute_shaped_reward(
first_violation_step: Option<usize>,
total_length: usize,
epsilon: f32,
) -> f32 {
match first_violation_step {
None => {
// No violation โ full validity bonus
epsilon + 1.0
}
Some(k) => {
if total_length == 0 {
return epsilon;
}
epsilon + (k as f32 / total_length as f32)
}
}
}
/// Temperature at a given training step (linear annealing).
pub fn temperature_at_step(step: usize, config: &GFlowNetConfig) -> f32 {
if step >= config.anneal_steps {
return config.tau_end;
}
let progress = step as f32 / config.anneal_steps as f32;
config.tau_start + (config.tau_end - config.tau_start) * progress
}
/// Sample a sequence from the model using temperature-scaled logits.
///
/// Returns (token_sequence, log_forward_prob, first_violation_step).
pub fn sample_sequence<B: Backend>(
model: &NeuralCompilerV2<B>,
graph: &crate::neural::data::tir_graph::TirGraph,
tau: f32,
config: &GFlowNetConfig,
device: &B::Device,
) -> (Vec<u32>, f32, Option<usize>) {
let node_features = graph_to_features::<B>(graph, device);
let (edge_src, edge_dst, edge_types) = graph_to_edges::<B>(graph, device);
// Encode graph
let (node_emb, _global) = model
.encoder
.forward(node_features, edge_src, edge_dst, edge_types);
let d_model = node_emb.dims()[1];
let num_nodes = node_emb.dims()[0];
let memory = node_emb.unsqueeze_dim::<3>(0); // [1, N, d_model]
let initial_depth = 0i32; // must match training
let mut tokens = Vec::new();
let mut log_pf = 0.0f32;
let mut first_violation: Option<usize> = None;
for step in 0..config.max_seq_len {
let cur_len = step + 1;
// Build full sequence: [EOS=0, tok_0, ..., tok_{step-1}]
let mut token_data = vec![0i32]; // EOS start
for &t in &tokens {
token_data.push(t as i32);
}
let pos_data: Vec<i32> = (0..cur_len as i32).collect();
// Stack depths: replay state machine
let mut depth_data = Vec::with_capacity(cur_len);
let mut sm_replay = StackStateMachine::new(initial_depth);
depth_data.push(sm_replay.depth_for_embedding(65) as i32);
for &t in &tokens {
sm_replay.step(t);
depth_data.push(sm_replay.depth_for_embedding(65) as i32);
}
// Type states: replay
let mut type_data = Vec::with_capacity(cur_len * 24);
let mut sm_replay2 = StackStateMachine::new(initial_depth);
type_data.extend(sm_replay2.type_encoding());
for &t in &tokens {
sm_replay2.step(t);
type_data.extend(sm_replay2.type_encoding());
}
let token_ids =
Tensor::<B, 2, Int>::from_data(TensorData::new(token_data, [1, cur_len]), device);
let positions =
Tensor::<B, 2, Int>::from_data(TensorData::new(pos_data, [1, cur_len]), device);
let stack_depths =
Tensor::<B, 2, Int>::from_data(TensorData::new(depth_data, [1, cur_len]), device);
let type_states =
Tensor::<B, 3>::from_data(TensorData::new(type_data, [1, cur_len, 24]), device);
let memory_expanded = memory.clone().expand([1, num_nodes, d_model]);
// Forward pass: [1, cur_len, VOCAB_SIZE]
let logits = model.decoder.forward(
token_ids,
positions,
stack_depths,
type_states,
memory_expanded,
);
// Extract logits at the last position: [VOCAB_SIZE]
let logits_1d = logits
.slice([0..1, step..step + 1, 0..VOCAB_SIZE])
.squeeze_dim::<2>(0)
.squeeze_dim::<1>(0);
let logits_data: Vec<f32> = logits_1d.to_data().to_vec().unwrap();
// Temperature-scaled softmax โ no grammar mask (matches training).
// Track violations via separate StackStateMachine for shaped reward.
let scaled: Vec<f32> = logits_data.iter().map(|&l| l / tau).collect();
let max_l = scaled.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let probs: Vec<f32> = scaled.iter().map(|&l| (l - max_l).exp()).collect();
let sum: f32 = probs.iter().sum();
let probs: Vec<f32> = probs.iter().map(|&p| p / sum).collect();
// Sample from categorical distribution
let token = sample_categorical(&probs);
// Track validity using a separate state machine (for shaped reward only)
if first_violation.is_none() && token != 0 {
let mut sm_check = StackStateMachine::new(initial_depth);
for &t in &tokens {
sm_check.step(t);
}
let mask = sm_check.valid_mask();
if mask[token as usize] < -1e8 {
first_violation = Some(step);
}
}
// Accumulate log forward probability
let prob = probs[token as usize].max(1e-10);
log_pf += prob.ln();
if token == 0 {
break; // EOS
}
tokens.push(token);
}
(tokens, log_pf, first_violation)
}
/// Sample from a categorical distribution using inverse-CDF.
///
/// Uses thread-local xorshift64 PRNG โ fast, no external deps, non-deterministic.
fn sample_categorical(probs: &[f32]) -> u32 {
use std::cell::Cell;
thread_local! {
static RNG_STATE: Cell<u64> = Cell::new(
// Seed from system time (nanos) to avoid deterministic argmax trap
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0xDEAD_BEEF_CAFE_1337) | 1
);
}
let u: f32 = RNG_STATE.with(|s| {
let mut x = s.get();
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
s.set(x);
// Map to [0, 1)
(x >> 40) as f32 / (1u64 << 24) as f32
});
// Inverse-CDF sampling
let mut cumulative = 0.0f32;
for (i, &p) in probs.iter().enumerate() {
cumulative += p;
if u < cumulative {
return i as u32;
}
}
// Fallback to last token (numerical edge case)
(probs.len() - 1) as u32
}
/// Trajectory Balance loss.
///
/// L_TB = (log_Z + log_P_F - log_P_B - log_R)^2
///
/// Where:
/// - log_Z: learned log-partition function (trained jointly)
/// - log_P_F: sum of log P(token_t | history) from forward sampling
/// - log_P_B: uniform backward policy (constant)
/// - log_R: log(reward), clipped >= log(epsilon)
pub fn tb_loss<B: Backend>(
log_pf: f32,
log_pb: f32,
log_r: f32,
log_z: Tensor<B, 1>,
device: &B::Device,
) -> Tensor<B, 1> {
let pf_tensor = Tensor::<B, 1>::from_data(TensorData::new(vec![log_pf], [1]), device);
let pb_tensor = Tensor::<B, 1>::from_data(TensorData::new(vec![log_pb], [1]), device);
let r_tensor = Tensor::<B, 1>::from_data(TensorData::new(vec![log_r], [1]), device);
let residual = log_z + pf_tensor - pb_tensor - r_tensor;
residual.clone() * residual
}
/// Run one GFlowNet training step.
///
/// Samples a sequence, computes reward, and returns TB loss.
pub fn gflownet_step<B: burn::tensor::backend::AutodiffBackend>(
model: &NeuralCompilerV2<B>,
graph: &crate::neural::data::tir_graph::TirGraph,
baseline_tasm: &[String],
compiler_cycles: u64,
log_z: Tensor<B, 1>,
step: usize,
config: &GFlowNetConfig,
vocab: &Vocab,
device: &B::Device,
) -> (Tensor<B, 1>, f32, bool) {
let tau = temperature_at_step(step, config);
// Sample sequence from model
let (tokens, log_pf, first_violation) = sample_sequence(model, graph, tau, config, device);
// Decode and validate
let tasm_lines = vocab.decode_sequence(&tokens);
let valid = if tasm_lines.is_empty() {
false
} else {
crate::cost::stack_verifier::verify_equivalent(baseline_tasm, &tasm_lines, 42)
};
let model_cycles = if valid {
let line_refs: Vec<&str> = tasm_lines.iter().map(|s| s.as_str()).collect();
Some(crate::cost::scorer::profile_tasm(&line_refs).cost())
} else {
None
};
// Compute reward
let reward = if step < config.shaping_steps
&& (step as f32 / config.shaping_steps as f32) < config.shaping_validity_threshold
{
compute_shaped_reward(first_violation, tokens.len(), config.reward_epsilon)
} else {
compute_reward(valid, model_cycles, compiler_cycles, config.reward_epsilon)
};
let log_r = reward.max(config.reward_epsilon).ln();
let log_pb = 0.0; // Uniform backward policy
let loss = tb_loss(log_pf, log_pb, log_r, log_z, device);
(loss, reward, valid)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn reward_valid_improvement() {
let r = compute_reward(true, Some(5), 10, 1e-3);
assert!((r - 1.5).abs() < 0.01); // 1 + (10-5)/10 = 1.5
}
#[test]
fn reward_valid_no_improvement() {
let r = compute_reward(true, Some(10), 10, 1e-3);
assert!((r - 1.0).abs() < 0.01); // 1 + max(0, 0) = 1.0
}
#[test]
fn reward_invalid() {
let r = compute_reward(false, None, 10, 1e-3);
assert!((r - 1e-3).abs() < 1e-6);
}
#[test]
fn shaped_reward_no_violation() {
let r = compute_shaped_reward(None, 10, 1e-3);
assert!((r - 1.001).abs() < 0.01);
}
#[test]
fn shaped_reward_early_violation() {
let r = compute_shaped_reward(Some(3), 10, 1e-3);
// epsilon + 3/10 = 0.001 + 0.3 = 0.301
assert!((r - 0.301).abs() < 0.01);
}
#[test]
fn temperature_annealing() {
let config = GFlowNetConfig {
tau_start: 2.0,
tau_end: 0.5,
anneal_steps: 100,
..Default::default()
};
assert!((temperature_at_step(0, &config) - 2.0).abs() < 0.01);
assert!((temperature_at_step(50, &config) - 1.25).abs() < 0.01);
assert!((temperature_at_step(100, &config) - 0.5).abs() < 0.01);
assert!((temperature_at_step(200, &config) - 0.5).abs() < 0.01);
}
#[test]
fn tb_loss_zero_residual() {
use burn::backend::NdArray;
let device = Default::default();
// When log_Z + log_PF - log_PB - log_R = 0, loss should be 0
let log_z = Tensor::<NdArray, 1>::from_data(TensorData::new(vec![1.0f32], [1]), &device);
let loss = tb_loss::<NdArray>(2.0, 1.0, 2.0, log_z, &device);
// residual = 1 + 2 - 1 - 2 = 0
let val: Vec<f32> = loss.to_data().to_vec().unwrap();
assert!(val[0].abs() < 1e-6);
}
#[test]
fn tb_loss_nonzero_residual() {
use burn::backend::NdArray;
let device = Default::default();
let log_z = Tensor::<NdArray, 1>::from_data(TensorData::new(vec![0.0f32], [1]), &device);
let loss = tb_loss::<NdArray>(1.0, 0.0, 0.5, log_z, &device);
// residual = 0 + 1 - 0 - 0.5 = 0.5, loss = 0.25
let val: Vec<f32> = loss.to_data().to_vec().unwrap();
assert!((val[0] - 0.25).abs() < 1e-4);
}
}
trident/src/neural/training/gflownet.rs
ฯ 0.0%
//! Stage 2: GFlowNet training with Trajectory Balance loss.
//!
//! After supervised pre-training, the model is fine-tuned using GFlowNets
//! to explore the space of valid TASM programs. The reward signal comes
//! from actual clock cycle improvements over the compiler baseline.
use *;
use crateNeuralCompilerV2;
use crateStackStateMachine;
use crate;
use crate;
/// GFlowNet training configuration.
/// Reward for a generated TASM sequence.
///
/// R(tasm) = epsilon if !valid
/// = 1 + max(0, (compiler_cycles - model_cycles) if valid
/// / compiler_cycles)
/// Partial credit shaped reward for early training.
///
/// R_shaped = epsilon + (k / total_length) * validity_bonus
/// where k = step of first stack violation.
/// Temperature at a given training step (linear annealing).
/// Sample a sequence from the model using temperature-scaled logits.
///
/// Returns (token_sequence, log_forward_prob, first_violation_step).
/// Sample from a categorical distribution using inverse-CDF.
///
/// Uses thread-local xorshift64 PRNG โ fast, no external deps, non-deterministic.
/// Trajectory Balance loss.
///
/// L_TB = (log_Z + log_P_F - log_P_B - log_R)^2
///
/// Where:
/// - log_Z: learned log-partition function (trained jointly)
/// - log_P_F: sum of log P(token_t | history) from forward sampling
/// - log_P_B: uniform backward policy (constant)
/// - log_R: log(reward), clipped >= log(epsilon)
/// Run one GFlowNet training step.
///
/// Samples a sequence, computes reward, and returns TB loss.