use nox::noun::{Order, NounId};
use super::{CompileError, formula_parts, body_pair, body_triple, atom_u64, axis_to_param,
detect_loop_setup, detect_back_edge};
const P: u64 = 0xFFFF_FFFF_0000_0001;
pub fn compile_to_amx<const N: usize>(
order: &Order<N>,
formula: NounId,
num_params: u32,
) -> Result<String, CompileError> {
let mut e = AmxEmitter::new(num_params);
e.emit_formula(order, formula)?;
let result = e.pop_reg();
Ok(e.finish(&result))
}
pub fn compile_to_amx_parallel<const N: usize>(
order: &Order<N>,
formula: NounId,
num_params: u32,
) -> Result<String, CompileError> {
let mut e = AmxEmitter::new(num_params);
e.emit_formula(order, formula)?;
let result = e.pop_reg();
Ok(e.finish_parallel(&result, num_params))
}
struct AmxEmitter {
body: String,
num_params: u32,
next_reg: u32,
next_label: u32,
reg_stack: Vec<String>,
subject: Vec<String>,
loop_state: Option<AmxLoopState>,
}
const SCRATCH_REGS: &[&str] = &[
"rbx", "r12", "r13", "r14", "r15",
];
const TEMP_REGS: &[&str] = &[
"r8", "r9", "r10", "r11",
];
const PARAM_REGS: &[&str] = &[
"rdi", "rsi", "rdx", "rcx",
];
#[derive(Clone)]
struct AmxLoopState {
carried: Vec<String>,
formula_reg: String,
header_label: String,
}
impl AmxEmitter {
fn new(num_params: u32) -> Self {
let subject: Vec<String> = (0..num_params).rev()
.map(|i| {
if (i as usize) < PARAM_REGS.len() {
format!("%{}", PARAM_REGS[i as usize])
} else {
format!("{}(%rbp)", 16 + (i as usize - PARAM_REGS.len()) * 8)
}
})
.collect();
Self {
body: String::with_capacity(4096),
num_params,
next_reg: 0,
next_label: 0,
reg_stack: Vec::new(),
subject,
loop_state: None,
}
}
fn alloc_vreg(&mut self) -> String {
let r = format!("%vr{}", self.next_reg);
self.next_reg += 1;
r
}
fn alloc_label(&mut self) -> String {
let l = format!(".L{}", self.next_label);
self.next_label += 1;
l
}
fn push_reg(&mut self) -> String {
let r = self.alloc_vreg();
self.reg_stack.push(r.clone());
r
}
fn pop_reg(&mut self) -> String {
self.reg_stack.pop().unwrap_or_else(|| "%vr0".into())
}
fn emit(&mut self, line: &str) {
self.body.push_str(" ");
self.body.push_str(line);
self.body.push('\n');
}
fn emit_label(&mut self, label: &str) {
self.body.push_str(label);
self.body.push_str(":\n");
}
fn emit_formula<const N: usize>(&mut self, order: &Order<N>, formula: NounId) -> Result<(), CompileError> {
let (tag, body) = formula_parts(order, formula)?;
match tag {
0 => self.emit_axis(order, body),
1 => self.emit_quote(order, body),
2 => self.emit_compose(order, body),
4 => self.emit_branch(order, body),
5 => self.emit_add(order, body),
6 => self.emit_sub(order, body),
7 => self.emit_mul(order, body),
9 => self.emit_eq(order, body),
10 => self.emit_lt(order, body),
11 => self.emit_xor(order, body),
12 => self.emit_and(order, body),
13 => self.emit_not(order, body),
14 => self.emit_shl(order, body),
_ => Err(CompileError::UnsupportedPattern(tag)),
}
}
fn emit_axis<const N: usize>(&mut self, order: &Order<N>, body: NounId) -> Result<(), CompileError> {
let addr = atom_u64(order, body)?;
let depth = axis_to_param(addr)?;
if (depth as usize) >= self.subject.len() { return Err(CompileError::NoParams); }
let src = self.subject[depth as usize].clone();
let dst = self.push_reg();
self.emit(&format!("movq {}, {} # axis {}", src, dst, addr));
Ok(())
}
fn emit_quote<const N: usize>(&mut self, order: &Order<N>, body: NounId) -> Result<(), CompileError> {
let val = atom_u64(order, body)?;
let dst = self.push_reg();
self.emit(&format!("movabsq ${}, {}", val, dst));
Ok(())
}
fn emit_compose<const N: usize>(&mut self, order: &Order<N>, body: NounId) -> Result<(), CompileError> {
if let Some((loop_body, inits)) = detect_loop_setup(order, body) {
return self.emit_loop(order, loop_body, &inits);
}
if let Some((new_subj, _)) = detect_back_edge(order, body) {
return self.emit_back_edge(order, new_subj);
}
let (a_formula, b_formula) = body_pair(order, body)?;
let (a_tag, a_body) = formula_parts(order, a_formula)?;
if a_tag != 3 { return Err(CompileError::UnsupportedPattern(2)); }
let (value_formula, identity) = body_pair(order, a_body)?;
let (id_tag, id_body) = formula_parts(order, identity)?;
if id_tag != 0 || atom_u64(order, id_body)? != 1 {
return Err(CompileError::UnsupportedPattern(2));
}
let (b_tag, body_formula) = formula_parts(order, b_formula)?;
if b_tag != 1 { return Err(CompileError::UnsupportedPattern(2)); }
self.emit_formula(order, value_formula)?;
let val = self.pop_reg();
self.subject.insert(0, val);
let result = self.emit_formula(order, body_formula);
self.subject.remove(0);
result
}
fn emit_loop<const N: usize>(
&mut self, order: &Order<N>, loop_body: NounId, inits: &[NounId],
) -> Result<(), CompileError> {
let formula_reg = self.alloc_vreg();
self.emit(&format!("xorq {f}, {f}", f=formula_reg));
let mut carried = Vec::new();
for &init in inits.iter() {
self.emit_formula(order, init)?;
let val = self.pop_reg();
let cr = self.alloc_vreg();
self.emit(&format!("movq {}, {}", val, cr));
carried.push(cr);
}
let saved = self.subject.clone();
for cr in carried.iter() {
self.subject.insert(0, cr.clone());
}
self.subject.insert(0, formula_reg.clone());
let header = self.alloc_label();
let prev = self.loop_state.take();
self.loop_state = Some(AmxLoopState {
carried: carried.clone(),
formula_reg: formula_reg.clone(),
header_label: header.clone(),
});
self.emit_label(&header);
self.emit_formula(order, loop_body)?;
self.loop_state = prev;
self.subject = saved;
Ok(())
}
fn emit_back_edge<const N: usize>(
&mut self, order: &Order<N>, new_subj: NounId,
) -> Result<(), CompileError> {
let ls = self.loop_state.as_ref()
.ok_or(CompileError::UnsupportedPattern(2))?.clone();
let (tag, cons_body) = formula_parts(order, new_subj)?;
if tag != 3 { return Err(CompileError::UnsupportedPattern(2)); }
let (_, rest) = body_pair(order, cons_body)?;
let mut new_vals = Vec::new();
let mut cur = rest;
for _ in ls.carried.iter() {
let (tag, cb) = formula_parts(order, cur)?;
if tag != 3 { break; }
let (val_formula, tail) = body_pair(order, cb)?;
self.emit_formula(order, val_formula)?;
new_vals.push(self.pop_reg());
cur = tail;
}
for (i, cr) in ls.carried.iter().enumerate() {
if i < new_vals.len() {
self.emit(&format!("movq {}, {}", new_vals[i], cr));
}
}
self.emit(&format!("jmp {}", ls.header_label));
let _ = self.push_reg(); Ok(())
}
fn emit_branch<const N: usize>(&mut self, order: &Order<N>, body: NounId) -> Result<(), CompileError> {
let (test, yes, no) = body_triple(order, body)?;
self.emit_formula(order, test)?;
let test_r = self.pop_reg();
let lbl_no = self.alloc_label();
let lbl_end = self.alloc_label();
let dst = self.alloc_vreg();
self.emit(&format!("testq {t}, {t}", t=test_r));
self.emit(&format!("jnz {}", lbl_no));
self.emit_formula(order, yes)?;
let yes_r = self.pop_reg();
self.emit(&format!("movq {}, {}", yes_r, dst));
self.emit(&format!("jmp {}", lbl_end));
self.emit_label(&lbl_no);
self.emit_formula(order, no)?;
let no_r = self.pop_reg();
self.emit(&format!("movq {}, {}", no_r, dst));
self.emit_label(&lbl_end);
self.reg_stack.push(dst);
Ok(())
}
fn emit_add<const N: usize>(&mut self, order: &Order<N>, body: NounId) -> Result<(), CompileError> {
let (a, b) = body_pair(order, body)?;
self.emit_formula(order, a)?;
let ra = self.pop_reg();
self.emit_formula(order, b)?;
let rb = self.pop_reg();
let dst = self.push_reg();
let tmp = self.alloc_vreg();
self.emit(&format!("movq {}, {}", ra, dst));
self.emit(&format!("addq {}, {}", rb, dst));
self.emit(&format!("movl $4294967295, {}d", tmp));
self.emit(&format!("adcq $0, {}", tmp)); self.emit(&format!("subq $1, {}", tmp)); self.emit(&format!("jc 1f"));
self.emit(&format!("addq {}, {}", tmp, dst)); self.emit("1:");
self.emit_reduce_mod_p(&dst);
Ok(())
}
fn emit_sub<const N: usize>(&mut self, order: &Order<N>, body: NounId) -> Result<(), CompileError> {
let (a, b) = body_pair(order, body)?;
self.emit_formula(order, a)?;
let ra = self.pop_reg();
self.emit_formula(order, b)?;
let rb = self.pop_reg();
let dst = self.push_reg();
let tmp = self.alloc_vreg();
let lbl_no_borrow = self.alloc_label();
let lbl_end = self.alloc_label();
self.emit(&format!("movq {}, {}", ra, dst));
self.emit(&format!("subq {}, {}", rb, dst));
self.emit(&format!("jnc {}", lbl_no_borrow));
self.emit(&format!("movabsq ${}, {}", P, tmp));
self.emit(&format!("addq {}, {}", tmp, dst));
self.emit_label(&lbl_no_borrow);
self.emit_label(&lbl_end);
Ok(())
}
fn emit_mul<const N: usize>(&mut self, order: &Order<N>, body: NounId) -> Result<(), CompileError> {
let (a, b) = body_pair(order, body)?;
self.emit_formula(order, a)?;
let ra = self.pop_reg();
self.emit_formula(order, b)?;
let rb = self.pop_reg();
let dst = self.push_reg();
let hi = self.alloc_vreg();
let tmp = self.alloc_vreg();
let tmp2 = self.alloc_vreg();
self.emit(&format!("movq {}, %rax", ra));
self.emit(&format!("mulq {}", rb));
self.emit(&format!("movq %rax, {}", dst)); self.emit(&format!("movq %rdx, {}", hi));
self.emit(&format!("movq {}, {}", hi, tmp));
self.emit(&format!("shlq $32, {}", tmp)); self.emit(&format!("subq {}, {}", hi, tmp));
self.emit(&format!("addq {}, {}", tmp, dst));
self.emit(&format!("movabsq $4294967295, {}", tmp2)); self.emit(&format!("jnc 1f"));
self.emit(&format!("addq {}, {}", tmp2, dst));
self.emit("1:");
self.emit_reduce_mod_p(&dst);
Ok(())
}
fn emit_eq<const N: usize>(&mut self, order: &Order<N>, body: NounId) -> Result<(), CompileError> {
let (a, b) = body_pair(order, body)?;
self.emit_formula(order, a)?;
let ra = self.pop_reg();
self.emit_formula(order, b)?;
let rb = self.pop_reg();
let dst = self.push_reg();
self.emit(&format!("xorq {d}, {d}", d=dst));
self.emit(&format!("cmpq {}, {}", rb, ra));
self.emit(&format!("setne %al"));
self.emit(&format!("movzbq %al, {}", dst));
Ok(())
}
fn emit_lt<const N: usize>(&mut self, order: &Order<N>, body: NounId) -> Result<(), CompileError> {
let (a, b) = body_pair(order, body)?;
self.emit_formula(order, a)?;
let ra = self.pop_reg();
self.emit_formula(order, b)?;
let rb = self.pop_reg();
let dst = self.push_reg();
self.emit(&format!("xorq {d}, {d}", d=dst));
self.emit(&format!("cmpq {}, {}", rb, ra));
self.emit(&format!("setae %al"));
self.emit(&format!("movzbq %al, {}", dst));
Ok(())
}
fn emit_xor<const N: usize>(&mut self, order: &Order<N>, body: NounId) -> Result<(), CompileError> {
let (a, b) = body_pair(order, body)?;
self.emit_formula(order, a)?;
let ra = self.pop_reg();
self.emit_formula(order, b)?;
let rb = self.pop_reg();
let dst = self.push_reg();
self.emit(&format!("movq {}, {}", ra, dst));
self.emit(&format!("xorq {}, {}", rb, dst));
Ok(())
}
fn emit_and<const N: usize>(&mut self, order: &Order<N>, body: NounId) -> Result<(), CompileError> {
let (a, b) = body_pair(order, body)?;
self.emit_formula(order, a)?;
let ra = self.pop_reg();
self.emit_formula(order, b)?;
let rb = self.pop_reg();
let dst = self.push_reg();
self.emit(&format!("movq {}, {}", ra, dst));
self.emit(&format!("andq {}, {}", rb, dst));
Ok(())
}
fn emit_not<const N: usize>(&mut self, order: &Order<N>, body: NounId) -> Result<(), CompileError> {
self.emit_formula(order, body)?;
let ra = self.pop_reg();
let dst = self.push_reg();
self.emit(&format!("movq {}, {}", ra, dst));
self.emit(&format!("notq {}", dst));
self.emit(&format!("andl $0xFFFFFFFF, {}d", dst)); Ok(())
}
fn emit_shl<const N: usize>(&mut self, order: &Order<N>, body: NounId) -> Result<(), CompileError> {
let (a, b) = body_pair(order, body)?;
self.emit_formula(order, a)?;
let ra = self.pop_reg();
self.emit_formula(order, b)?;
let rb = self.pop_reg();
let dst = self.push_reg();
self.emit(&format!("movq {}, %rcx", rb));
self.emit(&format!("movq {}, {}", ra, dst));
self.emit(&format!("shlq %cl, {}", dst));
self.emit(&format!("andl $0xFFFFFFFF, {}d", dst)); Ok(())
}
fn emit_reduce_mod_p(&mut self, dst: &str) {
let tmp = self.alloc_vreg();
let lbl = self.alloc_label();
self.emit(&format!("movabsq ${}, {}", P, tmp));
self.emit(&format!("cmpq {}, {}", tmp, dst));
self.emit(&format!("jb {}", lbl));
self.emit(&format!("subq {}, {}", tmp, dst));
self.emit_label(&lbl);
}
fn finish(self, result: &str) -> String {
let mut asm = String::with_capacity(8192);
asm.push_str("# nox formula -> x86-64 + Intel AMX assembly\n");
asm.push_str("# Goldilocks field: p = 2^64 - 2^32 + 1\n");
asm.push_str("# Requires: Intel Sapphire Rapids or newer (AMX support)\n\n");
asm.push_str(".intel_syntax noprefix\n\n");
asm.push_str(AMX_TILE_CONFIG);
asm.push_str(".text\n");
asm.push_str(".globl nox_formula\n");
asm.push_str(".type nox_formula, @function\n");
asm.push_str("nox_formula:\n");
asm.push_str(" pushq %rbp\n");
asm.push_str(" movq %rsp, %rbp\n");
for ® in SCRATCH_REGS {
asm.push_str(&format!(" pushq %{}\n", reg));
}
asm.push_str(" # Configure AMX tiles\n");
asm.push_str(" ldtilecfg [rip + .Ltile_config]\n");
let frame_size = (self.next_reg as usize + 1) * 8;
let aligned = (frame_size + 15) & !15; asm.push_str(&format!(" subq ${}, %rsp\n", aligned));
asm.push('\n');
for i in 0..self.num_params {
if (i as usize) < PARAM_REGS.len() {
asm.push_str(&format!(" movq %{}, {}(%rsp) # param {}\n",
PARAM_REGS[i as usize], i * 8, i));
}
}
asm.push('\n');
asm.push_str(" # --- formula body ---\n");
asm.push_str(&self.body);
asm.push_str(&format!("\n movq {}, %rax\n", result));
asm.push_str(" tilerelease\n");
asm.push_str(&format!(" addq ${}, %rsp\n", aligned));
for ® in SCRATCH_REGS.iter().rev() {
asm.push_str(&format!(" popq %{}\n", reg));
}
asm.push_str(" popq %rbp\n");
asm.push_str(" retq\n");
asm.push_str(".Lfunc_end_nox_formula:\n");
asm.push_str(" .size nox_formula, .Lfunc_end_nox_formula - nox_formula\n\n");
asm.push_str(AMX_MATRIX_HELPERS);
asm
}
fn finish_parallel(self, result: &str, num_params: u32) -> String {
let mut asm = String::with_capacity(8192);
asm.push_str("# nox formula -> x86-64 + Intel AMX assembly (parallel)\n");
asm.push_str("# Goldilocks field: p = 2^64 - 2^32 + 1\n");
asm.push_str("# Requires: Intel Sapphire Rapids or newer (AMX support)\n\n");
asm.push_str(".intel_syntax noprefix\n\n");
asm.push_str(AMX_TILE_CONFIG);
asm.push_str(".text\n");
asm.push_str(".globl nox_formula_parallel\n");
asm.push_str(".type nox_formula_parallel, @function\n");
asm.push_str("nox_formula_parallel:\n");
asm.push_str(" pushq %rbp\n");
asm.push_str(" movq %rsp, %rbp\n");
for ® in SCRATCH_REGS {
asm.push_str(&format!(" pushq %{}\n", reg));
}
asm.push_str(" ldtilecfg [rip + .Ltile_config]\n");
let frame_size = ((self.next_reg as usize + 16) * 8 + 15) & !15;
asm.push_str(&format!(" subq ${}, %rsp\n", frame_size));
let base_slot = (self.next_reg as usize + 1) * 8;
for i in 0..num_params {
if (i as usize) < PARAM_REGS.len() {
asm.push_str(&format!(" movq %{}, {}(%rsp) # input{}_ptr\n",
PARAM_REGS[i as usize], base_slot + i as usize * 8, i));
}
}
let out_slot = base_slot + num_params as usize * 8;
let cnt_slot = out_slot + 8;
if (num_params as usize) < PARAM_REGS.len() {
asm.push_str(&format!(" movq %{}, {}(%rsp) # output_ptr\n",
PARAM_REGS[num_params as usize], out_slot));
}
if (num_params as usize + 1) < PARAM_REGS.len() {
asm.push_str(&format!(" movq %{}, {}(%rsp) # count\n",
PARAM_REGS[num_params as usize + 1], cnt_slot));
}
asm.push_str("\n xorq %rbx, %rbx # idx = 0\n");
asm.push_str(".Lpar_loop:\n");
asm.push_str(&format!(" cmpq {}(%rsp), %rbx\n", cnt_slot));
asm.push_str(" jge .Lpar_done\n\n");
for i in 0..num_params {
asm.push_str(&format!(" movq {}(%rsp), %rax # input{}_ptr\n",
base_slot + i as usize * 8, i));
asm.push_str(&format!(" movq (%rax,%rbx,8), %{}\n",
if (i as usize) < PARAM_REGS.len() { PARAM_REGS[i as usize] } else { "rax" }));
}
asm.push('\n');
asm.push_str(" # --- formula body ---\n");
asm.push_str(&self.body);
asm.push_str(&format!("\n movq {}(%rsp), %rax # output_ptr\n", out_slot));
asm.push_str(&format!(" movq {}, (%rax,%rbx,8)\n", result));
asm.push_str(" incq %rbx\n");
asm.push_str(" jmp .Lpar_loop\n");
asm.push_str(".Lpar_done:\n");
asm.push_str(" tilerelease\n");
asm.push_str(&format!(" addq ${}, %rsp\n", frame_size));
for ® in SCRATCH_REGS.iter().rev() {
asm.push_str(&format!(" popq %{}\n", reg));
}
asm.push_str(" popq %rbp\n");
asm.push_str(" retq\n");
asm.push_str(".Lfunc_end_nox_formula_parallel:\n");
asm.push_str(" .size nox_formula_parallel, .Lfunc_end_nox_formula_parallel - nox_formula_parallel\n\n");
asm.push_str(AMX_MATRIX_HELPERS);
asm
}
}
const AMX_TILE_CONFIG: &str = r#"
.section .rodata
.align 64
.Ltile_config:
.byte 1 # palette = 1 (AMX-INT8 / AMX-BF16)
.zero 15 # reserved
# colbytes for tmm0-tmm7 (16 bits each)
.word 64 # tmm0: 64 bytes per row
.word 64 # tmm1
.word 64 # tmm2
.word 64 # tmm3
.word 64 # tmm4
.word 64 # tmm5
.word 64 # tmm6
.word 64 # tmm7
# rows for tmm0-tmm7 (8 bits each)
.byte 16 # tmm0: 16 rows
.byte 16 # tmm1
.byte 16 # tmm2
.byte 16 # tmm3
.byte 16 # tmm4
.byte 16 # tmm5
.byte 16 # tmm6
.byte 16 # tmm7
.zero 16 # pad to 64 bytes
"#;
const AMX_MATRIX_HELPERS: &str = r#"
# --- AMX matrix helpers for Goldilocks batch operations ---
# amx_matmul_bf16: C += A * B using BF16 tiles
# rdi = A ptr (16x32 BF16 matrix, row-major)
# rsi = B ptr (16x32 BF16 matrix, row-major)
# rdx = C ptr (16x16 FP32 matrix, row-major, accumulator)
.globl amx_matmul_bf16
.type amx_matmul_bf16, @function
amx_matmul_bf16:
# Load C accumulator into tmm0
tileloadd tmm0, [rdx] # C tile
# Load A into tmm1
tileloadd tmm1, [rdi] # A tile
# Load B into tmm2
tileloadd tmm2, [rsi] # B tile
# C += A * B (BF16 dot product -> FP32)
tdpbf16ps tmm0, tmm1, tmm2
# Store result
tilestored [rdx], tmm0
retq
.Lfunc_end_amx_matmul_bf16:
.size amx_matmul_bf16, .Lfunc_end_amx_matmul_bf16 - amx_matmul_bf16
# amx_matmul_int8: C += A * B using INT8 tiles
# rdi = A ptr (16x64 uint8 matrix)
# rsi = B ptr (16x64 int8 matrix)
# rdx = C ptr (16x16 int32 matrix, accumulator)
.globl amx_matmul_int8
.type amx_matmul_int8, @function
amx_matmul_int8:
tileloadd tmm0, [rdx] # C accumulator
tileloadd tmm1, [rdi] # A (unsigned int8)
tileloadd tmm2, [rsi] # B (signed int8)
tdpbusd tmm0, tmm1, tmm2 # C += A * B (uint8 x int8 -> int32)
tilestored [rdx], tmm0
retq
.Lfunc_end_amx_matmul_int8:
.size amx_matmul_int8, .Lfunc_end_amx_matmul_int8 - amx_matmul_int8
# amx_tilezero: zero a tile register
# Uses tmm0 for zeroing, then stores to memory
# rdi = dest ptr
.globl amx_tilezero
.type amx_tilezero, @function
amx_tilezero:
tilezero tmm0
tilestored [rdi], tmm0
retq
.Lfunc_end_amx_tilezero:
.size amx_tilezero, .Lfunc_end_amx_tilezero - amx_tilezero
"#;