use super::analyzer::CostAnalyzer;
use super::model::TableCost;
use crate::ast::*;
// --- Per-function cost result ---
/// Cost analysis result for a single function.
impl<'a> CostAnalyzer<'a> {
pub(crate) fn cost_expr(&mut self, expr: &Expr) -> TableCost {
let stack_op = self.cost_model.stack_op();
match expr {
Expr::Literal(_) => {
// push instruction: 1 cc, 1 opstack.
stack_op
}
Expr::Var(_) => {
// dup instruction: 1 cc, 1 opstack.
stack_op
}
Expr::BinOp { op, lhs, rhs } => {
let lhs_cost = self.cost_expr(&lhs.node);
let rhs_cost = self.cost_expr(&rhs.node);
lhs_cost.add(&rhs_cost).add(&self.cost_model.binop_cost(op))
}
Expr::Call { path, args, .. } => {
let fn_name = path.node.as_dotted();
let args_cost = args
.iter()
.fold(TableCost::ZERO, |acc, a| acc.add(&self.cost_expr(&a.node)));
// Check if it's a builtin โ try full name first, then short name
// to handle cross-module calls like "hash.tip5" โ "tip5" โ "hash"
let base_name = fn_name.rsplit('.').next().unwrap_or(&fn_name);
let fn_cost = {
let c = self.cost_model.builtin_cost(&fn_name);
if c.is_nonzero() {
c
} else {
self.cost_model.builtin_cost(base_name)
}
};
if fn_cost.is_nonzero() {
// Builtin: use the cost table.
args_cost.add(&fn_cost)
} else {
// User-defined: look up body cost + call overhead.
let body_cost = if let Some(func) = self.fn_bodies.get(base_name).cloned() {
self.cost_fn(&func)
} else {
TableCost::ZERO
};
args_cost
.add(&body_cost)
.add(&self.cost_model.call_overhead())
}
}
Expr::FieldAccess { expr: inner, .. } => {
// Evaluate inner struct + dup field elements.
self.cost_expr(&inner.node).add(&stack_op)
}
Expr::Index { expr: inner, .. } => {
// Evaluate inner array + dup indexed element.
self.cost_expr(&inner.node).add(&stack_op)
}
Expr::StructInit { fields, .. } => {
fields.iter().fold(TableCost::ZERO, |acc, (_, val)| {
acc.add(&self.cost_expr(&val.node))
})
}
Expr::ArrayInit(elems) => elems
.iter()
.fold(TableCost::ZERO, |acc, e| acc.add(&self.cost_expr(&e.node))),
Expr::Tuple(elems) => elems
.iter()
.fold(TableCost::ZERO, |acc, e| acc.add(&self.cost_expr(&e.node))),
}
}
/// Find the first loop in a function and return its per-iteration cost + bound.
pub(crate) fn find_loop_iteration_cost(&mut self, func: &FnDef) -> Option<(TableCost, u64)> {
if let Some(body) = &func.body {
for stmt in &body.node.stmts {
if let Stmt::For {
bound,
body: loop_body,
end,
..
} = &stmt.node
{
let body_cost = self.cost_block(&loop_body.node);
let per_iter = body_cost.add(&self.cost_model.loop_overhead());
let iterations = if let Some(b) = bound {
*b
} else if let Expr::Literal(Literal::Integer(n)) = &end.node {
*n
} else {
1
};
return Some((per_iter, iterations));
}
}
}
None
}
/// Collect per-statement costs mapped to line numbers.
///
/// Walks every function body and records the cost of each statement
/// along with the 1-based line number derived from the statement's span.
/// Also records function definition lines with call overhead.
pub(crate) fn stmt_costs(&mut self, file: &File, source: &str) -> Vec<(u32, TableCost)> {
// Build line offset table: line_starts[i] = byte offset of line i+1
let line_starts: Vec<u32> = std::iter::once(0)
.chain(source.bytes().enumerate().filter_map(|(i, b)| {
if b == b'\n' {
Some((i + 1) as u32)
} else {
None
}
}))
.collect();
let byte_to_line = |offset: u32| -> u32 {
match line_starts.binary_search(&offset) {
Ok(i) => (i + 1) as u32,
Err(i) => i as u32,
}
};
// Ensure fn_bodies are populated (analyze_file does this, but in case
// stmt_costs is called standalone)
for item in &file.items {
if let Item::Fn(func) = &item.node {
self.fn_bodies
.entry(func.name.node.clone())
.or_insert_with(|| func.clone());
}
}
let mut result: Vec<(u32, TableCost)> = Vec::new();
for item in &file.items {
if let Item::Fn(func) = &item.node {
// Record function header with call overhead
let fn_line = byte_to_line(item.span.start);
result.push((fn_line, self.cost_model.call_overhead()));
if let Some(body) = &func.body {
self.collect_block_costs(&body.node, &byte_to_line, &mut result);
}
}
}
result.sort_by_key(|(line, _)| *line);
result
}
/// Recursively collect costs for all statements in a block.
pub(crate) fn collect_block_costs(
&mut self,
block: &Block,
byte_to_line: &dyn Fn(u32) -> u32,
result: &mut Vec<(u32, TableCost)>,
) {
for stmt in &block.stmts {
let line = byte_to_line(stmt.span.start);
let cost = self.cost_stmt(&stmt.node);
result.push((line, cost));
// Recurse into nested blocks
match &stmt.node {
Stmt::If {
then_block,
else_block,
..
} => {
self.collect_block_costs(&then_block.node, byte_to_line, result);
if let Some(eb) = else_block {
self.collect_block_costs(&eb.node, byte_to_line, result);
}
}
Stmt::For { body, .. } => {
self.collect_block_costs(&body.node, byte_to_line, result);
}
Stmt::Match { arms, .. } => {
for arm in arms {
self.collect_block_costs(&arm.body.node, byte_to_line, result);
}
}
_ => {}
}
}
if let Some(tail) = &block.tail_expr {
let line = byte_to_line(tail.span.start);
let cost = self.cost_expr(&tail.node);
result.push((line, cost));
}
}
/// H0004: scan a block for loops where declared bound >> constant end value.
pub(crate) fn scan_loop_bound_waste(&mut self, fn_name: &str, block: &Block) {
for stmt in &block.stmts {
if let Stmt::For {
end, bound, body, ..
} = &stmt.node
{
// Check if end is a constant and bound is declared
if let (Some(declared_bound), Expr::Literal(Literal::Integer(end_val))) =
(bound, &end.node)
{
if *declared_bound > *end_val * 4 && *declared_bound > 8 {
self.loop_bound_waste.push((
fn_name.to_string(),
*end_val,
*declared_bound,
));
}
}
// Recurse into loop body
self.scan_loop_bound_waste(fn_name, &body.node);
}
// Recurse into if/else blocks
if let Stmt::If {
then_block,
else_block,
..
} = &stmt.node
{
self.scan_loop_bound_waste(fn_name, &then_block.node);
if let Some(eb) = else_block {
self.scan_loop_bound_waste(fn_name, &eb.node);
}
}
// Recurse into match arms
if let Stmt::Match { arms, .. } = &stmt.node {
for arm in arms {
self.scan_loop_bound_waste(fn_name, &arm.body.node);
}
}
}
}
}
/// Smallest power of 2 >= n.
///
/// Delegates to `field::proof::padded_height` โ same formula, different name.
pub(crate) fn next_power_of_two(n: u64) -> u64 {
crate::field::proof::padded_height(n)
}
trident/src/cost/visit.rs
ฯ 0.0%
use CostAnalyzer;
use TableCost;
use crate*;
// --- Per-function cost result ---
/// Cost analysis result for a single function.
/// Smallest power of 2 >= n.
///
/// Delegates to `field::proof::padded_height` โ same formula, different name.
pub