use std::path::{Path, PathBuf};
use std::process;

use clap::Args;

use super::trisha::{
    generate_program_harness, generate_test_harness, run_trisha, trisha_available, Harness,
};

use burn::backend::wgpu::{Wgpu, WgpuDevice};
use trident::neural::model::composite::NeuralCompilerV2;

#[derive(Args)]
pub struct BenchArgs {
    /// Directory containing baseline .tasm files (mirrors source tree)
    #[arg(default_value = "baselines/triton")]
    pub dir: PathBuf,
    /// Run all checks: compile, execute, prove, verify
    #[arg(long)]
    pub full: bool,
    /// Show per-function instruction breakdown
    #[arg(long)]
    pub functions: bool,
    /// Skip neural model compilation (faster)
    #[arg(long)]
    pub skip_neural: bool,
}

/// Timing triplet for a single dimension: execute, prove, verify (ms).
#[derive(Default)]
struct DimTiming {
    exec_ms: Option<f64>,
    prove_ms: Option<f64>,
    verify_ms: Option<f64>,
    proof_path: Option<PathBuf>,
}

/// Per-module benchmark data across all dimensions.
struct ModuleBench {
    name: String,
    /// Instruction counts
    classic_insn: usize,
    hand_insn: usize,
    neural_insn: usize,
    /// Rust-native compilation time (ms)
    compile_ms: f64,
    /// Rust reference execution time (nanoseconds per op), if available
    rust_ns: Option<u64>,
    /// Per-dimension timing
    classic: DimTiming,
    hand: DimTiming,
    neural: DimTiming,
    /// Per-function breakdown (only collected with --functions)
    functions: Vec<trident::FunctionBenchmark>,
}

pub fn cmd_bench(args: BenchArgs) {
    let bench_dir = resolve_bench_dir(&args.dir);
    if !bench_dir.is_dir() {
        eprintln!("error: '{}' is not a directory", args.dir.display());
        process::exit(1);
    }

    let project_root = find_project_root(&bench_dir);
    let baselines_root = project_root.join("baselines/triton");

    let mut baselines = find_baseline_files(&bench_dir, 0);
    baselines.sort();

    if baselines.is_empty() {
        eprintln!("No .tasm baselines found in '{}'", bench_dir.display());
        process::exit(1);
    }

    let options = trident::CompileOptions::default();
    let has_trisha = args.full && trisha_available();

    // Load neural model once for all modules (unless --skip-neural)
    let wgpu_device = WgpuDevice::default();
    let neural_model: Option<NeuralCompilerV2<Wgpu>> = if args.skip_neural {
        None
    } else {
        let m = trident::neural::load_model::<Wgpu>(&wgpu_device);
        if m.is_some() {
            eprint!("  Neural model loaded.\n");
        }
        m
    };

    // Collect data for each module
    let mut modules: Vec<ModuleBench> = Vec::new();

    for baseline_path in &baselines {
        let rel = baseline_path
            .strip_prefix(&baselines_root)
            .unwrap_or(baseline_path);
        let rel_str = rel.to_string_lossy();
        let source_rel = rel_str.replace(".tasm", ".tri");
        let source_path = project_root.join(&source_rel);
        let module_name = source_rel.trim_end_matches(".tri").replace('/', "::");

        if !source_path.exists() {
            continue;
        }

        // Read baseline TASM
        let baseline_tasm = match std::fs::read_to_string(baseline_path) {
            Ok(s) => s,
            Err(_) => continue,
        };

        // Compile module (instruction count) + time it
        let compile_start = std::time::Instant::now();
        let _guard = trident::diagnostic::suppress_warnings();
        let compiled_tasm = match trident::compile_module(&source_path, &options) {
            Ok(t) => t,
            Err(_) => continue,
        };
        drop(_guard);
        let compile_ms = compile_start.elapsed().as_secs_f64() * 1000.0;

        // Parse per-function instruction counts
        let compiled_fns = trident::parse_tasm_functions(&compiled_tasm);
        let baseline_fns = trident::parse_tasm_functions(&baseline_tasm);

        let mut fn_results: Vec<trident::FunctionBenchmark> = Vec::new();
        let mut total_compiled: usize = 0;
        let mut total_baseline: usize = 0;

        for (name, &baseline_count) in &baseline_fns {
            let compiled_count = compiled_fns.get(name).copied().unwrap_or(0);
            total_compiled += compiled_count;
            total_baseline += baseline_count;
            if args.functions {
                fn_results.push(trident::FunctionBenchmark {
                    name: name.clone(),
                    compiled_instructions: compiled_count,
                    baseline_instructions: baseline_count,
                });
            }
        }

        // Run Rust reference benchmark if available
        let ref_rs = project_root
            .join("benches/references")
            .join(rel_str.replace(".tasm", ".rs"));
        let rust_ns = if args.full && ref_rs.exists() {
            let rel = ref_rs.strip_prefix(project_root).unwrap_or(&ref_rs);
            run_rust_reference(&rel.to_string_lossy())
        } else {
            None
        };

        // Neural: compile per-function via neural model
        let neural_tasm_opt = if let Some(ref model) = neural_model {
            let result = compile_neural_tasm_inline(
                &source_path,
                &compiled_tasm,
                &options,
                model,
                &wgpu_device,
            );
            result
        } else {
            None
        };
        let neural_insn_count = neural_tasm_opt
            .as_ref()
            .map(|t| {
                t.lines()
                    .filter(|l| {
                        let s = l.trim();
                        !s.is_empty() && !s.starts_with("//") && !s.ends_with(':') && s != "halt"
                    })
                    .count()
            })
            .unwrap_or(0);

        let mut mb = ModuleBench {
            name: module_name.clone(),
            classic_insn: total_compiled,
            hand_insn: total_baseline,
            neural_insn: neural_insn_count,
            compile_ms,
            rust_ns,
            classic: DimTiming::default(),
            hand: DimTiming::default(),
            neural: DimTiming::default(),
            functions: fn_results,
        };

        // Run trisha passes for --full
        if has_trisha {
            // Check for .inputs file for live harness mode
            let harness_dir = project_root.join("benches/harnesses");
            let inputs_path = harness_dir.join(rel_str.replace(".tasm", ".inputs"));
            let live_inputs = if inputs_path.exists() {
                parse_inputs_file(&inputs_path)
            } else {
                None
            };

            if let Some(ref li) = live_inputs {
                // Live mode: compile harness .tri as linked program, transform for execution.
                // This resolves all cross-module calls (unlike compile_module).
                let bench_tri = harness_dir.join(rel_str.replace(".tasm", ".tri"));
                // Also check for pre-compiled harness .tasm
                let bench_tasm_path = harness_dir.join(rel_str.replace(".tasm", "_bench.tasm"));

                let linked_tasm = if bench_tasm_path.exists() {
                    std::fs::read_to_string(&bench_tasm_path).ok()
                } else if bench_tri.exists() {
                    let _guard2 = trident::diagnostic::suppress_warnings();
                    let result =
                        trident::compile_project_with_options(&bench_tri, &options).ok();
                    drop(_guard2);
                    result
                } else {
                    None
                };

                if let Some(ref tasm) = linked_tasm {
                    let harness = generate_program_harness(tasm, &li.values, &li.divine);
                    run_dimension(&mut mb.classic, &module_name, "classic", &harness);
                }
                // Hand baseline: use test harness (hand TASM is single-module)
                let hand_harness = generate_test_harness(&baseline_tasm);
                run_dimension(&mut mb.hand, &module_name, "hand", &hand_harness);
            } else {
                // Standard mode: single-module test harness (loops execute once)
                let _guard2 = trident::diagnostic::suppress_warnings();
                let module_tasm = trident::compile_module(&source_path, &options).ok();
                drop(_guard2);

                if let Some(tasm) = module_tasm {
                    let classic_harness = generate_test_harness(&tasm);
                    run_dimension(&mut mb.classic, &module_name, "classic", &classic_harness);
                }

                let hand_harness = generate_test_harness(&baseline_tasm);
                run_dimension(&mut mb.hand, &module_name, "hand", &hand_harness);

                // Neural: use inline-compiled neural TASM
                if let Some(ref neural_tasm) = neural_tasm_opt {
                    if !neural_tasm.is_empty() {
                        let neural_harness = generate_test_harness(neural_tasm);
                        run_dimension(&mut mb.neural, &module_name, "neural", &neural_harness);
                    }
                }
            }
        }

        // Show progress
        eprint!("\r  collecting {}...{}", module_name, " ".repeat(30));
        use std::io::Write;
        let _ = std::io::stderr().flush();

        modules.push(mb);
    }

    // Verify pass (needs proof files from prove pass)
    if has_trisha {
        for mb in &mut modules {
            verify_dimension(&mut mb.classic);
            verify_dimension(&mut mb.hand);
            verify_dimension(&mut mb.neural);
        }
    }

    // Clear progress line
    eprint!("\r{}\r", " ".repeat(80));

    if modules.is_empty() {
        eprintln!("No benchmarks could be compiled.");
        process::exit(1);
    }

    // Render unified table
    eprintln!();
    if args.full {
        render_full_table(&modules, args.functions);
    } else {
        render_insn_table(&modules, args.functions);
    }

    // Clean up proof files
    for mb in &modules {
        for dim in [&mb.classic, &mb.hand, &mb.neural] {
            if let Some(ref path) = dim.proof_path {
                let _ = std::fs::remove_file(path);
            }
        }
    }

    eprintln!();
}

/// Render instruction-count-only table (default, no --full).
fn render_insn_table(modules: &[ModuleBench], show_functions: bool) {
    let w = modules
        .iter()
        .map(|m| m.name.len())
        .max()
        .unwrap_or(40)
        .max(6)
        + 2;
    eprintln!(
        "{:<w$} {:>6} {:>6} {:>6} {:>7}",
        "Module",
        "Tri",
        "Hand",
        "Neural",
        "Ratio",
        w = w,
    );
    eprintln!("{}", "-".repeat(w + 30));

    for mb in modules {
        let ratio = if mb.hand_insn > 0 {
            format!("{:.2}x", mb.classic_insn as f64 / mb.hand_insn as f64)
        } else {
            "-".to_string()
        };
        let neural_str = if mb.neural_insn > 0 {
            mb.neural_insn.to_string()
        } else {
            "-".to_string()
        };
        eprintln!(
            "{:<w$} {:>6} {:>6} {:>6} {:>7}",
            mb.name,
            mb.classic_insn,
            mb.hand_insn,
            neural_str,
            ratio,
            w = w,
        );
        if show_functions {
            for f in &mb.functions {
                let fr = if f.baseline_instructions > 0 {
                    format!(
                        "{:.2}x",
                        f.compiled_instructions as f64 / f.baseline_instructions as f64
                    )
                } else {
                    "-".to_string()
                };
                eprintln!(
                    "  {:<fw$} {:>6} {:>6} {:>6} {:>7}",
                    f.name,
                    if f.compiled_instructions > 0 {
                        f.compiled_instructions.to_string()
                    } else {
                        "-".to_string()
                    },
                    f.baseline_instructions,
                    "", // per-function neural not tracked here
                    fr,
                    fw = w - 2,
                );
            }
        }
    }

    eprintln!("{}", "-".repeat(w + 30));
    let sum_classic: usize = modules.iter().map(|m| m.classic_insn).sum();
    let sum_hand: usize = modules.iter().map(|m| m.hand_insn).sum();
    let sum_neural: usize = modules.iter().map(|m| m.neural_insn).sum();
    let avg_ratio = if sum_hand > 0 {
        format!("{:.2}x", sum_classic as f64 / sum_hand as f64)
    } else {
        "-".to_string()
    };
    let neural_total = if sum_neural > 0 {
        sum_neural.to_string()
    } else {
        "-".to_string()
    };
    eprintln!(
        "{:<w$} {:>6} {:>6} {:>6} {:>7}",
        format!("TOTAL ({} modules)", modules.len()),
        sum_classic,
        sum_hand,
        neural_total,
        avg_ratio,
        w = w,
    );
}

/// Format a millisecond value, or "-" if None.
fn fmt_ms(ms: Option<f64>) -> String {
    ms.map(|v| format!("{:.0}ms", v))
        .unwrap_or_else(|| "-".into())
}

/// Compact verify status for a row: shows PASS/FAIL based on best result across dimensions.
fn fmt_verify_row(classic: &DimTiming, hand: &DimTiming, neural: &DimTiming) -> &'static str {
    let any_pass =
        classic.verify_ms.is_some() || hand.verify_ms.is_some() || neural.verify_ms.is_some();
    let any_proof =
        classic.proof_path.is_some() || hand.proof_path.is_some() || neural.proof_path.is_some();
    if any_pass {
        "PASS"
    } else if any_proof {
        "FAIL"
    } else {
        "-"
    }
}

/// Render full 4D table: grouped by step (Exec | Prove | Verify), sub-columns C/H/N.
fn render_full_table(modules: &[ModuleBench], show_functions: bool) {
    let w = modules
        .iter()
        .map(|m| m.name.len())
        .max()
        .unwrap_or(40)
        .max(6)
        + 2;
    let total_w = w + 104;
    // Header: Module  Compile  Rust | Exec (C H N) | Prove (C H N) | Verify (C H N) | Ratio
    eprintln!(
        "{:<w$} {:>7} {:>7}  | {:>5} {:>5} {:>5} | {:>7} {:>7} {:>7} | {:>5} {:>5} {:>5} {:>4} | {:>5}",
        "Module", "Compile", "Rust",
        "C", "H", "N",
        "C", "H", "N",
        "C", "H", "N", "",
        "Ratio",
        w = w,
    );
    eprintln!(
        "{:<w$} {:>7} {:>7}  | {:<17} | {:<23} | {:<22} | {:>5}",
        "",
        "",
        "",
        "Exec",
        "Prove",
        "Verify",
        "",
        w = w,
    );
    eprintln!("{}", "-".repeat(total_w));

    for mb in modules {
        let ratio = if mb.hand_insn > 0 {
            format!("{:.2}x", mb.classic_insn as f64 / mb.hand_insn as f64)
        } else {
            "-".to_string()
        };

        eprintln!(
            "{:<w$} {:>7} {:>7}  | {:>5} {:>5} {:>5} | {:>7} {:>7} {:>7} | {:>5} {:>5} {:>5} {:>4} | {:>5}",
            mb.name,
            format!("{:.1}ms", mb.compile_ms),
            fmt_rust(mb.rust_ns),
            fmt_ms(mb.classic.exec_ms), fmt_ms(mb.hand.exec_ms), fmt_ms(mb.neural.exec_ms),
            fmt_ms(mb.classic.prove_ms), fmt_ms(mb.hand.prove_ms), fmt_ms(mb.neural.prove_ms),
            fmt_ms(mb.classic.verify_ms), fmt_ms(mb.hand.verify_ms), fmt_ms(mb.neural.verify_ms),
            fmt_verify_row(&mb.classic, &mb.hand, &mb.neural),
            ratio,
            w = w,
        );

        if show_functions {
            for f in &mb.functions {
                let fr = if f.baseline_instructions > 0 {
                    format!(
                        "{:.2}x",
                        f.compiled_instructions as f64 / f.baseline_instructions as f64
                    )
                } else {
                    "-".to_string()
                };
                eprintln!(
                    "  {:<fw$} {:>5}/{:<5} {}",
                    f.name,
                    if f.compiled_instructions > 0 {
                        f.compiled_instructions.to_string()
                    } else {
                        "-".to_string()
                    },
                    f.baseline_instructions,
                    fr,
                    fw = w - 2,
                );
            }
        }
    }

    eprintln!("{}", "-".repeat(total_w));

    // Summary row
    let sum_classic: usize = modules.iter().map(|m| m.classic_insn).sum();
    let sum_hand: usize = modules.iter().map(|m| m.hand_insn).sum();
    let avg_ratio = if sum_hand > 0 {
        format!("{:.2}x", sum_classic as f64 / sum_hand as f64)
    } else {
        "-".to_string()
    };

    let total_compile: f64 = modules.iter().map(|m| m.compile_ms).sum();
    let total_rust_ns: u64 = modules.iter().filter_map(|m| m.rust_ns).sum();
    let rust_count = modules.iter().filter(|m| m.rust_ns.is_some()).count();
    let sum_dim_col = |modules: &[ModuleBench],
                       dim: fn(&ModuleBench) -> &DimTiming,
                       get: fn(&DimTiming) -> Option<f64>|
     -> f64 { modules.iter().filter_map(|m| get(dim(m))).sum() };
    let classic_exec: f64 = sum_dim_col(modules, |m| &m.classic, |d| d.exec_ms);
    let classic_prove: f64 = sum_dim_col(modules, |m| &m.classic, |d| d.prove_ms);
    let classic_verify: f64 = sum_dim_col(modules, |m| &m.classic, |d| d.verify_ms);
    let hand_exec: f64 = sum_dim_col(modules, |m| &m.hand, |d| d.exec_ms);
    let hand_prove: f64 = sum_dim_col(modules, |m| &m.hand, |d| d.prove_ms);
    let hand_verify: f64 = sum_dim_col(modules, |m| &m.hand, |d| d.verify_ms);
    let neural_exec: f64 = sum_dim_col(modules, |m| &m.neural, |d| d.exec_ms);
    let neural_prove: f64 = sum_dim_col(modules, |m| &m.neural, |d| d.prove_ms);
    let neural_verify: f64 = sum_dim_col(modules, |m| &m.neural, |d| d.verify_ms);

    let classic_verified = modules
        .iter()
        .filter(|m| m.classic.verify_ms.is_some())
        .count();
    let hand_verified = modules
        .iter()
        .filter(|m| m.hand.verify_ms.is_some())
        .count();
    let neural_verified = modules
        .iter()
        .filter(|m| m.neural.verify_ms.is_some())
        .count();
    let n = modules.len();

    let rust_total_str = if rust_count > 0 {
        fmt_rust(Some(total_rust_ns))
    } else {
        "-".into()
    };

    let fmt_t = |v: f64, has: bool| -> String {
        if has {
            format!("{:.0}ms", v)
        } else {
            "-".into()
        }
    };

    eprintln!(
        "{:<w$} {:>7} {:>7}  | {:>5} {:>5} {:>5} | {:>7} {:>7} {:>7} | {:>5} {:>5} {:>5} {:>4} | {:>5}",
        format!("TOTAL ({} modules)", n),
        format!("{:.0}ms", total_compile),
        rust_total_str,
        fmt_t(classic_exec, classic_verified > 0), fmt_t(hand_exec, hand_verified > 0), fmt_t(neural_exec, neural_verified > 0),
        fmt_t(classic_prove, classic_verified > 0), fmt_t(hand_prove, hand_verified > 0), fmt_t(neural_prove, neural_verified > 0),
        fmt_t(classic_verify, classic_verified > 0), fmt_t(hand_verify, hand_verified > 0), fmt_t(neural_verify, neural_verified > 0),
        format!("{}/{}", classic_verified, n),
        avg_ratio,
        w = w,
    );
    eprintln!(
        "{:<w$} {:>7} {:>7}  | insn: {:<11} |",
        "",
        "",
        "",
        format!("{}C / {}H", sum_classic, sum_hand),
        w = w,
    );
}

/// Run trisha with a timeout. Kills the process if it exceeds the deadline.
fn run_trisha_timed(
    base_args: &[&str],
    harness: &Harness,
    timeout: std::time::Duration,
) -> Result<super::trisha::TrishaResult, String> {
    use super::trisha::trisha_args_with_inputs;

    let args = trisha_args_with_inputs(base_args, harness);
    let start = std::time::Instant::now();
    let mut child = std::process::Command::new("trisha")
        .args(&args)
        .stdout(std::process::Stdio::piped())
        .stderr(std::process::Stdio::piped())
        .spawn()
        .map_err(|e| format!("failed to spawn trisha: {}", e))?;

    loop {
        match child.try_wait() {
            Ok(Some(status)) => {
                let elapsed_ms = start.elapsed().as_secs_f64() * 1000.0;
                if !status.success() {
                    // Capture stderr for debugging
                    if let Some(mut stderr) = child.stderr.take() {
                        use std::io::Read;
                        let mut buf = String::new();
                        let _ = stderr.read_to_string(&mut buf);
                        let msg: String = buf
                            .lines()
                            .filter(|l| !l.starts_with("GPU:") && !l.starts_with("Backend:"))
                            .take(5)
                            .collect::<Vec<_>>()
                            .join("\n");
                        if !msg.is_empty() {
                            return Err(msg);
                        }
                    }
                    return Err("failed".to_string());
                }
                return Ok(super::trisha::TrishaResult {
                    output: Vec::new(),
                    cycle_count: 0,
                    elapsed_ms,
                });
            }
            Ok(None) => {
                if start.elapsed() > timeout {
                    let _ = child.kill();
                    let _ = child.wait();
                    return Err("timed out".to_string());
                }
                std::thread::sleep(std::time::Duration::from_millis(5));
            }
            Err(e) => return Err(format!("wait error: {}", e)),
        }
    }
}

/// Run execute + prove for a single dimension, writing results into DimTiming.
fn run_dimension(dim: &mut DimTiming, module_name: &str, label: &str, harness: &Harness) {
    let tmp_path = std::env::temp_dir().join(format!(
        "trident_bench_{}_{}.tasm",
        module_name.replace("::", "_"),
        label,
    ));
    if std::fs::write(&tmp_path, &harness.tasm).is_err() {
        return;
    }
    let tmp_str = tmp_path.to_string_lossy().to_string();
    // Execute (5min timeout)
    match run_trisha_timed(
        &["run", "--tasm", &tmp_str],
        harness,
        std::time::Duration::from_secs(300),
    ) {
        Ok(r) => dim.exec_ms = Some(r.elapsed_ms),
        Err(e) => eprintln!("\n  [{}:{}] exec error: {}", module_name, label, e),
    }
    // Prove (5min timeout)
    let proof_path = std::env::temp_dir().join(format!(
        "trident_bench_{}_{}.proof.toml",
        module_name.replace("::", "_"),
        label,
    ));
    let proof_str = proof_path.to_string_lossy().to_string();
    if let Ok(r) = run_trisha_timed(
        &["prove", "--tasm", &tmp_str, "--output", &proof_str],
        harness,
        std::time::Duration::from_secs(300),
    ) {
        dim.prove_ms = Some(r.elapsed_ms);
        if proof_path.exists() {
            dim.proof_path = Some(proof_path);
        }
    }
    let _ = std::fs::remove_file(&tmp_path);
}

/// Run verify for a dimension (requires proof_path from prove pass).
fn verify_dimension(dim: &mut DimTiming) {
    if let Some(ref proof_path) = dim.proof_path {
        if let Ok(r) = run_trisha(&["verify", &proof_path.to_string_lossy()]) {
            dim.verify_ms = Some(r.elapsed_ms);
        }
    }
}

/// Format nanoseconds for display: ยตs for >= 1000ns, ns otherwise, "-" if None.
fn fmt_rust(ns: Option<u64>) -> String {
    match ns {
        None => "-".into(),
        Some(n) if n >= 1_000_000 => format!("{:.1}ms", n as f64 / 1_000_000.0),
        Some(n) if n >= 1_000 => format!("{:.1}ยตs", n as f64 / 1_000.0),
        Some(n) => format!("{}ns", n),
    }
}

/// Run a Rust reference benchmark. Expects a `.rs` file in benches/references/
/// that is registered as a cargo example. Builds with --release, runs, parses
/// `rust_ns: <N>` from stdout.
fn run_rust_reference(ref_path: &str) -> Option<u64> {
    // Derive example name from path: benches/references/std/crypto/poseidon2.rs -> ref_std_crypto_poseidon2
    let name = ref_path
        .trim_start_matches("benches/references/")
        .trim_end_matches(".rs")
        .replace('/', "_");
    let example_name = format!("ref_{}", name);

    // Build (should be near-instant if already built)
    let build = std::process::Command::new("cargo")
        .args(["build", "--example", &example_name, "--release", "--quiet"])
        .stdout(std::process::Stdio::null())
        .stderr(std::process::Stdio::null())
        .status()
        .ok()?;
    if !build.success() {
        return None;
    }

    // Run
    let output = std::process::Command::new("cargo")
        .args(["run", "--example", &example_name, "--release", "--quiet"])
        .output()
        .ok()?;
    if !output.status.success() {
        return None;
    }

    // Parse "rust_ns: <N>" from stdout
    let stdout = String::from_utf8_lossy(&output.stdout);
    stdout.lines().find_map(|l| {
        l.strip_prefix("rust_ns: ")
            .and_then(|v| v.trim().parse().ok())
    })
}

/// Compile a module with neural optimization using a pre-loaded model.
///
/// Splits TIR into per-function blocks (matching training), runs neural
/// beam search on each, and assembles the result. For each function,
/// picks neural output if valid and cost <= compiler, else keeps compiler output.
fn compile_neural_tasm_inline(
    source_path: &Path,
    _classical_tasm: &str,
    options: &trident::CompileOptions,
    model: &NeuralCompilerV2<Wgpu>,
    device: &WgpuDevice,
) -> Option<String> {
    use trident::neural::data::pairs::split_tir_by_function;

    // Build TIR
    let _guard = trident::diagnostic::suppress_warnings();
    let ir = match trident::build_tir_project(source_path, options) {
        Ok(ir) => ir,
        Err(_) => return None,
    };
    drop(_guard);

    let functions = split_tir_by_function(&ir);
    if functions.is_empty() {
        return None;
    }

    let lowering = trident::ir::tir::lower::create_stack_lowering(&options.target_config.name);
    let mut result_lines: Vec<String> = Vec::new();
    let mut any_neural = false;

    for (fn_name, fn_tir) in &functions {
        if fn_name.starts_with("__") || fn_tir.is_empty() {
            // Keep compiler output for internal functions
            let fn_baseline = lowering.lower(fn_tir);
            result_lines.extend(fn_baseline);
            continue;
        }

        // Lower this function's TIR to get full compiler output (with labels)
        let fn_full = lowering.lower(fn_tir);

        // Extract just instructions (no labels/comments) for neural comparison
        let fn_insns: Vec<String> = fn_full
            .iter()
            .filter(|l| {
                let t = l.trim();
                !t.is_empty() && !t.ends_with(':') && !t.starts_with("//")
            })
            .map(|l| l.trim().to_string())
            .collect();

        if fn_insns.is_empty() {
            result_lines.extend(fn_full);
            continue;
        }

        let compiler_cost = trident::cost::scorer::profile_tasm(
            &fn_insns.iter().map(|s| s.as_str()).collect::<Vec<_>>(),
        )
        .cost()
        .max(1);

        // Try neural compilation
        match trident::neural::compile_with_model(fn_tir, &fn_insns, model, device) {
            Ok(r) if r.neural && r.cost <= compiler_cost => {
                // Check if neural output is identical to compiler
                let identical = r.tasm_lines == fn_insns;
                if identical {
                    // Memorized compiler output โ€” not a real neural win
                    result_lines.push(format!(
                        "// {}: compiler (cost {}) [neural=identical]",
                        fn_name, compiler_cost
                    ));
                    result_lines.extend(fn_full);
                } else {
                    any_neural = true;
                    result_lines.push(format!(
                        "// {}: NEURAL (cost {}, compiler {})",
                        fn_name, r.cost, compiler_cost
                    ));
                    result_lines.push(format!("__{}:", fn_name));
                    let needs_return = !r.tasm_lines.last().is_some_and(|l| l.trim() == "return");
                    result_lines.extend(r.tasm_lines);
                    if needs_return {
                        result_lines.push("return".to_string());
                    }
                }
            }
            _ => {
                // Use full compiler output (already has labels)
                result_lines.push(format!("// {}: compiler (cost {})", fn_name, compiler_cost));
                result_lines.extend(fn_full);
            }
        }
    }

    if !any_neural {
        return None;
    }

    // Add halt at end for executability
    result_lines.push("halt".to_string());

    // Also write .neural.tasm to disk as cache for future runs
    let classical_path = source_path.to_string_lossy();
    if let Some(bench_path) = derive_neural_tasm_path(&classical_path) {
        let _ = std::fs::write(&bench_path, result_lines.join("\n"));
    }

    Some(result_lines.join("\n"))
}

/// Derive the .neural.tasm path from a source .tri path.
/// E.g. std/crypto/poseidon2.tri -> baselines/triton/std/crypto/poseidon2.neural.tasm
fn derive_neural_tasm_path(source_path: &str) -> Option<PathBuf> {
    // Find the relative part after the project root
    let source = Path::new(source_path);
    let file_stem = source.file_stem()?.to_string_lossy();
    let parent = source.parent()?;

    // Walk up to find "baselines" sibling
    let mut ancestor = parent;
    let mut rel_parts = vec![file_stem.to_string()];
    loop {
        if let Some(name) = ancestor.file_name() {
            rel_parts.push(name.to_string_lossy().to_string());
            ancestor = ancestor.parent()?;
            // Check if baselines/triton/ exists as sibling
            let baselines = ancestor.join("baselines").join("triton");
            if baselines.is_dir() {
                rel_parts.reverse();
                let rel = rel_parts.join("/");
                return Some(baselines.join(format!("{}.neural.tasm", rel)));
            }
        } else {
            return None;
        }
    }
}

/// Parsed live inputs from a `.inputs` file.
struct LiveInputs {
    values: Vec<u64>,
    divine: Vec<u64>,
}

/// Parse a `.inputs` file for live harness generation.
///
/// Format:
/// ```text
/// values: 1000, 2000, 3000, 8, 16, 64, ...
/// divine: 42, 17, 0, 3, ...
/// ```
///
/// Lines starting with `#` are comments. Blank lines ignored.
/// The `divine:` section is optional โ€” when present, divine values
/// are inlined into the harness TASM (replacing `divine N` instructions
/// with `push <val>`), enabling phases that need prover hints.
fn parse_inputs_file(path: &Path) -> Option<LiveInputs> {
    let content = std::fs::read_to_string(path).ok()?;
    let mut values = None;
    let mut divine = Vec::new();
    for line in content.lines() {
        let line = line.trim();
        if line.is_empty() || line.starts_with('#') {
            continue;
        }
        if let Some(rest) = line.strip_prefix("values:") {
            let vals: Vec<u64> = rest
                .split(',')
                .filter_map(|v| v.trim().parse().ok())
                .collect();
            values = Some(vals);
        } else if let Some(rest) = line.strip_prefix("divine:") {
            let vals: Vec<u64> = rest
                .split(',')
                .filter_map(|v| v.trim().parse().ok())
                .collect();
            divine = vals;
        }
    }
    Some(LiveInputs {
        values: values?,
        divine,
    })
}

/// Recursively find all .tasm baseline files in a directory (depth-limited).
/// Skips .neural.tasm and .formal.tasm (generated artifacts).
fn find_baseline_files(dir: &std::path::Path, depth: usize) -> Vec<PathBuf> {
    if depth >= 64 {
        return Vec::new();
    }
    let mut files = Vec::new();
    if let Ok(entries) = std::fs::read_dir(dir) {
        for entry in entries.flatten() {
            let path = entry.path();
            if path.is_dir() {
                files.extend(find_baseline_files(&path, depth + 1));
            } else if let Some(name) = path.file_name() {
                let name = name.to_string_lossy();
                if name.ends_with(".tasm")
                    && !name.ends_with(".neural.tasm")
                    && !name.ends_with(".formal.tasm")
                {
                    files.push(path);
                }
            }
        }
    }
    files
}

/// Find the project root from a baselines directory.
///
/// Walks up from `bench_dir` looking for a parent that contains a `baselines/`
/// child. This handles both `trident bench` (bench_dir = `baselines/triton/`)
/// and `trident bench baselines/triton/std/crypto` (subdirectory).
fn find_project_root(bench_dir: &Path) -> &Path {
    let mut dir = bench_dir;
    loop {
        if dir.file_name().map(|n| n == "baselines").unwrap_or(false) {
            return dir.parent().unwrap_or(Path::new("."));
        }
        match dir.parent() {
            Some(parent) if parent != dir => dir = parent,
            _ => return bench_dir.parent().unwrap_or(Path::new(".")),
        }
    }
}

/// Resolve the bench directory by searching ancestor directories.
fn resolve_bench_dir(dir: &std::path::Path) -> PathBuf {
    if dir.is_dir() {
        return dir.to_path_buf();
    }
    if dir.is_relative() {
        if let Ok(cwd) = std::env::current_dir() {
            let mut ancestor = cwd.as_path();
            loop {
                let candidate = ancestor.join(dir);
                if candidate.is_dir() {
                    return candidate;
                }
                match ancestor.parent() {
                    Some(parent) => ancestor = parent,
                    None => break,
                }
            }
        }
    }
    dir.to_path_buf()
}

Local Graph