use super::asm::{amx_op, OP_FMA16, OP_LDX, OP_LDY, OP_LDZ, OP_STZ};
use super::regs::{XRow, YRow};
const fn fma16_first(xr: XRow, yr: YRow, tile: u8) -> u64 {
let x_off = xr.byte_offset() << 10;
let y_off = yr.byte_offset();
let z = (tile as u64 & 1) << 20;
x_off | y_off | z | (1 << 27) }
const fn fma16_acc(xr: XRow, yr: YRow, tile: u8) -> u64 {
let x_off = xr.byte_offset() << 10;
let y_off = yr.byte_offset();
let z = (tile as u64 & 1) << 20;
x_off | y_off | z }
#[inline]
pub unsafe fn preload_c_f16(c: *const u16, ldc: usize, tile: u8) {
for j in 0u8..32 {
let z_row = j * 2 + (tile & 1);
let c_addr = (c as *const u8).add(j as usize * ldc * 2);
amx_op::<OP_LDZ>((c_addr as u64) | ((z_row as u64) << 56));
}
}
#[inline]
pub unsafe fn store_c_f16(c: *mut u16, ldc: usize, tile: u8) {
for j in 0u8..32 {
let z_row = j * 2 + (tile & 1);
let c_addr = (c as *mut u8).add(j as usize * ldc * 2);
amx_op::<OP_STZ>((c_addr as u64) | ((z_row as u64) << 56));
}
}
#[inline]
pub unsafe fn microkernel_32x32_f16(a_panel: *const u8, b_panel: *const u8, k: usize) {
let mut first = true;
let mut p = 0usize;
while p + 8 <= k {
for i in 0u8..8 {
amx_op::<OP_LDX>((b_panel.add((p + i as usize) * 64) as u64) | ((i as u64) << 56));
amx_op::<OP_LDY>((a_panel.add((p + i as usize) * 64) as u64) | ((i as u64) << 56));
}
if first {
amx_op::<OP_FMA16>(fma16_first(
XRow::new_unchecked(0),
YRow::new_unchecked(0),
0,
));
first = false;
} else {
amx_op::<OP_FMA16>(fma16_acc(XRow::new_unchecked(0), YRow::new_unchecked(0), 0));
}
for i in 1u8..8 {
amx_op::<OP_FMA16>(fma16_acc(XRow::new_unchecked(i), YRow::new_unchecked(i), 0));
}
p += 8;
}
while p < k {
amx_op::<OP_LDX>(b_panel.add(p * 64) as u64);
amx_op::<OP_LDY>(a_panel.add(p * 64) as u64);
if first {
amx_op::<OP_FMA16>(fma16_first(
XRow::new_unchecked(0),
YRow::new_unchecked(0),
0,
));
first = false;
} else {
amx_op::<OP_FMA16>(fma16_acc(XRow::new_unchecked(0), YRow::new_unchecked(0), 0));
}
p += 1;
}
}
#[inline]
pub unsafe fn accumulate_tile_f16_to_f32(c: *mut f32, ldc: usize) {
let mut tmp = [0u16; 32 * 32];
store_c_f16(tmp.as_mut_ptr(), 32, 0);
for j in 0..32 {
let dst = c.add(j * ldc);
let src = tmp.as_ptr().add(j * 32);
let mut i = 0;
while i + 8 <= 32 {
core::arch::asm!(
"ldr q0, [{src}]", "fcvtl v1.4s, v0.4h", "fcvtl2 v2.4s, v0.8h", "ldp q3, q4, [{dst}]", "fadd v3.4s, v3.4s, v1.4s",
"fadd v4.4s, v4.4s, v2.4s",
"stp q3, q4, [{dst}]", src = in(reg) src.add(i),
dst = in(reg) dst.add(i),
out("v0") _, out("v1") _, out("v2") _, out("v3") _, out("v4") _,
);
i += 8;
}
}
}
#[inline]
pub unsafe fn store_tile_f16_to_f32(c: *mut f32, ldc: usize) {
let mut tmp = [0u16; 32 * 32];
store_c_f16(tmp.as_mut_ptr(), 32, 0);
for j in 0..32 {
let dst = c.add(j * ldc);
let src = tmp.as_ptr().add(j * 32);
let mut i = 0;
while i + 8 <= 32 {
core::arch::asm!(
"ldr q0, [{src}]",
"fcvtl v1.4s, v0.4h",
"fcvtl2 v2.4s, v0.8h",
"stp q1, q2, [{dst}]",
src = in(reg) src.add(i),
dst = in(reg) dst.add(i),
out("v0") _, out("v1") _, out("v2") _,
);
i += 8;
}
}
}