//! SMT-LIB2 encoder for Trident constraint systems.
//!
//! Encodes the `ConstraintSystem` from `sym.rs` as SMT-LIB2 queries
//! compatible with Z3, CVC5, and other SMT solvers.
//!
//! Goldilocks field arithmetic is encoded using bitvector operations:
//! - Field elements as 128-bit bitvectors (to handle multiplication overflow)
//! - All operations mod p (p = 2^64 - 2^32 + 1)
//! - Equality checks, range constraints, conditional assertions
//!
//! The encoder produces two kinds of queries:
//! 1. **Safety check**: Is there an assignment that violates any constraint?
//!    (check-sat on negation of all constraints)
//! 2. **Witness existence**: For divine inputs, does a valid witness exist?
//!    (check-sat on all constraints)

use crate::sym::{Constraint, ConstraintSystem, SymValue, GOLDILOCKS_P};
use std::collections::BTreeSet;

/// Generate SMT-LIB2 encoding of a constraint system.
///
/// Returns the complete SMT-LIB2 script as a string.
pub fn encode_system(system: &ConstraintSystem, mode: QueryMode) -> String {
    let mut encoder = SmtEncoder::new(mode);
    encoder.encode(system);
    encoder.output
}

/// What kind of SMT query to generate.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum QueryMode {
    /// Check if constraints can be violated (negate and check-sat).
    /// SAT โ†’ found a counterexample (bug). UNSAT โ†’ safe.
    SafetyCheck,
    /// Check if a valid witness exists for divine inputs.
    /// SAT โ†’ witness found. UNSAT โ†’ no valid witness.
    WitnessExistence,
}

/// Result of running an SMT solver.
#[derive(Clone, Debug)]
pub struct SmtResult {
    /// Raw solver output.
    pub output: String,
    /// Parsed result.
    pub status: SmtStatus,
    /// Model (variable assignments) if SAT.
    pub model: Option<String>,
}

#[derive(Clone, Debug, PartialEq, Eq)]
pub enum SmtStatus {
    Sat,
    Unsat,
    Unknown,
    Error(String),
}

struct SmtEncoder {
    output: String,
    mode: QueryMode,
    declared_vars: BTreeSet<String>,
}

impl SmtEncoder {
    fn new(mode: QueryMode) -> Self {
        Self {
            output: String::new(),
            mode,
            declared_vars: BTreeSet::new(),
        }
    }

    fn emit(&mut self, s: &str) {
        self.output.push_str(s);
        self.output.push('\n');
    }

    fn encode(&mut self, system: &ConstraintSystem) {
        // Header
        self.emit("; Trident SMT-LIB2 encoding");
        self.emit("; Generated by trident audit");
        self.emit(&format!("; Mode: {:?}", self.mode));
        self.emit(&format!("; Variables: {}", system.num_variables));
        self.emit(&format!("; Constraints: {}", system.constraints.len()));
        self.emit("");
        self.emit("(set-logic QF_BV)");
        self.emit("");

        // Define the Goldilocks prime as a constant
        self.emit("; Goldilocks prime: p = 2^64 - 2^32 + 1");
        self.emit(&format!(
            "(define-fun GOLDILOCKS_P () (_ BitVec 128) (_ bv{} 128))",
            GOLDILOCKS_P
        ));
        self.emit("");

        // Helper: field_mod(x) = x mod p (for 128-bit intermediate results)
        self.emit("; Field modular reduction");
        self.emit("(define-fun field_mod ((x (_ BitVec 128))) (_ BitVec 128)");
        self.emit("  (bvurem x GOLDILOCKS_P))");
        self.emit("");

        // Declare all variables
        self.emit("; Variable declarations");
        self.declare_variables(system);
        self.emit("");

        // Field range constraints: all variables < p
        self.emit("; Field range constraints (all values < p)");
        for var_name in &self.declared_vars.clone() {
            self.emit(&format!("(assert (bvult {} GOLDILOCKS_P))", var_name));
        }
        self.emit("");

        // Encode constraints
        match self.mode {
            QueryMode::SafetyCheck => {
                // For safety: assert negation of each constraint and check if
                // any can be violated. We assert all constraints hold, then
                // ask if this is satisfiable. If UNSAT, no valid input exists
                // (vacuously safe). If SAT, all constraints hold for that input.
                //
                // Actually: we want to find a COUNTEREXAMPLE. So we assert the
                // negation of at least one constraint.
                self.emit("; Safety check: can any constraint be violated?");
                self.emit("; (SAT = counterexample found, UNSAT = all constraints hold)");
                self.emit("");

                if system.constraints.is_empty() {
                    self.emit("; No constraints to check");
                    self.emit("(assert true)");
                } else {
                    // Assert: NOT (c1 AND c2 AND ... AND cn)
                    // Equivalent to: c1_neg OR c2_neg OR ... OR cn_neg
                    let mut disjuncts = Vec::new();
                    for (i, constraint) in system.constraints.iter().enumerate() {
                        let smt = self.encode_constraint(constraint);
                        self.emit(&format!("; Constraint #{}", i));
                        disjuncts.push(format!("(not {})", smt));
                    }

                    if disjuncts.len() == 1 {
                        self.emit(&format!("(assert {})", disjuncts[0]));
                    } else {
                        self.emit(&format!("(assert (or {}))", disjuncts.join(" ")));
                    }
                }
            }
            QueryMode::WitnessExistence => {
                // For witness: assert all constraints and check satisfiability.
                // SAT โ†’ valid witness exists for divine inputs.
                self.emit("; Witness existence: do valid divine values exist?");
                self.emit("; (SAT = witness found, UNSAT = no valid witness)");
                self.emit("");

                for (i, constraint) in system.constraints.iter().enumerate() {
                    let smt = self.encode_constraint(constraint);
                    self.emit(&format!("; Constraint #{}", i));
                    self.emit(&format!("(assert {})", smt));
                }
            }
        }

        self.emit("");
        self.emit("(check-sat)");
        self.emit("(get-model)");
        self.emit("(exit)");
    }

    fn declare_variables(&mut self, system: &ConstraintSystem) {
        // Collect all variable names from the constraint system
        let mut var_names: Vec<String> = Vec::new();

        for (name, max_version) in &system.variables {
            for v in 0..=*max_version {
                let var_name = if v == 0 {
                    name.clone()
                } else {
                    format!("{}_{}", name, v)
                };
                let smt_name = sanitize_smt_name(&var_name);
                if !self.declared_vars.contains(&smt_name) {
                    var_names.push(smt_name.clone());
                    self.declared_vars.insert(smt_name);
                }
            }
        }

        // Also declare pub_input and divine variables
        for pi in &system.pub_inputs {
            let smt_name = sanitize_smt_name(&pi.to_string());
            if !self.declared_vars.contains(&smt_name) {
                var_names.push(smt_name.clone());
                self.declared_vars.insert(smt_name);
            }
        }
        for di in &system.divine_inputs {
            let smt_name = sanitize_smt_name(&di.to_string());
            if !self.declared_vars.contains(&smt_name) {
                var_names.push(smt_name.clone());
                self.declared_vars.insert(smt_name);
            }
        }

        for name in &var_names {
            self.emit(&format!("(declare-fun {} () (_ BitVec 128))", name));
        }
    }

    fn encode_constraint(&mut self, constraint: &Constraint) -> String {
        match constraint {
            Constraint::Equal(a, b) => {
                let sa = self.encode_value(a);
                let sb = self.encode_value(b);
                format!("(= {} {})", sa, sb)
            }
            Constraint::AssertTrue(v) => {
                let sv = self.encode_value(v);
                // In Trident, true = 1, false = 0. Assert v != 0.
                format!("(not (= {} (_ bv0 128)))", sv)
            }
            Constraint::Conditional(cond, inner) => {
                let sc = self.encode_value(cond);
                let si = self.encode_constraint(inner);
                // If cond != 0 then inner must hold
                format!("(=> (not (= {} (_ bv0 128))) {})", sc, si)
            }
            Constraint::RangeU32(v) => {
                let sv = self.encode_value(v);
                format!("(bvule {} (_ bv{} 128))", sv, u32::MAX)
            }
            Constraint::DigestEqual(a, b) => {
                let mut conjuncts = Vec::new();
                for (x, y) in a.iter().zip(b.iter()) {
                    let sx = self.encode_value(x);
                    let sy = self.encode_value(y);
                    conjuncts.push(format!("(= {} {})", sx, sy));
                }
                if conjuncts.len() == 1 {
                    conjuncts[0].clone()
                } else {
                    format!("(and {})", conjuncts.join(" "))
                }
            }
        }
    }

    fn encode_value(&mut self, value: &SymValue) -> String {
        match value {
            SymValue::Const(c) => {
                format!("(_ bv{} 128)", c % GOLDILOCKS_P)
            }
            SymValue::Var(var) => {
                let name = sanitize_smt_name(&var.to_string());
                // Ensure variable is declared
                if !self.declared_vars.contains(&name) {
                    self.declared_vars.insert(name.clone());
                    // This will be emitted out of order, but that's OK for SMT-LIB2
                    // in incremental mode. For safety, we handle this in declare_variables.
                }
                name
            }
            SymValue::Add(a, b) => {
                let sa = self.encode_value(a);
                let sb = self.encode_value(b);
                format!("(field_mod (bvadd {} {}))", sa, sb)
            }
            SymValue::Mul(a, b) => {
                let sa = self.encode_value(a);
                let sb = self.encode_value(b);
                format!("(field_mod (bvmul {} {}))", sa, sb)
            }
            SymValue::Sub(a, b) => {
                let sa = self.encode_value(a);
                let sb = self.encode_value(b);
                // a - b mod p = (a + p - b) mod p
                format!("(field_mod (bvadd {} (bvsub GOLDILOCKS_P {})))", sa, sb)
            }
            SymValue::Neg(a) => {
                let sa = self.encode_value(a);
                format!("(field_mod (bvsub GOLDILOCKS_P {}))", sa)
            }
            SymValue::Inv(a) => {
                // Inverse is hard to encode directly in BV. We use an
                // existential: declare a fresh variable inv_x, assert
                // inv_x * x == 1 mod p.
                let _sa = self.encode_value(a);
                let inv_name = format!("__inv_{}", self.declared_vars.len());
                self.declared_vars.insert(inv_name.clone());
                // We can't add declarations mid-stream easily, so just
                // return the variable name. The caller should handle this.
                inv_name
            }
            SymValue::Eq(a, b) => {
                let sa = self.encode_value(a);
                let sb = self.encode_value(b);
                // Returns 1 if equal, 0 otherwise
                format!("(ite (= {} {}) (_ bv1 128) (_ bv0 128))", sa, sb)
            }
            SymValue::Lt(a, b) => {
                let sa = self.encode_value(a);
                let sb = self.encode_value(b);
                format!("(ite (bvult {} {}) (_ bv1 128) (_ bv0 128))", sa, sb)
            }
            SymValue::Hash(inputs, index) => {
                // Hash is opaque โ€” create uninterpreted function
                let hash_name = format!("__hash_{}_{}", inputs.len(), index);
                let name = sanitize_smt_name(&hash_name);
                if !self.declared_vars.contains(&name) {
                    self.declared_vars.insert(name.clone());
                }
                name
            }
            SymValue::Divine(idx) => {
                let name = format!("divine_{}", idx);
                let smt_name = sanitize_smt_name(&name);
                if !self.declared_vars.contains(&smt_name) {
                    self.declared_vars.insert(smt_name.clone());
                }
                smt_name
            }
            SymValue::PubInput(idx) => {
                let name = format!("pub_in_{}", idx);
                let smt_name = sanitize_smt_name(&name);
                if !self.declared_vars.contains(&smt_name) {
                    self.declared_vars.insert(smt_name.clone());
                }
                smt_name
            }
            SymValue::Ite(cond, then_val, else_val) => {
                let sc = self.encode_value(cond);
                let st = self.encode_value(then_val);
                let se = self.encode_value(else_val);
                format!("(ite (not (= {} (_ bv0 128))) {} {})", sc, st, se)
            }
            SymValue::FieldAccess(inner, field) => {
                // Field access is opaque โ€” create uninterpreted function
                let inner_enc = self.encode_value(inner);
                let name = format!("__field_{}_{}", inner_enc, sanitize_smt_name(field));
                if !self.declared_vars.contains(&name) {
                    self.declared_vars.insert(name.clone());
                }
                name
            }
        }
    }
}

/// Sanitize a variable name for SMT-LIB2 (replace dots with underscores, etc.).
fn sanitize_smt_name(name: &str) -> String {
    let sanitized: String = name
        .chars()
        .map(|c| {
            if c.is_alphanumeric() || c == '_' {
                c
            } else {
                '_'
            }
        })
        .collect();
    // SMT-LIB2 names can't start with a digit
    if sanitized.starts_with(|c: char| c.is_ascii_digit()) {
        format!("v_{}", sanitized)
    } else {
        sanitized
    }
}

// --- Z3 Process Runner ---

mod runner;
pub use runner::run_z3;

#[cfg(test)]
mod tests;

Dimensions

trident/src/diagnostic/mod.rs
trident/src/ir/mod.rs
trident/src/deploy/mod.rs
trident/src/syntax/mod.rs
trident/src/api/mod.rs
nebu/rs/extension/mod.rs
optica/src/render/mod.rs
trident/src/config/mod.rs
trident/src/field/mod.rs
trident/src/cli/mod.rs
optica/src/parser/mod.rs
trident/src/neural/mod.rs
trident/src/cost/mod.rs
trident/src/typecheck/mod.rs
optica/src/server/mod.rs
trident/src/package/mod.rs
optica/src/scanner/mod.rs
optica/src/output/mod.rs
trident/src/verify/mod.rs
optica/src/graph/mod.rs
trident/src/ast/mod.rs
trident/src/lsp/mod.rs
trident/src/runtime/mod.rs
trident/src/gpu/mod.rs
optica/src/query/mod.rs
trident/src/lsp/semantic/mod.rs
trident/src/verify/equiv/mod.rs
trident/src/package/hash/mod.rs
trident/src/neural/training/mod.rs
trident/src/verify/synthesize/mod.rs
trident/src/ir/tir/mod.rs
rs/macros/src/addressed/mod.rs
trident/src/package/registry/mod.rs
rs/rsc/src/lints/mod.rs
trident/src/verify/report/mod.rs
trident/src/config/resolve/mod.rs
trident/src/verify/solve/mod.rs
rs/macros/src/registers/mod.rs
rs/macros/src/cell/mod.rs
rs/core/src/fixed_point/mod.rs
trident/src/neural/data/mod.rs
rs/core/src/bounded/mod.rs
trident/src/lsp/util/mod.rs
trident/src/typecheck/tests/mod.rs
trident/src/neural/model/mod.rs
trident/src/cost/stack_verifier/mod.rs
trident/src/syntax/grammar/mod.rs
trident/src/package/manifest/mod.rs
trident/src/syntax/parser/mod.rs
trident/src/ir/kir/mod.rs
trident/src/neural/inference/mod.rs
trident/src/syntax/lexer/mod.rs
trident/src/cost/model/mod.rs
trident/src/ir/lir/mod.rs
trident/src/syntax/format/mod.rs
trident/src/config/scaffold/mod.rs
trident/src/verify/sym/mod.rs
trident/src/api/tests/mod.rs
trident/src/package/store/mod.rs
trident/src/ir/tree/mod.rs
trident/src/ir/kir/lower/mod.rs
trident/src/ir/lir/lower/mod.rs
trident/src/ir/tir/lower/mod.rs
trident/src/ir/tir/builder/mod.rs
trident/src/ir/tir/neural/mod.rs
trident/src/neural/data/tir_graph/mod.rs
trident/src/syntax/parser/tests/mod.rs
cw-cyber/packages/cyber-std/src/tokenfactory/mod.rs
trident/src/ir/tree/lower/mod.rs
trident/src/ir/tir/stack/mod.rs
cw-cyber/contracts/cybernet/src/tests/mod.rs
trident/src/ir/tir/optimize/mod.rs

Local Graph