pub mod triton;
use crate::ast::BinOp;
pub(crate) use triton::TritonCostModel;
pub const MAX_TABLES: usize = 8;
pub(crate) trait CostModel {
fn table_names(&self) -> &[&str];
fn table_short_names(&self) -> &[&str];
fn builtin_cost(&self, name: &str) -> TableCost;
fn binop_cost(&self, op: &BinOp) -> TableCost;
fn call_overhead(&self) -> TableCost;
fn stack_op(&self) -> TableCost;
fn if_overhead(&self) -> TableCost;
fn loop_overhead(&self) -> TableCost;
fn hash_rows_per_permutation(&self) -> u64;
fn trace_column_count(&self) -> u64;
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct TableCost {
pub values: [u64; MAX_TABLES],
pub count: u8,
}
impl Default for TableCost {
fn default() -> Self {
Self::ZERO
}
}
impl TableCost {
pub const ZERO: TableCost = TableCost {
values: [0; MAX_TABLES],
count: 0,
};
pub fn from_slice(vals: &[u64]) -> TableCost {
let mut values = [0u64; MAX_TABLES];
let n = vals.len().min(MAX_TABLES);
values[..n].copy_from_slice(&vals[..n]);
TableCost {
values,
count: n as u8,
}
}
pub fn get(&self, i: usize) -> u64 {
self.values[i]
}
pub fn is_nonzero(&self) -> bool {
let n = self.count as usize;
self.values[..n].iter().any(|&v| v > 0)
}
pub fn add(&self, other: &TableCost) -> TableCost {
let n = self.count.max(other.count) as usize;
let mut values = [0u64; MAX_TABLES];
for i in 0..n {
values[i] = self.values[i] + other.values[i];
}
TableCost {
values,
count: n as u8,
}
}
pub fn scale(&self, factor: u64) -> TableCost {
let n = self.count as usize;
let mut values = [0u64; MAX_TABLES];
for i in 0..n {
values[i] = self.values[i].saturating_mul(factor);
}
TableCost {
values,
count: self.count,
}
}
pub fn max(&self, other: &TableCost) -> TableCost {
let n = self.count.max(other.count) as usize;
let mut values = [0u64; MAX_TABLES];
for i in 0..n {
values[i] = self.values[i].max(other.values[i]);
}
TableCost {
values,
count: n as u8,
}
}
pub fn max_height(&self) -> u64 {
let n = self.count as usize;
self.values[..n].iter().copied().max().unwrap_or(0)
}
pub fn dominant_table<'a>(&self, short_names: &[&'a str]) -> &'a str {
let n = self.count as usize;
if n == 0 || short_names.is_empty() {
return "?";
}
let max = self.max_height();
if max == 0 {
return short_names[0];
}
for i in 0..n.min(short_names.len()) {
if self.values[i] == max {
return short_names[i];
}
}
short_names[0]
}
pub fn to_json_value(&self, names: &[&str]) -> String {
let n = self.count as usize;
let mut parts = Vec::new();
for i in 0..n.min(names.len()) {
let escaped = names[i].replace('\\', "\\\\").replace('"', "\\\"");
parts.push(format!("\"{}\": {}", escaped, self.values[i]));
}
format!("{{{}}}", parts.join(", "))
}
pub fn from_json_value(s: &str, names: &[&str]) -> Option<TableCost> {
fn extract_u64(s: &str, key: &str) -> Option<u64> {
let escaped = key.replace('\\', "\\\\").replace('"', "\\\"");
let needle = format!("\"{}\"", escaped);
let mut search_from = 0;
while let Some(pos) = s[search_from..].find(&needle) {
let idx = search_from + pos;
let rest = &s[idx + needle.len()..];
let trimmed = rest.trim_start();
if trimmed.starts_with(':') {
let after_colon = trimmed[1..].trim_start();
let end = after_colon
.find(|c: char| !c.is_ascii_digit())
.unwrap_or(after_colon.len());
return after_colon[..end].parse().ok();
}
search_from = idx + needle.len();
}
None
}
let mut values = [0u64; MAX_TABLES];
for (i, name) in names.iter().enumerate() {
values[i] = extract_u64(s, name)?;
}
Some(TableCost {
values,
count: names.len() as u8,
})
}
pub fn format_annotation(&self, short_names: &[&str]) -> String {
let n = self.count as usize;
let mut parts = Vec::new();
for i in 0..n.min(short_names.len()) {
if self.values[i] > 0 {
parts.push(format!("{}={}", short_names[i], self.values[i]));
}
}
parts.join(" ")
}
}
pub(crate) fn cost_builtin(target: &str, name: &str) -> TableCost {
create_cost_model(target).builtin_cost(name)
}
pub(crate) fn create_cost_model(target_name: &str) -> &'static dyn CostModel {
match target_name {
"triton" => &TritonCostModel,
_ => &TritonCostModel,
}
}