//! Neural compiler v2: GNN encoder + Transformer decoder.
//!
//! Replaces the v1 MLP evolutionary model with a ~13M parameter
//! architecture trained via supervised learning + GFlowNets.
//!
//! # Public API
//!
//! ```ignore
//! use trident::neural;
//! let result = neural::compile(&tir_ops, &baseline_tasm)?;
//! ```

pub mod checkpoint;
pub mod data;
pub mod inference;
pub mod model;
pub mod training;

use burn::backend::Wgpu;

use crate::ir::tir::TIROp;
use data::tir_graph::TirGraph;
use inference::beam::{beam_search, BeamConfig};
use inference::execute::validate_and_rank;
use model::vocab::Vocab;
use training::supervised::{graph_to_edges, graph_to_features};

/// Result of neural compilation.
pub struct CompileResult {
    /// Optimized TASM instructions.
    pub tasm_lines: Vec<String>,
    /// Table cost (clock cycles) of the result.
    pub cost: u64,
    /// How many beam candidates were valid.
    pub valid_count: usize,
    /// Total beam candidates evaluated.
    pub total_count: usize,
    /// Whether this is a neural result (true) or fallback (false).
    pub neural: bool,
}

/// Compile TIR ops to optimized TASM using the neural model.
///
/// Loads the production checkpoint, runs beam search (K=32, max_steps=256),
/// validates candidates against baseline TASM, and returns the cheapest valid one.
///
/// Falls back to `baseline_tasm` if no valid candidate is found or if
/// no trained checkpoint exists.
pub fn compile(tir_ops: &[TIROp], baseline_tasm: &[String]) -> Result<CompileResult, String> {
    let device = burn::backend::wgpu::WgpuDevice::default();
    compile_with_device::<Wgpu>(tir_ops, baseline_tasm, &device)
}

/// Compile TIR ops with a specific burn backend device.
pub fn compile_with_device<B: burn::prelude::Backend>(
    tir_ops: &[TIROp],
    baseline_tasm: &[String],
    device: &B::Device,
) -> Result<CompileResult, String> {
    let vocab = Vocab::new();

    // Build graph from TIR
    let graph = TirGraph::from_tir_ops(tir_ops);
    if graph.nodes.is_empty() {
        return Ok(fallback_result(baseline_tasm));
    }

    // Load production checkpoint
    let config = model::composite::NeuralCompilerConfig::new();
    let model = config.init::<B>(device);
    let model =
        match checkpoint::load_checkpoint(model, checkpoint::CheckpointTag::Production, device) {
            Ok(Some(loaded)) => loaded,
            Ok(None) => {
                // No checkpoint โ€” try stage1_best as fallback
                let model2 = config.init::<B>(device);
                match checkpoint::load_checkpoint(
                    model2,
                    checkpoint::CheckpointTag::Stage1Best,
                    device,
                ) {
                    Ok(Some(loaded)) => loaded,
                    _ => return Ok(fallback_result(baseline_tasm)),
                }
            }
            Err(_) => return Ok(fallback_result(baseline_tasm)),
        };

    // Encode graph
    let node_features = graph_to_features::<B>(&graph, device);
    let (edge_src, edge_dst, edge_types) = graph_to_edges::<B>(&graph, device);

    // Beam search
    let beam_config = BeamConfig::default(); // K=32, max_steps=256
    let beam_result = beam_search(
        &model.encoder,
        &model.decoder,
        node_features,
        edge_src,
        edge_dst,
        edge_types,
        &beam_config,
        0, // must match training initial_stack_depth
        device,
    );

    // Validate and rank
    match validate_and_rank(&beam_result.sequences, &vocab, baseline_tasm, 0) {
        Some(ranked) => Ok(CompileResult {
            tasm_lines: ranked.tasm_lines,
            cost: ranked.cost,
            valid_count: ranked.valid_count,
            total_count: ranked.total_count,
            neural: true,
        }),
        None => Ok(fallback_result(baseline_tasm)),
    }
}

/// Load the trained model once, for use with `compile_with_model`.
/// Returns None if no checkpoint exists.
pub fn load_model<B: burn::prelude::Backend>(
    device: &B::Device,
) -> Option<model::composite::NeuralCompilerV2<B>> {
    let config = model::composite::NeuralCompilerConfig::new();
    let m = config.init::<B>(device);
    match checkpoint::load_checkpoint(m, checkpoint::CheckpointTag::Production, device) {
        Ok(Some(loaded)) => Some(loaded),
        Ok(None) => {
            let m2 = config.init::<B>(device);
            match checkpoint::load_checkpoint(m2, checkpoint::CheckpointTag::Stage1Best, device) {
                Ok(Some(loaded)) => Some(loaded),
                _ => None,
            }
        }
        Err(_) => None,
    }
}

/// Compile TIR ops using a pre-loaded model (avoids repeated checkpoint loading).
pub fn compile_with_model<B: burn::prelude::Backend>(
    tir_ops: &[TIROp],
    baseline_tasm: &[String],
    model: &model::composite::NeuralCompilerV2<B>,
    device: &B::Device,
) -> Result<CompileResult, String> {
    let vocab = Vocab::new();

    let graph = TirGraph::from_tir_ops(tir_ops);
    if graph.nodes.is_empty() {
        return Ok(fallback_result(baseline_tasm));
    }

    let node_features = graph_to_features::<B>(&graph, device);
    let (edge_src, edge_dst, edge_types) = graph_to_edges::<B>(&graph, device);

    let beam_config = BeamConfig::default();
    let beam_result = beam_search(
        &model.encoder,
        &model.decoder,
        node_features,
        edge_src,
        edge_dst,
        edge_types,
        &beam_config,
        0,
        device,
    );

    match validate_and_rank(&beam_result.sequences, &vocab, baseline_tasm, 0) {
        Some(ranked) => Ok(CompileResult {
            tasm_lines: ranked.tasm_lines,
            cost: ranked.cost,
            valid_count: ranked.valid_count,
            total_count: ranked.total_count,
            neural: true,
        }),
        None => Ok(fallback_result(baseline_tasm)),
    }
}

fn fallback_result(baseline_tasm: &[String]) -> CompileResult {
    use crate::cost::scorer::profile_tasm;
    let refs: Vec<&str> = baseline_tasm.iter().map(|s| s.as_str()).collect();
    let cost = profile_tasm(&refs).cost();
    CompileResult {
        tasm_lines: baseline_tasm.to_vec(),
        cost,
        valid_count: 0,
        total_count: 0,
        neural: false,
    }
}

Dimensions

trident/src/diagnostic/mod.rs
trident/src/ir/mod.rs
trident/src/deploy/mod.rs
trident/src/syntax/mod.rs
trident/src/api/mod.rs
nebu/rs/extension/mod.rs
optica/src/render/mod.rs
trident/src/config/mod.rs
trident/src/field/mod.rs
trident/src/cli/mod.rs
optica/src/parser/mod.rs
trident/src/cost/mod.rs
trident/src/typecheck/mod.rs
optica/src/server/mod.rs
trident/src/package/mod.rs
optica/src/scanner/mod.rs
optica/src/output/mod.rs
trident/src/verify/mod.rs
optica/src/graph/mod.rs
trident/src/ast/mod.rs
trident/src/lsp/mod.rs
trident/src/runtime/mod.rs
trident/src/gpu/mod.rs
optica/src/query/mod.rs
trident/src/lsp/semantic/mod.rs
trident/src/verify/equiv/mod.rs
trident/src/package/hash/mod.rs
trident/src/neural/training/mod.rs
trident/src/verify/synthesize/mod.rs
trident/src/ir/tir/mod.rs
rs/macros/src/addressed/mod.rs
trident/src/package/registry/mod.rs
rs/rsc/src/lints/mod.rs
trident/src/verify/report/mod.rs
trident/src/config/resolve/mod.rs
trident/src/verify/solve/mod.rs
rs/macros/src/registers/mod.rs
trident/src/verify/smt/mod.rs
rs/macros/src/cell/mod.rs
rs/core/src/fixed_point/mod.rs
trident/src/neural/data/mod.rs
rs/core/src/bounded/mod.rs
trident/src/lsp/util/mod.rs
trident/src/typecheck/tests/mod.rs
trident/src/neural/model/mod.rs
trident/src/cost/stack_verifier/mod.rs
trident/src/syntax/grammar/mod.rs
trident/src/package/manifest/mod.rs
trident/src/syntax/parser/mod.rs
trident/src/ir/kir/mod.rs
trident/src/neural/inference/mod.rs
trident/src/syntax/lexer/mod.rs
trident/src/cost/model/mod.rs
trident/src/ir/lir/mod.rs
trident/src/syntax/format/mod.rs
trident/src/config/scaffold/mod.rs
trident/src/verify/sym/mod.rs
trident/src/api/tests/mod.rs
trident/src/package/store/mod.rs
trident/src/ir/tree/mod.rs
trident/src/ir/kir/lower/mod.rs
trident/src/ir/lir/lower/mod.rs
trident/src/ir/tir/lower/mod.rs
trident/src/ir/tir/builder/mod.rs
trident/src/ir/tir/neural/mod.rs
trident/src/neural/data/tir_graph/mod.rs
trident/src/syntax/parser/tests/mod.rs
cw-cyber/packages/cyber-std/src/tokenfactory/mod.rs
trident/src/ir/tree/lower/mod.rs
trident/src/ir/tir/stack/mod.rs
cw-cyber/contracts/cybernet/src/tests/mod.rs
trident/src/ir/tir/optimize/mod.rs

Local Graph