use super::asm::{
self, OP_FMA16, OP_FMA32, OP_LDX, OP_LDY, OP_LDZ, OP_LDZI, OP_MAC16, OP_STX, OP_STY, OP_STZ,
};
use super::regs::{XRow, YRow, ZRow};
use super::Matrix;
#[inline(always)]
fn load_store_operand(ptr: *const u8, row: u8) -> u64 {
(ptr as u64) | ((row as u64) << 56)
}
impl Matrix {
#[inline]
pub unsafe fn ldx(&self, src: *const u8, row: XRow) {
let operand = load_store_operand(src, row.index());
asm::amx_op::<OP_LDX>(operand);
}
#[inline]
pub unsafe fn ldy(&self, src: *const u8, row: YRow) {
let operand = load_store_operand(src, row.index());
asm::amx_op::<OP_LDY>(operand);
}
#[inline]
pub unsafe fn ldz(&self, src: *const u8, row: ZRow) {
let operand = load_store_operand(src, row.index());
asm::amx_op::<OP_LDZ>(operand);
}
#[inline]
pub unsafe fn ldzi(&self, src: *const u8, row: ZRow) {
let operand = load_store_operand(src, row.index());
asm::amx_op::<OP_LDZI>(operand);
}
#[inline]
pub unsafe fn stx(&self, dst: *mut u8, row: XRow) {
let operand = load_store_operand(dst, row.index());
asm::amx_op::<OP_STX>(operand);
}
#[inline]
pub unsafe fn sty(&self, dst: *mut u8, row: YRow) {
let operand = load_store_operand(dst, row.index());
asm::amx_op::<OP_STY>(operand);
}
#[inline]
pub unsafe fn stz(&self, dst: *mut u8, row: ZRow) {
let operand = load_store_operand(dst, row.index());
asm::amx_op::<OP_STZ>(operand);
}
#[inline]
pub unsafe fn fma32(&self, operand: u64) {
asm::amx_op::<OP_FMA32>(operand);
}
#[inline]
pub unsafe fn fma16(&self, operand: u64) {
asm::amx_op::<OP_FMA16>(operand);
}
#[inline]
pub unsafe fn mac16(&self, operand: u64) {
asm::amx_op::<OP_MAC16>(operand);
}
}