//! Function call dispatch: intrinsic resolution and user-defined calls.

use crate::ast::*;
use crate::span::Spanned;
use crate::tir::TIROp;
use crate::typecheck::MonoInstance;

use super::TIRBuilder;

impl TIRBuilder {
    /// Emit a function call (intrinsic or user-defined).
    pub(crate) fn build_call(
        &mut self,
        name: &str,
        generic_args: &[Spanned<ArraySize>],
        args: &[Spanned<Expr>],
    ) {
        // Evaluate arguments โ€” each pushes a temp.
        for arg in args {
            self.build_expr(&arg.node);
        }

        // Pop all arg temps from the model.
        let arg_count = args.len();
        for _ in 0..arg_count {
            self.stack.pop();
        }

        // Resolve intrinsic name.
        let resolved_name = self.intrinsic_map.get(name).cloned().or_else(|| {
            name.rsplit('.')
                .next()
                .and_then(|short| self.intrinsic_map.get(short).cloned())
        });
        let effective_name = resolved_name.as_deref().unwrap_or(name);

        match effective_name {
            // โ”€โ”€ I/O โ”€โ”€
            "pub_read" => {
                self.emit_and_push(TIROp::ReadIo(1), 1);
            }
            "pub_read2" => {
                self.emit_and_push(TIROp::ReadIo(2), 2);
            }
            "pub_read3" => {
                self.emit_and_push(TIROp::ReadIo(3), 3);
            }
            "pub_read4" => {
                self.emit_and_push(TIROp::ReadIo(4), 4);
            }
            "pub_read5" => {
                self.emit_and_push(TIROp::ReadIo(5), 5);
            }
            "pub_write" => {
                self.ops.push(TIROp::WriteIo(1));
                self.push_temp(0);
            }
            "pub_write2" => {
                self.ops.push(TIROp::WriteIo(2));
                self.push_temp(0);
            }
            "pub_write3" => {
                self.ops.push(TIROp::WriteIo(3));
                self.push_temp(0);
            }
            "pub_write4" => {
                self.ops.push(TIROp::WriteIo(4));
                self.push_temp(0);
            }
            "pub_write5" => {
                self.ops.push(TIROp::WriteIo(5));
                self.push_temp(0);
            }

            // โ”€โ”€ Non-deterministic input โ”€โ”€
            "divine" => {
                self.emit_and_push(TIROp::Hint(1), 1);
            }
            "divine3" => {
                self.emit_and_push(TIROp::Hint(3), 3);
            }
            "divine5" => {
                self.emit_and_push(TIROp::Hint(5), 5);
            }

            // โ”€โ”€ Assertions โ”€โ”€
            "assert" => {
                self.ops.push(TIROp::Assert(1));
                self.push_temp(0);
            }
            "assert_eq" => {
                self.ops.push(TIROp::Eq);
                self.ops.push(TIROp::Assert(1));
                self.push_temp(0);
            }
            "assert_digest" => {
                self.ops.push(TIROp::Assert(5));
                self.ops.push(TIROp::Pop(self.target_config.digest_width));
                self.push_temp(0);
            }

            // โ”€โ”€ Field operations โ”€โ”€
            "field_add" => {
                self.ops.push(TIROp::Add);
                self.push_temp(1);
            }
            "field_mul" => {
                self.ops.push(TIROp::Mul);
                self.push_temp(1);
            }
            "inv" => {
                self.ops.push(TIROp::Invert);
                self.push_temp(1);
            }
            "neg" => {
                self.ops.push(TIROp::Neg);
                self.push_temp(1);
            }
            "sub" => {
                self.ops.push(TIROp::Sub);
                self.push_temp(1);
            }

            // โ”€โ”€ U32 operations โ”€โ”€
            "split" => {
                self.ops.push(TIROp::Split);
                self.push_temp(2);
            }
            "log2" => {
                self.ops.push(TIROp::Log2);
                self.push_temp(1);
            }
            "pow" => {
                self.ops.push(TIROp::Pow);
                self.push_temp(1);
            }
            "popcount" => {
                self.ops.push(TIROp::PopCount);
                self.push_temp(1);
            }

            // โ”€โ”€ Hash operations โ”€โ”€
            "hash" => {
                self.ops.push(TIROp::Hash {
                    width: self.target_config.digest_width,
                });
                self.push_temp(self.target_config.digest_width);
            }
            "sponge_init" => {
                self.ops.push(TIROp::SpongeInit);
                self.push_temp(0);
            }
            "sponge_absorb" => {
                self.ops.push(TIROp::SpongeAbsorb);
                self.push_temp(0);
            }
            "sponge_squeeze" => {
                self.emit_and_push(TIROp::SpongeSqueeze, self.target_config.hash_rate);
            }
            "sponge_absorb_mem" => {
                self.ops.push(TIROp::SpongeLoad);
                self.push_temp(0);
            }

            // โ”€โ”€ Merkle โ”€โ”€
            "merkle_step" => {
                self.emit_and_push(TIROp::MerkleStep, 6);
            }
            "merkle_step_mem" => {
                self.emit_and_push(TIROp::MerkleLoad, 7);
            }

            // โ”€โ”€ RAM โ”€โ”€
            "ram_read" => {
                self.ops.push(TIROp::RamRead { width: 1 });
                self.push_temp(1);
            }
            "ram_write" => {
                self.ops.push(TIROp::RamWrite { width: 1 });
                self.push_temp(0);
            }
            "ram_read_block" => {
                self.ops.push(TIROp::RamRead { width: 5 });
                self.push_temp(5);
            }
            "ram_write_block" => {
                self.ops.push(TIROp::RamWrite { width: 5 });
                self.push_temp(0);
            }

            // โ”€โ”€ Conversion โ”€โ”€
            "as_u32" => {
                // split: st0=lo (u32), st1=hi. Keep lo, discard hi.
                self.ops.push(TIROp::Split);
                self.ops.push(TIROp::Swap(1));
                self.ops.push(TIROp::Pop(1));
                self.push_temp(1);
            }
            "as_field" => {
                self.push_temp(1);
            }

            // โ”€โ”€ XField โ”€โ”€
            "xfield" => {
                self.push_temp(3);
            }
            "xinvert" => {
                self.ops.push(TIROp::ExtInvert);
                self.push_temp(3);
            }
            "xx_dot_step" => {
                self.emit_and_push(TIROp::FoldExt, 5);
            }
            "xb_dot_step" => {
                self.emit_and_push(TIROp::FoldBase, 5);
            }

            // โ”€โ”€ User-defined function โ”€โ”€
            _ => {
                self.build_user_call(name, generic_args);
            }
        }
    }

    /// Emit only the call/intrinsic opcode for a pass-through function.
    /// Does NOT evaluate arguments or touch the stack model โ€” the caller's
    /// params are already in place on the real stack.
    pub(crate) fn emit_call_only(
        &mut self,
        name: &str,
        generic_args: &[Spanned<ArraySize>],
        _arg_count: usize,
    ) {
        let resolved_name = self.intrinsic_map.get(name).cloned().or_else(|| {
            name.rsplit('.')
                .next()
                .and_then(|short| self.intrinsic_map.get(short).cloned())
        });
        let effective_name = resolved_name.as_deref().unwrap_or(name);

        match effective_name {
            "hash" => {
                self.ops.push(TIROp::Hash {
                    width: self.target_config.digest_width,
                });
            }
            "sponge_init" => self.ops.push(TIROp::SpongeInit),
            "sponge_absorb" => self.ops.push(TIROp::SpongeAbsorb),
            "sponge_squeeze" => self.ops.push(TIROp::SpongeSqueeze),
            "sponge_absorb_mem" => self.ops.push(TIROp::SpongeLoad),
            "assert" => self.ops.push(TIROp::Assert(1)),
            "assert_eq" => {
                self.ops.push(TIROp::Eq);
                self.ops.push(TIROp::Assert(1));
            }
            "pub_read" => self.ops.push(TIROp::ReadIo(1)),
            "pub_read2" => self.ops.push(TIROp::ReadIo(2)),
            "pub_read3" => self.ops.push(TIROp::ReadIo(3)),
            "pub_read4" => self.ops.push(TIROp::ReadIo(4)),
            "pub_read5" => self.ops.push(TIROp::ReadIo(5)),
            "pub_write" => self.ops.push(TIROp::WriteIo(1)),
            "pub_write2" => self.ops.push(TIROp::WriteIo(2)),
            "pub_write3" => self.ops.push(TIROp::WriteIo(3)),
            "pub_write4" => self.ops.push(TIROp::WriteIo(4)),
            "pub_write5" => self.ops.push(TIROp::WriteIo(5)),
            "divine" => self.ops.push(TIROp::Hint(1)),
            "divine3" => self.ops.push(TIROp::Hint(3)),
            "divine5" => self.ops.push(TIROp::Hint(5)),
            "split" => self.ops.push(TIROp::Split),
            "log2" => self.ops.push(TIROp::Log2),
            "pow" => self.ops.push(TIROp::Pow),
            "popcount" => self.ops.push(TIROp::PopCount),
            "inv" => self.ops.push(TIROp::Invert),
            "neg" => self.ops.push(TIROp::Neg),
            "sub" => self.ops.push(TIROp::Sub),
            "field_add" => self.ops.push(TIROp::Add),
            "field_mul" => self.ops.push(TIROp::Mul),
            "ram_read" => self.ops.push(TIROp::RamRead { width: 1 }),
            "ram_write" => self.ops.push(TIROp::RamWrite { width: 1 }),
            "ram_read_block" => self.ops.push(TIROp::RamRead { width: 5 }),
            "ram_write_block" => self.ops.push(TIROp::RamWrite { width: 5 }),
            "merkle_step" => self.ops.push(TIROp::MerkleStep),
            "merkle_step_mem" => self.ops.push(TIROp::MerkleLoad),
            "xinvert" => self.ops.push(TIROp::ExtInvert),
            "xx_dot_step" => self.ops.push(TIROp::FoldExt),
            "xb_dot_step" => self.ops.push(TIROp::FoldBase),
            "assert_digest" => {
                self.ops.push(TIROp::Assert(5));
                self.ops.push(TIROp::Pop(self.target_config.digest_width));
            }
            _ => {
                // User-defined call โ€” resolve label the same way as
                // build_user_call but skip stack model updates.
                let call_label = self.resolve_call_label(name, generic_args);
                self.ops.push(TIROp::Call(call_label));
            }
        }
    }

    /// Resolve a user-defined call name to its TASM label.
    /// Returns `(call_label, base_name)` where `base_name` is used for
    /// return width lookup.
    fn resolve_call_label(&mut self, name: &str, generic_args: &[Spanned<ArraySize>]) -> String {
        let is_generic = self.generic_fn_defs.contains_key(name);

        if is_generic {
            let size_args: Vec<u64> = if !generic_args.is_empty() {
                generic_args
                    .iter()
                    .map(|ga| ga.node.eval(&self.current_subs))
                    .collect()
            } else if !self.current_subs.is_empty() {
                if let Some(gdef) = self.generic_fn_defs.get(name) {
                    gdef.type_params
                        .iter()
                        .map(|p| self.current_subs.get(&p.node).copied().unwrap_or(0))
                        .collect()
                } else {
                    vec![]
                }
            } else {
                let idx = self.call_resolution_idx;
                if idx < self.call_resolutions.len() && self.call_resolutions[idx].name == name {
                    self.call_resolution_idx += 1;
                    self.call_resolutions[idx].size_args.clone()
                } else {
                    let mut found = vec![];
                    for (i, res) in self.call_resolutions.iter().enumerate() {
                        if i >= self.call_resolution_idx && res.name == name {
                            self.call_resolution_idx = i + 1;
                            found = res.size_args.clone();
                            break;
                        }
                    }
                    found
                }
            };
            let inst = MonoInstance {
                name: name.to_string(),
                size_args,
            };
            inst.mangled_name()
        } else if name.contains('.') {
            let parts: Vec<&str> = name.rsplitn(2, '.').collect();
            let fn_name = parts[0];
            let short_module = parts[1];
            let full_module = self
                .module_aliases
                .get(short_module)
                .map(|s| s.as_str())
                .unwrap_or(short_module);
            let mangled = full_module.replace('.', "_");
            // @ prefix marks cross-module calls so the linker doesn't re-prefix them
            format!("@{}__{}", mangled, fn_name)
        } else {
            name.to_string()
        }
    }

    /// Emit a call to a user-defined (non-intrinsic) function.
    fn build_user_call(&mut self, name: &str, generic_args: &[Spanned<ArraySize>]) {
        let call_label = self.resolve_call_label(name, generic_args);

        // For return width lookup, use the base name (without module prefix).
        let base_name = if name.contains('.') && !self.generic_fn_defs.contains_key(name) {
            name.rsplitn(2, '.').next().unwrap_or(name).to_string()
        } else {
            call_label.clone()
        };

        let ret_width = self.fn_return_widths.get(&base_name).copied().unwrap_or(0);
        if ret_width > 0 {
            self.emit_and_push(TIROp::Call(call_label), ret_width);
        } else {
            self.ops.push(TIROp::Call(call_label));
            self.push_temp(0);
        }
    }
}

Local Graph