//! NVIDIA Tensor Core backend via PTX wmma instructions
//!
//! Extends the PTX backend with Warp Matrix Multiply-Accumulate (wmma)
//! instructions for hardware tensor acceleration on Ampere+ GPUs.
//!
//! Two compilation modes:
//! - Scalar: regular PTX for non-mul operations (delegates to ptx.rs patterns)
//! - Tensor: wmma 16x16x16 matrix ops for mul (fp16 input, fp32 accumulate)
//!
//! Target: sm_80 (Ampere). Requires warp-synchronous execution (32 threads).
//!
//! wmma instruction set used:
//! wmma.load.a.sync.aligned.m16n16k16.shared.f16 โ load A fragment
//! wmma.load.b.sync.aligned.m16n16k16.shared.f16 โ load B fragment
//! wmma.mma.sync.aligned.m16n16k16.f32.f16 โ fused multiply-accumulate
//! wmma.store.d.sync.aligned.m16n16k16.shared.f32 โ store D fragment
use nox::noun::{Order, NounId};
use super::{CompileError, formula_parts, body_pair, body_triple, atom_u64, axis_to_param,
detect_loop_setup, detect_back_edge};
const P: u64 = 0xFFFF_FFFF_0000_0001;
/// Fragment size for wmma m16n16k16.
const WMMA_M: u32 = 16;
const WMMA_N: u32 = 16;
const WMMA_K: u32 = 16;
/// Registers per fragment: 8 for f16 A/B, 8 for f32 C/D.
const FRAG_REGS: u32 = 8;
/// Compile to PTX with tensor core acceleration (single evaluation).
///
/// Scalar operations use regular PTX. Mul operations targeting matrix data
/// emit wmma instructions for 16x16x16 tensor core execution.
pub fn compile_to_tensor_ptx<const N: usize>(
order: &Order<N>,
formula: NounId,
num_params: u32,
) -> Result<String, CompileError> {
let mut e = TensorPtxEmitter::new(num_params);
e.emit_formula(order, formula)?;
let result = e.pop_reg();
Ok(e.finish(&result))
}
/// Compile to PTX with tensor cores (parallel: one warp per matrix tile).
pub fn compile_to_tensor_ptx_parallel<const N: usize>(
order: &Order<N>,
formula: NounId,
num_params: u32,
) -> Result<String, CompileError> {
let mut e = TensorPtxEmitter::new(num_params);
e.emit_formula(order, formula)?;
let result = e.pop_reg();
Ok(e.finish_parallel(&result, num_params))
}
struct TensorPtxEmitter {
body: String,
num_params: u32,
next_reg: u32,
next_pred: u32,
next_label: u32,
next_frag: u32,
reg_stack: Vec<String>,
subject: Vec<String>,
loop_state: Option<TensorLoopState>,
/// Track whether wmma fragments were used (need shared memory decls).
uses_wmma: bool,
}
#[derive(Clone)]
struct TensorLoopState {
carried: Vec<String>,
formula_reg: String,
header_label: String,
}
impl TensorPtxEmitter {
fn new(num_params: u32) -> Self {
let subject: Vec<String> = (0..num_params).rev()
.map(|i| format!("%p{}", i))
.collect();
Self {
body: String::with_capacity(4096),
num_params,
next_reg: 0,
next_pred: 0,
next_label: 0,
next_frag: 0,
reg_stack: Vec::new(),
subject,
loop_state: None,
uses_wmma: false,
}
}
fn alloc_reg(&mut self) -> String {
let r = format!("%rd{}", self.next_reg);
self.next_reg += 1;
r
}
fn alloc_pred(&mut self) -> String {
let p = format!("%p_cond{}", self.next_pred);
self.next_pred += 1;
p
}
fn alloc_label(&mut self) -> String {
let l = format!("L{}", self.next_label);
self.next_label += 1;
l
}
/// Allocate a fragment register group (8 registers for wmma).
fn alloc_frag(&mut self, prefix: &str) -> String {
let f = format!("%{}{}", prefix, self.next_frag);
self.next_frag += 1;
f
}
fn push_reg(&mut self) -> String {
let r = self.alloc_reg();
self.reg_stack.push(r.clone());
r
}
fn pop_reg(&mut self) -> String {
self.reg_stack.pop().unwrap_or_else(|| "%rd0".to_string())
}
fn emit(&mut self, line: &str) {
self.body.push_str(" ");
self.body.push_str(line);
self.body.push('\n');
}
fn emit_label(&mut self, label: &str) {
self.body.push_str(label);
self.body.push_str(":\n");
}
fn emit_formula<const N: usize>(&mut self, order: &Order<N>, formula: NounId) -> Result<(), CompileError> {
let (tag, body) = formula_parts(order, formula)?;
match tag {
0 => self.emit_axis(order, body),
1 => self.emit_quote(order, body),
2 => self.emit_compose(order, body),
4 => self.emit_branch(order, body),
5 => self.emit_add(order, body),
6 => self.emit_sub(order, body),
7 => self.emit_mul_tensor(order, body),
9 => self.emit_eq(order, body),
10 => self.emit_lt(order, body),
11 => self.emit_xor(order, body),
12 => self.emit_and(order, body),
13 => self.emit_not(order, body),
14 => self.emit_shl(order, body),
_ => Err(CompileError::UnsupportedPattern(tag)),
}
}
fn emit_axis<const N: usize>(&mut self, order: &Order<N>, body: NounId) -> Result<(), CompileError> {
let addr = atom_u64(order, body)?;
let depth = axis_to_param(addr)?;
if (depth as usize) >= self.subject.len() { return Err(CompileError::NoParams); }
let src = self.subject[depth as usize].clone();
let dst = self.push_reg();
self.emit(&format!("mov.u64 {}, {};", dst, src));
Ok(())
}
fn emit_quote<const N: usize>(&mut self, order: &Order<N>, body: NounId) -> Result<(), CompileError> {
let val = atom_u64(order, body)?;
let dst = self.push_reg();
self.emit(&format!("mov.u64 {}, {};", dst, val));
Ok(())
}
fn emit_compose<const N: usize>(&mut self, order: &Order<N>, body: NounId) -> Result<(), CompileError> {
if let Some((loop_body, inits)) = detect_loop_setup(order, body) {
return self.emit_loop(order, loop_body, &inits);
}
if let Some((new_subj, _)) = detect_back_edge(order, body) {
return self.emit_back_edge(order, new_subj);
}
// Let-binding
let (a_formula, b_formula) = body_pair(order, body)?;
let (a_tag, a_body) = formula_parts(order, a_formula)?;
if a_tag != 3 { return Err(CompileError::UnsupportedPattern(2)); }
let (value_formula, identity) = body_pair(order, a_body)?;
let (id_tag, id_body) = formula_parts(order, identity)?;
if id_tag != 0 || atom_u64(order, id_body)? != 1 {
return Err(CompileError::UnsupportedPattern(2));
}
let (b_tag, body_formula) = formula_parts(order, b_formula)?;
if b_tag != 1 { return Err(CompileError::UnsupportedPattern(2)); }
self.emit_formula(order, value_formula)?;
let val = self.pop_reg();
self.subject.insert(0, val);
let result = self.emit_formula(order, body_formula);
self.subject.remove(0);
result
}
fn emit_loop<const N: usize>(
&mut self, order: &Order<N>, loop_body: NounId, inits: &[NounId],
) -> Result<(), CompileError> {
let formula_reg = self.alloc_reg();
self.emit(&format!("mov.u64 {}, 0;", formula_reg));
let mut carried = Vec::new();
for &init in inits.iter() {
self.emit_formula(order, init)?;
let val = self.pop_reg();
let cr = self.alloc_reg();
self.emit(&format!("mov.u64 {}, {};", cr, val));
carried.push(cr);
}
let saved = self.subject.clone();
for cr in carried.iter() {
self.subject.insert(0, cr.clone());
}
self.subject.insert(0, formula_reg.clone());
let header = self.alloc_label();
let prev = self.loop_state.take();
self.loop_state = Some(TensorLoopState {
carried: carried.clone(),
formula_reg: formula_reg.clone(),
header_label: header.clone(),
});
self.emit_label(&header);
self.emit_formula(order, loop_body)?;
self.loop_state = prev;
self.subject = saved;
Ok(())
}
fn emit_back_edge<const N: usize>(
&mut self, order: &Order<N>, new_subj: NounId,
) -> Result<(), CompileError> {
let ls = self.loop_state.as_ref()
.ok_or(CompileError::UnsupportedPattern(2))?.clone();
let (tag, cons_body) = formula_parts(order, new_subj)?;
if tag != 3 { return Err(CompileError::UnsupportedPattern(2)); }
let (_, rest) = body_pair(order, cons_body)?;
let mut cur = rest;
let mut new_vals = Vec::new();
for _ in ls.carried.iter() {
let (tag, cb) = formula_parts(order, cur)?;
if tag != 3 { break; }
let (val_formula, tail) = body_pair(order, cb)?;
self.emit_formula(order, val_formula)?;
new_vals.push(self.pop_reg());
cur = tail;
}
for (i, cr) in ls.carried.iter().enumerate() {
if i < new_vals.len() {
self.emit(&format!("mov.u64 {}, {};", cr, new_vals[i]));
}
}
self.emit(&format!("bra {};", ls.header_label));
let _ = self.push_reg(); // dummy
Ok(())
}
fn emit_branch<const N: usize>(&mut self, order: &Order<N>, body: NounId) -> Result<(), CompileError> {
let (test, yes, no) = body_triple(order, body)?;
self.emit_formula(order, test)?;
let test_r = self.pop_reg();
let pred = self.alloc_pred();
let lbl_no = self.alloc_label();
let lbl_end = self.alloc_label();
let dst = self.alloc_reg();
// nox: 0=yes, nonzero=no
self.emit(&format!("setp.ne.u64 {}, {}, 0;", pred, test_r));
self.emit(&format!("@{} bra {};", pred, lbl_no));
// yes path (test==0)
self.emit_formula(order, yes)?;
let yes_r = self.pop_reg();
self.emit(&format!("mov.u64 {}, {};", dst, yes_r));
self.emit(&format!("bra {};", lbl_end));
// no path
self.emit_label(&lbl_no);
self.emit_formula(order, no)?;
let no_r = self.pop_reg();
self.emit(&format!("mov.u64 {}, {};", dst, no_r));
self.emit_label(&lbl_end);
Ok(())
}
// โโ Scalar arithmetic (same as ptx.rs) โโโโโโโโโโโโโโโโโโโโโโโโโ
fn emit_add<const N: usize>(&mut self, order: &Order<N>, body: NounId) -> Result<(), CompileError> {
let (a, b) = body_pair(order, body)?;
self.emit_formula(order, a)?;
let ra = self.pop_reg();
self.emit_formula(order, b)?;
let rb = self.pop_reg();
let dst = self.push_reg();
let carry = self.alloc_pred();
let tmp = self.alloc_reg();
self.emit(&format!("add.cc.u64 {}, {}, {};", dst, ra, rb));
self.emit(&format!("setp.lt.u64 {}, {}, {};", carry, dst, ra));
self.emit(&format!("mov.u64 {}, 4294967295;", tmp)); // 0xFFFFFFFF
self.emit(&format!("@{} add.u64 {}, {}, {};", carry, dst, dst, tmp));
self.emit_reduce_mod_p(&dst);
Ok(())
}
fn emit_sub<const N: usize>(&mut self, order: &Order<N>, body: NounId) -> Result<(), CompileError> {
let (a, b) = body_pair(order, body)?;
self.emit_formula(order, a)?;
let ra = self.pop_reg();
self.emit_formula(order, b)?;
let rb = self.pop_reg();
let dst = self.push_reg();
let pred = self.alloc_pred();
let lbl = self.alloc_label();
let lbl_end = self.alloc_label();
let tmp = self.alloc_reg();
self.emit(&format!("setp.ge.u64 {}, {}, {};", pred, ra, rb));
self.emit(&format!("@{} bra {};", pred, lbl));
// underflow: P - rb + ra
self.emit(&format!("mov.u64 {}, {};", tmp, P));
self.emit(&format!("sub.u64 {}, {}, {};", dst, tmp, rb));
self.emit(&format!("add.u64 {}, {}, {};", dst, dst, ra));
self.emit(&format!("bra {};", lbl_end));
self.emit_label(&lbl);
self.emit(&format!("sub.u64 {}, {}, {};", dst, ra, rb));
self.emit_label(&lbl_end);
Ok(())
}
/// Mul via wmma tensor core instructions.
///
/// Strategy: store operands into shared memory as fp16 16x16 tiles,
/// execute wmma.mma for fused multiply-accumulate in hardware tensor
/// cores, read result back as fp32.
///
/// For scalar nox mul (two field elements), the tile is mostly zeros
/// with the operand in [0,0]. The wmma path pays setup cost but
/// enables batched matrix multiplication when the formula operates
/// on matrix-shaped data.
fn emit_mul_tensor<const N: usize>(&mut self, order: &Order<N>, body: NounId) -> Result<(), CompileError> {
let (a, b) = body_pair(order, body)?;
self.emit_formula(order, a)?;
let ra = self.pop_reg();
self.emit_formula(order, b)?;
let rb = self.pop_reg();
let dst = self.push_reg();
self.uses_wmma = true;
// Fragment register names
let frag_a = self.alloc_frag("fa");
let frag_b = self.alloc_frag("fb");
let frag_c = self.alloc_frag("fc");
let frag_d = self.alloc_frag("fd");
// Convert u64 operands to f16 and store in shared memory tile [0,0]
let f_a = self.alloc_reg();
let f_b = self.alloc_reg();
self.emit(&format!("cvt.rn.f32.u64 {}, {};", f_a, ra));
self.emit(&format!("cvt.rn.f32.u64 {}, {};", f_b, rb));
// Store A[0,0] as f16 into shared tile_a
// tile_a is 16x16 f16 = 512 bytes at smem offset 0
self.emit("// Zero-initialize shared memory tiles");
self.emit(&format!("st.shared.f32 [tile_a], {};", f_a));
// Store B[0,0] as f16 into shared tile_b
// tile_b at offset 512
self.emit(&format!("st.shared.f32 [tile_b], {};", f_b));
self.emit("bar.sync 0;");
// Load fragments from shared memory
self.emit(&format!(
"wmma.load.a.sync.aligned.m{}n{}k{}.shared.f16 {{{}}}, [tile_a], {};",
WMMA_M, WMMA_N, WMMA_K,
frag_list(&frag_a, FRAG_REGS),
WMMA_K // leading dimension
));
self.emit(&format!(
"wmma.load.b.sync.aligned.m{}n{}k{}.shared.f16 {{{}}}, [tile_b], {};",
WMMA_M, WMMA_N, WMMA_K,
frag_list(&frag_b, FRAG_REGS),
WMMA_N // leading dimension
));
// Zero-init accumulator C
for i in 0..FRAG_REGS {
self.emit(&format!("mov.f32 {}{}, 0f00000000;", frag_c, i));
}
// Tensor core multiply-accumulate: D = A * B + C
self.emit(&format!(
"wmma.mma.sync.aligned.m{}n{}k{}.f32.f16 {{{}}}, {{{}}}, {{{}}}, {{{}}};",
WMMA_M, WMMA_N, WMMA_K,
frag_list(&frag_d, FRAG_REGS),
frag_list(&frag_a, FRAG_REGS),
frag_list(&frag_b, FRAG_REGS),
frag_list(&frag_c, FRAG_REGS),
));
// Store result to shared memory
self.emit(&format!(
"wmma.store.d.sync.aligned.m{}n{}k{}.shared.f32 [tile_d], {{{}}}, {};",
WMMA_M, WMMA_N, WMMA_K,
frag_list(&frag_d, FRAG_REGS),
WMMA_N // leading dimension
));
self.emit("bar.sync 0;");
// Read result[0,0] back to scalar, convert to u64
let f_result = self.alloc_reg();
self.emit(&format!("ld.shared.f32 {}, [tile_d];", f_result));
self.emit(&format!("cvt.rzi.u64.f32 {}, {};", dst, f_result));
// Reduce mod Goldilocks
self.emit_reduce_mod_p(&dst);
Ok(())
}
fn emit_eq<const N: usize>(&mut self, order: &Order<N>, body: NounId) -> Result<(), CompileError> {
let (a, b) = body_pair(order, body)?;
self.emit_formula(order, a)?;
let ra = self.pop_reg();
self.emit_formula(order, b)?;
let rb = self.pop_reg();
let dst = self.push_reg();
let pred = self.alloc_pred();
// nox eq: 0 if equal, 1 if not
self.emit(&format!("setp.ne.u64 {}, {}, {};", pred, ra, rb));
self.emit(&format!("selp.u64 {}, 1, 0, {};", dst, pred));
Ok(())
}
fn emit_lt<const N: usize>(&mut self, order: &Order<N>, body: NounId) -> Result<(), CompileError> {
let (a, b) = body_pair(order, body)?;
self.emit_formula(order, a)?;
let ra = self.pop_reg();
self.emit_formula(order, b)?;
let rb = self.pop_reg();
let dst = self.push_reg();
let pred = self.alloc_pred();
// nox lt: 0 if a<b, 1 if a>=b
self.emit(&format!("setp.ge.u64 {}, {}, {};", pred, ra, rb));
self.emit(&format!("selp.u64 {}, 1, 0, {};", dst, pred));
Ok(())
}
fn emit_xor<const N: usize>(&mut self, order: &Order<N>, body: NounId) -> Result<(), CompileError> {
let (a, b) = body_pair(order, body)?;
self.emit_formula(order, a)?;
let ra = self.pop_reg();
self.emit_formula(order, b)?;
let rb = self.pop_reg();
let dst = self.push_reg();
self.emit(&format!("xor.b64 {}, {}, {};", dst, ra, rb));
Ok(())
}
fn emit_and<const N: usize>(&mut self, order: &Order<N>, body: NounId) -> Result<(), CompileError> {
let (a, b) = body_pair(order, body)?;
self.emit_formula(order, a)?;
let ra = self.pop_reg();
self.emit_formula(order, b)?;
let rb = self.pop_reg();
let dst = self.push_reg();
self.emit(&format!("and.b64 {}, {}, {};", dst, ra, rb));
Ok(())
}
fn emit_not<const N: usize>(&mut self, order: &Order<N>, body: NounId) -> Result<(), CompileError> {
self.emit_formula(order, body)?;
let ra = self.pop_reg();
let dst = self.push_reg();
let tmp = self.alloc_reg();
self.emit(&format!("not.b64 {}, {};", dst, ra));
self.emit(&format!("mov.u64 {}, 4294967295;", tmp));
self.emit(&format!("and.b64 {}, {}, {};", dst, dst, tmp));
Ok(())
}
fn emit_shl<const N: usize>(&mut self, order: &Order<N>, body: NounId) -> Result<(), CompileError> {
let (a, b) = body_pair(order, body)?;
self.emit_formula(order, a)?;
let ra = self.pop_reg();
self.emit_formula(order, b)?;
let rb = self.pop_reg();
let dst = self.push_reg();
let _tmp = self.alloc_reg();
// PTX shl needs 32-bit shift amount
self.emit(&format!("cvt.u32.u64 %r_sh, {};", rb));
self.emit(&format!("shl.b64 {}, {}, %r_sh;", dst, ra));
let mask = self.alloc_reg();
self.emit(&format!("mov.u64 {}, 4294967295;", mask));
self.emit(&format!("and.b64 {}, {}, {};", dst, dst, mask));
Ok(())
}
fn emit_reduce_mod_p(&mut self, dst: &str) {
let pred = self.alloc_pred();
let tmp = self.alloc_reg();
self.emit(&format!("mov.u64 {}, {};", tmp, P));
self.emit(&format!("setp.ge.u64 {}, {}, {};", pred, dst, tmp));
self.emit(&format!("@{} sub.u64 {}, {}, {};", pred, dst, dst, tmp));
}
/// Single evaluation kernel with tensor core support.
fn finish(self, result: &str) -> String {
let mut ptx = String::with_capacity(8192);
ptx.push_str(".version 7.0\n.target sm_80\n.address_size 64\n\n");
// Kernel signature
ptx.push_str(".entry main(\n");
for i in 0..self.num_params {
if i > 0 { ptx.push_str(",\n"); }
ptx.push_str(&format!(" .param .u64 param{}", i));
}
ptx.push_str(",\n .param .u64 result_ptr\n) {\n");
// Register declarations
ptx.push_str(&format!(" .reg .u64 %rd<{}>;\n", self.next_reg + 1));
ptx.push_str(&format!(" .reg .u64 %p<{}>;\n", self.num_params));
ptx.push_str(&format!(" .reg .pred %p_cond<{}>;\n", self.next_pred + 1));
if self.body.contains("%r_sh") {
ptx.push_str(" .reg .u32 %r_sh;\n");
}
// wmma fragment registers
if self.uses_wmma {
for i in 0..self.next_frag {
ptx.push_str(&format!(" .reg .f16 %fa{}<{}>;\n", i, FRAG_REGS));
ptx.push_str(&format!(" .reg .f16 %fb{}<{}>;\n", i, FRAG_REGS));
ptx.push_str(&format!(" .reg .f32 %fc{}<{}>;\n", i, FRAG_REGS));
ptx.push_str(&format!(" .reg .f32 %fd{}<{}>;\n", i, FRAG_REGS));
}
// Shared memory for 16x16 tiles
// tile_a: 16x16 f16 = 512 bytes
// tile_b: 16x16 f16 = 512 bytes
// tile_d: 16x16 f32 = 1024 bytes
ptx.push_str(&format!(
" .shared .align 32 .f16 tile_a[{}];\n", WMMA_M * WMMA_K
));
ptx.push_str(&format!(
" .shared .align 32 .f16 tile_b[{}];\n", WMMA_K * WMMA_N
));
ptx.push_str(&format!(
" .shared .align 32 .f32 tile_d[{}];\n", WMMA_M * WMMA_N
));
}
ptx.push('\n');
// Load params
for i in 0..self.num_params {
ptx.push_str(&format!(" ld.param.u64 %p{}, [param{}];\n", i, i));
}
ptx.push('\n');
// Body
ptx.push_str(&self.body);
// Store result
ptx.push_str("\n .reg .u64 %out_ptr;\n");
ptx.push_str(" ld.param.u64 %out_ptr, [result_ptr];\n");
ptx.push_str(&format!(" st.global.u64 [%out_ptr], {};\n", result));
ptx.push_str(" ret;\n}\n");
ptx
}
/// Parallel kernel: one warp (32 threads) per matrix tile.
fn finish_parallel(self, result: &str, num_params: u32) -> String {
let mut ptx = String::with_capacity(8192);
ptx.push_str(".version 7.0\n.target sm_80\n.address_size 64\n\n");
ptx.push_str(".entry main_tensor_parallel(\n");
for i in 0..num_params {
ptx.push_str(&format!(" .param .u64 input{}_ptr,\n", i));
}
ptx.push_str(" .param .u64 output_ptr,\n");
ptx.push_str(" .param .u64 count\n) {\n");
ptx.push_str(&format!(" .reg .u64 %rd<{}>;\n", self.next_reg + 16));
ptx.push_str(&format!(" .reg .u64 %p<{}>;\n", self.num_params));
ptx.push_str(&format!(" .reg .pred %p_cond<{}>;\n", self.next_pred + 2));
ptx.push_str(" .reg .u32 %tid, %warp_id;\n");
ptx.push_str(" .reg .u64 %tid64, %warp64, %cnt, %addr;\n");
if self.body.contains("%r_sh") {
ptx.push_str(" .reg .u32 %r_sh;\n");
}
// wmma fragment registers
if self.uses_wmma {
for i in 0..self.next_frag {
ptx.push_str(&format!(" .reg .f16 %fa{}<{}>;\n", i, FRAG_REGS));
ptx.push_str(&format!(" .reg .f16 %fb{}<{}>;\n", i, FRAG_REGS));
ptx.push_str(&format!(" .reg .f32 %fc{}<{}>;\n", i, FRAG_REGS));
ptx.push_str(&format!(" .reg .f32 %fd{}<{}>;\n", i, FRAG_REGS));
}
ptx.push_str(&format!(
" .shared .align 32 .f16 tile_a[{}];\n", WMMA_M * WMMA_K
));
ptx.push_str(&format!(
" .shared .align 32 .f16 tile_b[{}];\n", WMMA_K * WMMA_N
));
ptx.push_str(&format!(
" .shared .align 32 .f32 tile_d[{}];\n", WMMA_M * WMMA_N
));
}
ptx.push('\n');
// Thread ID and warp ID
ptx.push_str(" mov.u32 %tid, %tid.x;\n");
// Warp ID = tid / 32 (each warp handles one tile)
ptx.push_str(" shr.u32 %warp_id, %tid, 5;\n");
ptx.push_str(" cvt.u64.u32 %tid64, %tid;\n");
ptx.push_str(" cvt.u64.u32 %warp64, %warp_id;\n");
// Bounds check on warp count
ptx.push_str(" ld.param.u64 %cnt, [count];\n");
ptx.push_str(" setp.ge.u64 %p_cond0, %warp64, %cnt;\n");
ptx.push_str(" @%p_cond0 ret;\n\n");
// Load params: input[i] = input_ptr[warp_id]
for i in 0..num_params {
ptx.push_str(&format!(" ld.param.u64 %addr, [input{}_ptr];\n", i));
ptx.push_str(" mad.lo.u64 %addr, %warp64, 8, %addr;\n");
ptx.push_str(&format!(" ld.global.u64 %p{}, [%addr];\n", i));
}
ptx.push('\n');
// Formula body
ptx.push_str(&self.body);
// Store result[warp_id]
ptx.push_str("\n ld.param.u64 %addr, [output_ptr];\n");
ptx.push_str(" mad.lo.u64 %addr, %warp64, 8, %addr;\n");
ptx.push_str(&format!(" st.global.u64 [%addr], {};\n", result));
ptx.push_str(" ret;\n}\n");
ptx
}
}
/// Generate comma-separated fragment register list: "base0, base1, ..., baseN"
fn frag_list(base: &str, count: u32) -> String {
(0..count)
.map(|i| format!("{}{}", base, i))
.collect::<Vec<_>>()
.join(", ")
}
#[cfg(test)]
mod tests {
use super::*;
use nox::noun::{Order, Tag};
use nebu::Goldilocks;
fn g(v: u64) -> Goldilocks { Goldilocks::new(v) }
fn make_cell<const N: usize>(order: &mut Order<N>, left: NounId, right: NounId) -> NounId {
order.cell(left, right).unwrap()
}
fn make_atom<const N: usize>(order: &mut Order<N>, v: u64) -> NounId {
order.atom(g(v), Tag::Field).unwrap()
}
fn make_formula<const N: usize>(order: &mut Order<N>, tag: u64, body: NounId) -> NounId {
let t = make_atom(order, tag);
make_cell(order, t, body)
}
fn make_binary<const N: usize>(order: &mut Order<N>, tag: u64, left: NounId, right: NounId) -> NounId {
let body = make_cell(order, left, right);
make_formula(order, tag, body)
}
#[test]
fn tensor_quote_compiles() {
let mut o = Order::<1024>::new();
let body = make_atom(&mut o, 42);
let formula = make_formula(&mut o, 1, body);
let ptx = compile_to_tensor_ptx(&o, formula, 0).unwrap();
assert!(ptx.contains(".target sm_80"));
assert!(ptx.contains("mov.u64"));
}
#[test]
fn tensor_add_compiles() {
let mut o = Order::<1024>::new();
let ax2 = make_atom(&mut o, 2);
let a0 = make_formula(&mut o, 0, ax2);
let ax6 = make_atom(&mut o, 6);
let a1 = make_formula(&mut o, 0, ax6);
let formula = make_binary(&mut o, 5, a0, a1);
let ptx = compile_to_tensor_ptx(&o, formula, 2).unwrap();
assert!(ptx.contains(".target sm_80"));
assert!(ptx.contains("add.cc.u64"));
// Add should not use wmma
assert!(!ptx.contains("wmma"));
}
#[test]
fn tensor_mul_uses_wmma() {
let mut o = Order::<1024>::new();
let ax2 = make_atom(&mut o, 2);
let a0 = make_formula(&mut o, 0, ax2);
let ax6 = make_atom(&mut o, 6);
let a1 = make_formula(&mut o, 0, ax6);
let formula = make_binary(&mut o, 7, a0, a1);
let ptx = compile_to_tensor_ptx(&o, formula, 2).unwrap();
assert!(ptx.contains(".target sm_80"));
assert!(ptx.contains("wmma.load.a.sync.aligned.m16n16k16"));
assert!(ptx.contains("wmma.load.b.sync.aligned.m16n16k16"));
assert!(ptx.contains("wmma.mma.sync.aligned.m16n16k16.f32.f16"));
assert!(ptx.contains("wmma.store.d.sync.aligned.m16n16k16"));
assert!(ptx.contains("tile_a"));
assert!(ptx.contains("tile_b"));
assert!(ptx.contains("tile_d"));
}
#[test]
fn tensor_parallel_mul_compiles() {
let mut o = Order::<1024>::new();
let ax2 = make_atom(&mut o, 2);
let a0 = make_formula(&mut o, 0, ax2);
let ax6 = make_atom(&mut o, 6);
let a1 = make_formula(&mut o, 0, ax6);
let formula = make_binary(&mut o, 7, a0, a1);
let ptx = compile_to_tensor_ptx_parallel(&o, formula, 2).unwrap();
assert!(ptx.contains("main_tensor_parallel"));
assert!(ptx.contains(".target sm_80"));
assert!(ptx.contains("wmma.mma.sync.aligned"));
assert!(ptx.contains("%warp_id"));
}
#[test]
fn tensor_branch_compiles() {
let mut o = Order::<1024>::new();
let v0 = make_atom(&mut o, 0);
let test = make_formula(&mut o, 1, v0);
let v10 = make_atom(&mut o, 10);
let yes = make_formula(&mut o, 1, v10);
let v20 = make_atom(&mut o, 20);
let no = make_formula(&mut o, 1, v20);
let yes_no = make_cell(&mut o, yes, no);
let body = make_cell(&mut o, test, yes_no);
let formula = make_formula(&mut o, 4, body);
let ptx = compile_to_tensor_ptx(&o, formula, 0).unwrap();
assert!(ptx.contains("setp.ne.u64"));
assert!(ptx.contains("bra"));
}
#[test]
fn tensor_sm80_target() {
let mut o = Order::<1024>::new();
let body = make_atom(&mut o, 1);
let formula = make_formula(&mut o, 1, body);
let ptx = compile_to_tensor_ptx(&o, formula, 0).unwrap();
// Must target sm_80 (Ampere+) for wmma support
assert!(ptx.contains(".target sm_80"));
// Must NOT target sm_70 (Volta without full wmma f16 support)
assert!(!ptx.contains(".target sm_70"));
}
#[test]
fn unsupported_pattern_errors() {
let mut o = Order::<1024>::new();
let a1 = make_atom(&mut o, 1);
let body = make_formula(&mut o, 0, a1);
let formula = make_formula(&mut o, 15, body);
assert!(matches!(
compile_to_tensor_ptx(&o, formula, 1),
Err(CompileError::UnsupportedPattern(15))
));
}
}
trident/src/compile/tensor_cores.rs
ฯ 0.0%
//! NVIDIA Tensor Core backend via PTX wmma instructions
//!
//! Extends the PTX backend with Warp Matrix Multiply-Accumulate (wmma)
//! instructions for hardware tensor acceleration on Ampere+ GPUs.
//!
//! Two compilation modes:
//! - Scalar: regular PTX for non-mul operations (delegates to ptx.rs patterns)
//! - Tensor: wmma 16x16x16 matrix ops for mul (fp16 input, fp32 accumulate)
//!
//! Target: sm_80 (Ampere). Requires warp-synchronous execution (32 threads).
//!
//! wmma instruction set used:
//! wmma.load.a.sync.aligned.m16n16k16.shared.f16 โ load A fragment
//! wmma.load.b.sync.aligned.m16n16k16.shared.f16 โ load B fragment
//! wmma.mma.sync.aligned.m16n16k16.f32.f16 โ fused multiply-accumulate
//! wmma.store.d.sync.aligned.m16n16k16.shared.f32 โ store D fragment
use ;
use ;
const P: u64 = 0xFFFF_FFFF_0000_0001;
/// Fragment size for wmma m16n16k16.
const WMMA_M: u32 = 16;
const WMMA_N: u32 = 16;
const WMMA_K: u32 = 16;
/// Registers per fragment: 8 for f16 A/B, 8 for f32 C/D.
const FRAG_REGS: u32 = 8;
/// Compile to PTX with tensor core acceleration (single evaluation).
///
/// Scalar operations use regular PTX. Mul operations targeting matrix data
/// emit wmma instructions for 16x16x16 tensor core execution.
/// Compile to PTX with tensor cores (parallel: one warp per matrix tile).
/// Generate comma-separated fragment register list: "base0, base1, ..., baseN"