use super::regs::{XRow, YRow};
#[derive(Copy, Clone, Debug)]
pub struct FmaOp {
bits: u64,
}
impl FmaOp {
#[inline]
pub const fn new() -> Self {
Self { bits: 0 }
}
#[inline]
pub const fn x(mut self, row: XRow) -> Self {
self.bits = (self.bits & !0x7FC00) | ((row.byte_offset()) << 10);
self
}
#[inline]
pub const fn y(mut self, row: YRow) -> Self {
self.bits = (self.bits & !0x1FF) | row.byte_offset();
self
}
#[inline]
pub const fn z_tile(mut self, tile: u8) -> Self {
let t = (tile & 3) as u64;
self.bits = (self.bits & !(0x3F << 20)) | (t << 20);
self
}
#[inline]
pub const fn no_accumulate(mut self) -> Self {
self.bits |= 1 << 27;
self
}
#[inline]
pub const fn accumulate(mut self) -> Self {
self.bits &= !(1 << 27);
self
}
#[inline]
pub const fn vector_mode(mut self) -> Self {
self.bits |= 1 << 63;
self
}
#[inline]
pub const fn build(self) -> u64 {
self.bits
}
}
impl Default for FmaOp {
fn default() -> Self {
Self::new()
}
}
#[inline]
pub const fn fma_first(xr: XRow, yr: YRow, tile: u8) -> u64 {
FmaOp::new()
.x(xr)
.y(yr)
.z_tile(tile)
.no_accumulate()
.build()
}
#[inline]
pub const fn fma_acc(xr: XRow, yr: YRow, tile: u8) -> u64 {
FmaOp::new().x(xr).y(yr).z_tile(tile).build()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn operand_first_rank1() {
let op = fma_first(
unsafe { XRow::new_unchecked(0) },
unsafe { YRow::new_unchecked(0) },
0,
);
assert_ne!(op & (1 << 27), 0);
assert_eq!(op & 0x7FC00, 0); assert_eq!(op & 0x1FF, 0); }
#[test]
fn operand_acc_xn_yn() {
let op = fma_acc(
unsafe { XRow::new_unchecked(3) },
unsafe { YRow::new_unchecked(5) },
0,
);
assert_eq!(op & (1 << 27), 0);
assert_eq!((op >> 10) & 0x1FF, 192);
assert_eq!(op & 0x1FF, 320);
}
#[test]
fn operand_tile_select() {
let op = FmaOp::new().z_tile(2).build();
assert_eq!((op >> 20) & 0x3F, 2);
}
}