//! Beam search decoder (K=32).
//!
//! Runs the Transformer decoder autoregressively with beam search,
//! applying grammar masks at each step to enforce TASM validity.
//! Returns K candidate sequences ranked by log-probability.
use burn::prelude::*;
use crate::neural::model::decoder::StackAwareDecoder;
use crate::neural::model::encoder::GnnEncoder;
use crate::neural::model::vocab::VOCAB_SIZE;
/// Beam search configuration.
pub struct BeamConfig {
/// Number of beams to maintain.
pub k: usize,
/// Maximum output sequence length.
pub max_steps: usize,
/// Minimum tokens before EOS is allowed.
pub min_tokens: usize,
/// Length normalization exponent (0=none, 1=full). Prevents short sequences
/// from dominating by normalizing log-prob by `length^alpha`.
pub length_alpha: f32,
/// Repetition penalty applied to tokens seen in the last `rep_window` steps.
/// 1.0 = no penalty, >1.0 = penalize (logit divided by this value).
pub rep_penalty: f32,
/// Window size for repetition penalty.
pub rep_window: usize,
}
impl Default for BeamConfig {
fn default() -> Self {
Self {
k: 32,
max_steps: 256,
min_tokens: 1,
length_alpha: 0.7,
rep_penalty: 1.5,
rep_window: 16,
}
}
}
/// Result of beam search: K candidate sequences with scores.
pub struct BeamResult {
/// Token ID sequences (one per beam), sorted by log-prob descending.
pub sequences: Vec<Vec<u32>>,
/// Log-probabilities for each sequence.
pub log_probs: Vec<f32>,
}
/// Run beam search on a single input graph.
///
/// - `encoder`: GNN encoder (already loaded)
/// - `decoder`: Transformer decoder (already loaded)
/// - `node_features`: [N, 59] node feature matrix
/// - `edge_src`, `edge_dst`: [E] edge endpoint indices
/// - `edge_types`: [E] edge type IDs (0=DataDep, 1=ControlFlow, 2=MemOrder)
/// - `config`: beam search parameters
/// - `initial_stack_depth`: initial stack depth for depth/type features
/// (must match training โ typically 0)
///
/// Returns K candidate token sequences ranked by log-probability.
///
/// No grammar mask is applied during decoding. The model was trained
/// without grammar masks (teacher forcing with ground truth), so the
/// learned distribution should naturally avoid invalid tokens.
pub fn beam_search<B: Backend>(
encoder: &GnnEncoder<B>,
decoder: &StackAwareDecoder<B>,
node_features: Tensor<B, 2>,
edge_src: Tensor<B, 1, Int>,
edge_dst: Tensor<B, 1, Int>,
edge_types: Tensor<B, 1, Int>,
config: &BeamConfig,
initial_stack_depth: i32,
device: &B::Device,
) -> BeamResult {
use crate::neural::model::grammar::StackStateMachine;
let k = config.k;
let num_nodes = node_features.dims()[0];
// 1. Encode graph โ node embeddings + global context
let (node_emb, _global) = encoder.forward(node_features, edge_src, edge_dst, edge_types);
// node_emb: [N, d_model]
// Expand memory for K beams: [K, N, d_model]
let d_model = node_emb.dims()[1];
let memory = node_emb
.unsqueeze_dim::<3>(0)
.expand([k, num_nodes, d_model]);
// 2. Initialize beams
let mut beam_sequences: Vec<Vec<u32>> = vec![vec![]; k];
let mut beam_log_probs: Vec<f32> = vec![0.0; k];
let mut beam_finished: Vec<bool> = vec![false; k];
// 3. Autoregressive decoding โ pass full sequence history each step.
//
// The decoder uses self-attention over all previous positions, so we must
// feed the entire generated sequence (not just the last token). At step t,
// input shape is [K, t+1] and we take logits at position t.
//
// Stack depth/type features are replayed from the generated sequence using
// the same initial_stack_depth as training (typically 0). This ensures
// the decoder sees depth embeddings consistent with what it learned.
for step in 0..config.max_steps {
// Check if all beams are finished
if beam_finished.iter().all(|&f| f) {
break;
}
let cur_len = step + 1; // sequence length including the start EOS token
// Build full sequence inputs: [K, cur_len]
// Each beam's input is: [EOS=0, tok_0, tok_1, ..., tok_{step-1}]
let mut token_data = Vec::with_capacity(k * cur_len);
let mut pos_data = Vec::with_capacity(k * cur_len);
let mut depth_data = Vec::with_capacity(k * cur_len);
let mut type_data = Vec::with_capacity(k * cur_len * 24);
for b in 0..k {
// Token IDs: [EOS, ...generated tokens]
token_data.push(0i32); // EOS start
for &t in &beam_sequences[b] {
token_data.push(t as i32);
}
// Pad if beam is shorter than step
while token_data.len() < (b + 1) * cur_len {
token_data.push(0i32);
}
// Positions: [0, 1, 2, ...]
for p in 0..cur_len {
pos_data.push(p as i32);
}
// Stack depths: replay state machine from generated tokens
let mut sm = StackStateMachine::new(initial_stack_depth);
depth_data.push(sm.depth_for_embedding(65) as i32); // depth at EOS start
for &t in &beam_sequences[b] {
sm.step(t);
depth_data.push(sm.depth_for_embedding(65) as i32);
}
while depth_data.len() < (b + 1) * cur_len {
depth_data.push(sm.depth_for_embedding(65) as i32);
}
// Type states: replay for each position
let mut sm2 = StackStateMachine::new(initial_stack_depth);
type_data.extend(sm2.type_encoding()); // type at EOS start
for &t in &beam_sequences[b] {
sm2.step(t);
type_data.extend(sm2.type_encoding());
}
while type_data.len() < (b + 1) * cur_len * 24 {
type_data.extend(std::iter::repeat(0.0f32).take(24));
}
}
let token_ids =
Tensor::<B, 2, Int>::from_data(TensorData::new(token_data, [k, cur_len]), device);
let positions =
Tensor::<B, 2, Int>::from_data(TensorData::new(pos_data, [k, cur_len]), device);
let stack_depths =
Tensor::<B, 2, Int>::from_data(TensorData::new(depth_data, [k, cur_len]), device);
let type_states =
Tensor::<B, 3>::from_data(TensorData::new(type_data, [k, cur_len, 24]), device);
// Forward pass: [K, cur_len, VOCAB_SIZE]
let logits = decoder.forward(
token_ids,
positions,
stack_depths,
type_states,
memory.clone(),
);
// Extract logits at the LAST position only: [K, VOCAB_SIZE]
let logits_2d = logits
.slice([0..k, step..step + 1, 0..VOCAB_SIZE])
.squeeze_dim::<2>(1);
// Convert to CPU for beam management
let logits_data = logits_2d.to_data();
let logits_flat: Vec<f32> = logits_data.to_vec().unwrap();
// 4. Score all (beam, token) candidates โ no grammar mask.
// The model was trained without grammar mask penalties, so its learned
// distribution should naturally prefer valid continuations.
let mut candidates: Vec<(usize, u32, f32)> = Vec::new(); // (beam_idx, token, log_prob)
for b in 0..k {
if beam_finished[b] {
// Finished beam: only EOS continuation with same score
candidates.push((b, 0, beam_log_probs[b]));
continue;
}
let beam_offset = b * VOCAB_SIZE;
// Build repetition set: tokens in last rep_window steps
let seq = &beam_sequences[b];
let rep_start = seq.len().saturating_sub(config.rep_window);
let mut recent = [false; VOCAB_SIZE];
for &tok in &seq[rep_start..] {
if (tok as usize) < VOCAB_SIZE {
recent[tok as usize] = true;
}
}
// Apply repetition penalty to raw logits, then log-softmax
let mut adjusted = Vec::with_capacity(VOCAB_SIZE);
for t in 0..VOCAB_SIZE {
let logit = logits_flat[beam_offset + t];
if recent[t] && config.rep_penalty > 1.0 {
// Penalize: divide positive logits, multiply negative ones
if logit > 0.0 {
adjusted.push(logit / config.rep_penalty);
} else {
adjusted.push(logit * config.rep_penalty);
}
} else {
adjusted.push(logit);
}
}
let max_logit = adjusted.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let log_sum_exp: f32 = adjusted
.iter()
.map(|&l| (l - max_logit).exp())
.sum::<f32>()
.ln()
+ max_logit;
for t in 0..VOCAB_SIZE {
// Block EOS until we've generated min_tokens
if t == 0 && step < config.min_tokens {
continue;
}
let log_prob = adjusted[t] - log_sum_exp;
let cumulative = beam_log_probs[b] + log_prob;
candidates.push((b, t as u32, cumulative));
}
}
// Sort by length-normalized score (descending), keep top K.
// Length normalization prevents short sequences from dominating.
let alpha = config.length_alpha;
candidates.sort_by(|a, b| {
let len_a = (beam_sequences[a.0].len() + if a.1 == 0 { 0 } else { 1 }).max(1) as f32;
let len_b = (beam_sequences[b.0].len() + if b.1 == 0 { 0 } else { 1 }).max(1) as f32;
let score_a = a.2 / len_a.powf(alpha);
let score_b = b.2 / len_b.powf(alpha);
score_b
.partial_cmp(&score_a)
.unwrap_or(std::cmp::Ordering::Equal)
});
candidates.truncate(k);
// 5. Update beams
let mut new_sequences: Vec<Vec<u32>> = Vec::with_capacity(k);
let mut new_log_probs: Vec<f32> = Vec::with_capacity(k);
let mut new_finished: Vec<bool> = Vec::with_capacity(k);
for &(src_beam, token, log_prob) in candidates.iter().take(k) {
let mut seq = beam_sequences[src_beam].clone();
let finished = beam_finished[src_beam] || token == 0;
if !beam_finished[src_beam] && token != 0 {
seq.push(token);
}
new_sequences.push(seq);
new_log_probs.push(log_prob);
new_finished.push(finished);
}
// Pad if fewer than K candidates
while new_sequences.len() < k {
new_sequences.push(vec![]);
new_log_probs.push(f32::NEG_INFINITY);
new_finished.push(true);
}
beam_sequences = new_sequences;
beam_log_probs = new_log_probs;
beam_finished = new_finished;
}
// Sort final results by length-normalized log-prob
let alpha = config.length_alpha;
let mut indexed: Vec<(usize, f32)> = beam_log_probs
.iter()
.enumerate()
.map(|(i, &lp)| {
let len = beam_sequences[i].len().max(1) as f32;
(i, lp / len.powf(alpha))
})
.collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let sequences: Vec<Vec<u32>> = indexed
.iter()
.map(|&(i, _)| beam_sequences[i].clone())
.collect();
let log_probs: Vec<f32> = indexed.iter().map(|&(_, lp)| lp).collect();
BeamResult {
sequences,
log_probs,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::neural::model::decoder::DecoderConfig;
use crate::neural::model::encoder::GnnEncoderConfig;
use burn::backend::NdArray;
type B = NdArray;
#[test]
fn beam_search_produces_k_sequences() {
let device = Default::default();
// Small model for testing
let encoder = GnnEncoderConfig::new()
.with_d_model(32)
.with_d_edge(8)
.with_num_layers(1)
.init::<B>(&device);
let decoder = DecoderConfig {
d_model: 32,
num_layers: 1,
n_heads: 4,
d_ff: 64,
max_seq: 64,
max_stack_depth: 65,
type_window: 8,
dropout: 0.0,
}
.init::<B>(&device);
// Tiny graph: 3 nodes, 2 edges
let node_features = Tensor::<B, 2>::zeros([3, 59], &device);
let edge_src = Tensor::<B, 1, Int>::from_data(TensorData::new(vec![0i32, 1], [2]), &device);
let edge_dst = Tensor::<B, 1, Int>::from_data(TensorData::new(vec![1i32, 2], [2]), &device);
let edge_types =
Tensor::<B, 1, Int>::from_data(TensorData::new(vec![0i32, 1], [2]), &device);
let config = BeamConfig {
k: 4, // Small K for test speed
max_steps: 5,
min_tokens: 1,
..Default::default()
};
let result = beam_search(
&encoder,
&decoder,
node_features,
edge_src,
edge_dst,
edge_types,
&config,
0,
&device,
);
assert_eq!(result.sequences.len(), 4);
assert_eq!(result.log_probs.len(), 4);
// Log probs should be sorted descending
for i in 1..result.log_probs.len() {
assert!(
result.log_probs[i] <= result.log_probs[i - 1],
"log_probs not sorted: {} > {}",
result.log_probs[i],
result.log_probs[i - 1]
);
}
}
}
trident/src/neural/inference/beam.rs
ฯ 0.0%
//! Beam search decoder (K=32).
//!
//! Runs the Transformer decoder autoregressively with beam search,
//! applying grammar masks at each step to enforce TASM validity.
//! Returns K candidate sequences ranked by log-probability.
use *;
use crateStackAwareDecoder;
use crateGnnEncoder;
use crateVOCAB_SIZE;
/// Beam search configuration.
/// Result of beam search: K candidate sequences with scores.
/// Run beam search on a single input graph.
///
/// - `encoder`: GNN encoder (already loaded)
/// - `decoder`: Transformer decoder (already loaded)
/// - `node_features`: [N, 59] node feature matrix
/// - `edge_src`, `edge_dst`: [E] edge endpoint indices
/// - `edge_types`: [E] edge type IDs (0=DataDep, 1=ControlFlow, 2=MemOrder)
/// - `config`: beam search parameters
/// - `initial_stack_depth`: initial stack depth for depth/type features
/// (must match training โ typically 0)
///
/// Returns K candidate token sequences ranked by log-probability.
///
/// No grammar mask is applied during decoding. The model was trained
/// without grammar masks (teacher forcing with ground truth), so the
/// learned distribution should naturally avoid invalid tokens.