#[cfg(test)]
mod tests;
mod triton;
use super::neural::report::{
BlockDecision, DecisionReason, OptimizerReport, OptimizerStatus, Winner,
};
use super::TIROp;
use crate::cost::scorer;
pub use triton::TritonLowering;
pub trait StackLowering {
fn lower(&self, ops: &[TIROp]) -> Vec<String>;
}
pub fn create_stack_lowering(_target: &str) -> Box<dyn StackLowering> {
Box::new(TritonLowering::new())
}
pub fn create_speculative_lowering(
_target: &str,
meta_generation: u64,
meta_hash: String,
meta_status: OptimizerStatus,
) -> SpeculativeLowering {
SpeculativeLowering {
classical: TritonLowering::new(),
report: std::cell::RefCell::new(OptimizerReport {
status: meta_status,
generation: meta_generation,
weight_hash: meta_hash,
decisions: Vec::new(),
total_neural_cost: 0,
total_classical_cost: 0,
}),
}
}
pub struct SpeculativeLowering {
classical: TritonLowering,
report: std::cell::RefCell<OptimizerReport>,
}
impl SpeculativeLowering {
pub fn report(&self) -> OptimizerReport {
self.report.borrow().clone()
}
pub fn inject_neural_candidate(
&self,
block_id: &str,
candidate_tasm: &[String],
baseline_cost: u64,
) {
let mut report = self.report.borrow_mut();
if candidate_tasm.is_empty() {
report.decisions.push(BlockDecision {
block_id: block_id.to_string(),
winner: Winner::Classical,
winner_cost: baseline_cost,
loser_cost: baseline_cost,
reason: DecisionReason::NoCandidate,
});
return;
}
let candidate_profile = scorer::profile_tasm(
&candidate_tasm
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>(),
);
let candidate_cost = candidate_profile.cost();
let baseline_profile = scorer::profile_tasm_str(&format!(
"push 0\n" ));
if candidate_cost < baseline_cost {
let reason = if candidate_profile.is_cliff_jump(&baseline_profile) {
DecisionReason::CliffJump
} else {
DecisionReason::StackScheduling
};
report.decisions.push(BlockDecision {
block_id: block_id.to_string(),
winner: Winner::Neural,
winner_cost: candidate_cost,
loser_cost: baseline_cost,
reason,
});
} else {
report.decisions.push(BlockDecision {
block_id: block_id.to_string(),
winner: Winner::Classical,
winner_cost: baseline_cost,
loser_cost: baseline_cost,
reason: DecisionReason::NeuralWorse(candidate_cost),
});
}
}
}
impl StackLowering for SpeculativeLowering {
fn lower(&self, ops: &[TIROp]) -> Vec<String> {
self.classical.lower(ops)
}
}
pub fn decode_output(codes: &[u64]) -> Vec<String> {
const VOCAB: &[&str] = &[
"", "push 0", "push 1", "push -1", "pop 1", "pop 2", "pop 3", "pop 4", "pop 5", "dup 0", "dup 1", "dup 2", "dup 3", "dup 4", "dup 5", "swap 1", "swap 2", "swap 3", "swap 4", "swap 5", "add", "mul", "eq", "lt", "and", "xor", "div_mod", "split", "pop_count", "log_2_floor", "nop", "assert", "dup 9", "write_io 1", "dup 11", "dup 12", "divine 1", "dup 14", "dup 15", "swap 10", "swap 11", "swap 12", "swap 13", "halt", "swap 15", "write_io 5", "pick 2", "pick 3", "divine 5", "pick 5", "place 1", "place 2", "place 3", "place 4", "place 5", "push 2", "push 3", "assert_vector", "dup 6", "dup 7", "swap 6", "swap 7", "swap 8", "swap 9", ];
let mut out = Vec::new();
for &code in codes {
let idx = code as usize;
if idx == 0 || idx >= VOCAB.len() {
break;
}
out.push(VOCAB[idx].to_string());
}
out
}
pub fn encode_tasm_line(line: &str) -> Option<u64> {
const VOCAB: &[&str] = &[
"", "push 0", "push 1", "push -1", "pop 1", "pop 2", "pop 3", "pop 4", "pop 5", "dup 0", "dup 1", "dup 2", "dup 3", "dup 4", "dup 5", "swap 1", "swap 2", "swap 3", "swap 4", "swap 5", "add", "mul", "eq", "lt", "and", "xor", "div_mod", "split", "pop_count", "log_2_floor", "nop", "assert", "dup 9", "write_io 1", "dup 11", "dup 12", "divine 1", "dup 14", "dup 15", "swap 10", "swap 11", "swap 12", "swap 13", "halt", "swap 15", "write_io 5", "pick 2", "pick 3", "divine 5", "pick 5", "place 1", "place 2", "place 3", "place 4", "place 5", "push 2", "push 3", "assert_vector", "dup 6", "dup 7", "swap 6", "swap 7", "swap 8", "swap 9", ];
let trimmed = line.trim();
for (i, &entry) in VOCAB.iter().enumerate().skip(1) {
if trimmed == entry {
return Some(i as u64);
}
}
None
}
pub fn encode_tasm_block(lines: &[String]) -> Vec<u64> {
lines.iter().filter_map(|l| encode_tasm_line(l)).collect()
}