#[cfg(target_arch = "aarch64")]
use core::arch::aarch64::*;
const LN2: f32 = core::f32::consts::LN_2; const LN2_INV: f32 = 1.0 / LN2; const LN2_HI: f32 = 0.693_145_75; const LN2_LO: f32 = 1.428_606_8e-6;
const EXP_P0: f32 = 1.0;
const EXP_P1: f32 = 1.0;
const EXP_P2: f32 = 0.5; const EXP_P3: f32 = 0.166_666_7; const EXP_P4: f32 = 0.041_666_67; const EXP_P5: f32 = 0.008_333_34; const EXP_P6: f32 = 0.001_388_89;
const EXP_C5: f32 = 0.008_371_6;
const EXP_C4: f32 = 0.041_675_6;
const EXP_C3: f32 = 0.166_665_6;
const EXP_C2: f32 = 0.500_000_1;
const EXP_HI: f32 = 88.376_26;
const EXP_LO: f32 = -87.336_55;
const SQRT_2_OVER_PI: f32 = 0.797_884_6;
const GELU_COEFF: f32 = 0.044_715;
#[inline(always)]
pub(crate) fn exp_scalar(x: f32) -> f32 {
let x = x.clamp(EXP_LO, EXP_HI);
let n = (x * LN2_INV + 0.5).floor();
let r = x - n * LN2_HI - n * LN2_LO;
let mut p = EXP_P6;
p = p * r + EXP_P5;
p = p * r + EXP_P4;
p = p * r + EXP_P3;
p = p * r + EXP_P2;
p = p * r + EXP_P1;
p = p * r + EXP_P0;
let ni = n as i32;
let bits = ((ni + 127) as u32) << 23;
let pow2n = f32::from_bits(bits);
p * pow2n
}
#[inline(always)]
fn log_scalar(x: f32) -> f32 {
if x <= 0.0 {
return f32::NEG_INFINITY;
}
let bits = x.to_bits();
let e = ((bits >> 23) & 0xFF) as i32 - 127;
let m_bits = (bits & 0x007F_FFFF) | 0x3F80_0000; let m = f32::from_bits(m_bits);
let exp_adj = e as f32;
let f = m - 1.0;
let f2 = f * f;
let f3 = f2 * f;
let f4 = f2 * f2;
let f5 = f4 * f;
let f6 = f4 * f2;
let f7 = f4 * f3;
let r = f - f2 * 0.5 + f3 * (1.0 / 3.0) - f4 * 0.25 + f5 * 0.2 - f6 * (1.0 / 6.0)
+ f7 * (1.0 / 7.0);
exp_adj * LN2 + r
}
#[inline(always)]
fn tanh_scalar(x: f32) -> f32 {
if x.abs() > 9.0 {
return x.signum();
}
let e2x = exp_scalar(2.0 * x);
(e2x - 1.0) / (e2x + 1.0)
}
#[inline(always)]
fn sigmoid_scalar(x: f32) -> f32 {
1.0 / (1.0 + exp_scalar(-x))
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
pub(crate) unsafe fn exp_neon(x: float32x4_t) -> float32x4_t {
let x = vmaxq_f32(vminq_f32(x, vdupq_n_f32(EXP_HI)), vdupq_n_f32(EXP_LO));
let n = vrndnq_f32(vmulq_f32(x, vdupq_n_f32(LN2_INV)));
let r = vfmsq_f32(x, n, vdupq_n_f32(LN2));
let r2 = vmulq_f32(r, r);
let r4 = vmulq_f32(r2, r2);
let p01 = vfmaq_f32(vdupq_n_f32(EXP_P0), vdupq_n_f32(EXP_P1), r);
let p23 = vfmaq_f32(vdupq_n_f32(EXP_C2), vdupq_n_f32(EXP_C3), r);
let p45 = vfmaq_f32(vdupq_n_f32(EXP_C4), vdupq_n_f32(EXP_C5), r);
let p = vfmaq_f32(vfmaq_f32(p01, p23, r2), p45, r4);
let ni = vcvtq_s32_f32(n);
let pow2n = vreinterpretq_f32_s32(vshlq_n_s32::<23>(vaddq_s32(ni, vdupq_n_s32(127))));
vmulq_f32(p, pow2n)
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn log_neon(x: float32x4_t) -> float32x4_t {
let one = vdupq_n_f32(1.0);
let ln2_v = vdupq_n_f32(LN2);
let bits = vreinterpretq_s32_f32(x);
let e_raw = vsubq_s32(vshrq_n_s32::<23>(bits), vdupq_n_s32(127));
let m_bits = vorrq_s32(
vandq_s32(bits, vdupq_n_s32(0x007F_FFFFu32 as i32)),
vdupq_n_s32(0x3F80_0000u32 as i32),
);
let m = vreinterpretq_f32_s32(m_bits);
let f = vsubq_f32(m, one);
let mut p = vdupq_n_f32(1.0 / 7.0);
p = vfmaq_f32(vdupq_n_f32(-1.0 / 6.0), p, f);
p = vfmaq_f32(vdupq_n_f32(0.2), p, f);
p = vfmaq_f32(vdupq_n_f32(-0.25), p, f);
p = vfmaq_f32(vdupq_n_f32(1.0 / 3.0), p, f);
p = vfmaq_f32(vdupq_n_f32(-0.5), p, f);
p = vfmaq_f32(one, p, f);
let r = vmulq_f32(p, f);
let ef = vcvtq_f32_s32(e_raw);
vfmaq_f32(r, ef, ln2_v)
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn exp_neon_x4(
x0: float32x4_t,
x1: float32x4_t,
x2: float32x4_t,
x3: float32x4_t,
) -> (float32x4_t, float32x4_t, float32x4_t, float32x4_t) {
let lo = vdupq_n_f32(EXP_LO);
let hi = vdupq_n_f32(EXP_HI);
let x0 = vmaxq_f32(vminq_f32(x0, hi), lo);
let x1 = vmaxq_f32(vminq_f32(x1, hi), lo);
let x2 = vmaxq_f32(vminq_f32(x2, hi), lo);
let x3 = vmaxq_f32(vminq_f32(x3, hi), lo);
let inv_ln2 = vdupq_n_f32(LN2_INV);
let n0 = vrndnq_f32(vmulq_f32(x0, inv_ln2));
let n1 = vrndnq_f32(vmulq_f32(x1, inv_ln2));
let n2 = vrndnq_f32(vmulq_f32(x2, inv_ln2));
let n3 = vrndnq_f32(vmulq_f32(x3, inv_ln2));
let ln2 = vdupq_n_f32(LN2);
let r0 = vfmsq_f32(x0, n0, ln2);
let r1 = vfmsq_f32(x1, n1, ln2);
let r2 = vfmsq_f32(x2, n2, ln2);
let r3 = vfmsq_f32(x3, n3, ln2);
let c5 = vdupq_n_f32(EXP_C5);
let c4 = vdupq_n_f32(EXP_C4);
let c3 = vdupq_n_f32(EXP_C3);
let c2 = vdupq_n_f32(EXP_C2);
let c1 = vdupq_n_f32(EXP_P1);
let c0 = vdupq_n_f32(EXP_P0);
let r2_0 = vmulq_f32(r0, r0);
let r2_1 = vmulq_f32(r1, r1);
let r2_2 = vmulq_f32(r2, r2);
let r2_3 = vmulq_f32(r3, r3);
let r4_0 = vmulq_f32(r2_0, r2_0);
let r4_1 = vmulq_f32(r2_1, r2_1);
let r4_2 = vmulq_f32(r2_2, r2_2);
let r4_3 = vmulq_f32(r2_3, r2_3);
let p01_0 = vfmaq_f32(c0, c1, r0);
let p01_1 = vfmaq_f32(c0, c1, r1);
let p01_2 = vfmaq_f32(c0, c1, r2);
let p01_3 = vfmaq_f32(c0, c1, r3);
let p23_0 = vfmaq_f32(c2, c3, r0);
let p23_1 = vfmaq_f32(c2, c3, r1);
let p23_2 = vfmaq_f32(c2, c3, r2);
let p23_3 = vfmaq_f32(c2, c3, r3);
let p45_0 = vfmaq_f32(c4, c5, r0);
let p45_1 = vfmaq_f32(c4, c5, r1);
let p45_2 = vfmaq_f32(c4, c5, r2);
let p45_3 = vfmaq_f32(c4, c5, r3);
let p0 = vfmaq_f32(vfmaq_f32(p01_0, p23_0, r2_0), p45_0, r4_0);
let p1 = vfmaq_f32(vfmaq_f32(p01_1, p23_1, r2_1), p45_1, r4_1);
let p2 = vfmaq_f32(vfmaq_f32(p01_2, p23_2, r2_2), p45_2, r4_2);
let p3 = vfmaq_f32(vfmaq_f32(p01_3, p23_3, r2_3), p45_3, r4_3);
let bias = vdupq_n_s32(127);
let pow0 = vreinterpretq_f32_s32(vshlq_n_s32::<23>(vaddq_s32(vcvtq_s32_f32(n0), bias)));
let pow1 = vreinterpretq_f32_s32(vshlq_n_s32::<23>(vaddq_s32(vcvtq_s32_f32(n1), bias)));
let pow2 = vreinterpretq_f32_s32(vshlq_n_s32::<23>(vaddq_s32(vcvtq_s32_f32(n2), bias)));
let pow3 = vreinterpretq_f32_s32(vshlq_n_s32::<23>(vaddq_s32(vcvtq_s32_f32(n3), bias)));
(
vmulq_f32(p0, pow0),
vmulq_f32(p1, pow1),
vmulq_f32(p2, pow2),
vmulq_f32(p3, pow3),
)
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn log_neon_x4(
x0: float32x4_t,
x1: float32x4_t,
x2: float32x4_t,
x3: float32x4_t,
) -> (float32x4_t, float32x4_t, float32x4_t, float32x4_t) {
let one = vdupq_n_f32(1.0);
let ln2_v = vdupq_n_f32(LN2);
let mask = vdupq_n_s32(0x007F_FFFFu32 as i32);
let base = vdupq_n_s32(0x3F80_0000u32 as i32);
let bias = vdupq_n_s32(127);
let b0 = vreinterpretq_s32_f32(x0);
let b1 = vreinterpretq_s32_f32(x1);
let b2 = vreinterpretq_s32_f32(x2);
let b3 = vreinterpretq_s32_f32(x3);
let e0 = vsubq_s32(vshrq_n_s32::<23>(b0), bias);
let e1 = vsubq_s32(vshrq_n_s32::<23>(b1), bias);
let e2 = vsubq_s32(vshrq_n_s32::<23>(b2), bias);
let e3 = vsubq_s32(vshrq_n_s32::<23>(b3), bias);
let f0 = vsubq_f32(
vreinterpretq_f32_s32(vorrq_s32(vandq_s32(b0, mask), base)),
one,
);
let f1 = vsubq_f32(
vreinterpretq_f32_s32(vorrq_s32(vandq_s32(b1, mask), base)),
one,
);
let f2 = vsubq_f32(
vreinterpretq_f32_s32(vorrq_s32(vandq_s32(b2, mask), base)),
one,
);
let f3 = vsubq_f32(
vreinterpretq_f32_s32(vorrq_s32(vandq_s32(b3, mask), base)),
one,
);
let c7 = vdupq_n_f32(1.0 / 7.0);
let c6 = vdupq_n_f32(-1.0 / 6.0);
let c5 = vdupq_n_f32(0.2);
let c4 = vdupq_n_f32(-0.25);
let c3 = vdupq_n_f32(1.0 / 3.0);
let c2 = vdupq_n_f32(-0.5);
let mut p0 = c7;
let mut p1 = c7;
let mut p2 = c7;
let mut p3 = c7;
p0 = vfmaq_f32(c6, p0, f0);
p1 = vfmaq_f32(c6, p1, f1);
p2 = vfmaq_f32(c6, p2, f2);
p3 = vfmaq_f32(c6, p3, f3);
p0 = vfmaq_f32(c5, p0, f0);
p1 = vfmaq_f32(c5, p1, f1);
p2 = vfmaq_f32(c5, p2, f2);
p3 = vfmaq_f32(c5, p3, f3);
p0 = vfmaq_f32(c4, p0, f0);
p1 = vfmaq_f32(c4, p1, f1);
p2 = vfmaq_f32(c4, p2, f2);
p3 = vfmaq_f32(c4, p3, f3);
p0 = vfmaq_f32(c3, p0, f0);
p1 = vfmaq_f32(c3, p1, f1);
p2 = vfmaq_f32(c3, p2, f2);
p3 = vfmaq_f32(c3, p3, f3);
p0 = vfmaq_f32(c2, p0, f0);
p1 = vfmaq_f32(c2, p1, f1);
p2 = vfmaq_f32(c2, p2, f2);
p3 = vfmaq_f32(c2, p3, f3);
p0 = vfmaq_f32(one, p0, f0);
p1 = vfmaq_f32(one, p1, f1);
p2 = vfmaq_f32(one, p2, f2);
p3 = vfmaq_f32(one, p3, f3);
let r0 = vmulq_f32(p0, f0);
let r1 = vmulq_f32(p1, f1);
let r2 = vmulq_f32(p2, f2);
let r3 = vmulq_f32(p3, f3);
(
vfmaq_f32(r0, vcvtq_f32_s32(e0), ln2_v),
vfmaq_f32(r1, vcvtq_f32_s32(e1), ln2_v),
vfmaq_f32(r2, vcvtq_f32_s32(e2), ln2_v),
vfmaq_f32(r3, vcvtq_f32_s32(e3), ln2_v),
)
}
pub fn exp(x: &mut [f32]) {
let len = x.len();
let mut i = 0;
#[cfg(target_arch = "aarch64")]
{
unsafe {
let ptr = x.as_mut_ptr();
while i + 16 <= len {
let (r0, r1, r2, r3) = exp_neon_x4(
vld1q_f32(ptr.add(i)),
vld1q_f32(ptr.add(i + 4)),
vld1q_f32(ptr.add(i + 8)),
vld1q_f32(ptr.add(i + 12)),
);
vst1q_f32(ptr.add(i), r0);
vst1q_f32(ptr.add(i + 4), r1);
vst1q_f32(ptr.add(i + 8), r2);
vst1q_f32(ptr.add(i + 12), r3);
i += 16;
}
while i + 4 <= len {
vst1q_f32(ptr.add(i), exp_neon(vld1q_f32(ptr.add(i))));
i += 4;
}
}
}
while i < len {
x[i] = exp_scalar(x[i]);
i += 1;
}
}
pub fn exp_to(src: &[f32], dst: &mut [f32]) {
let len = src.len().min(dst.len());
let mut i = 0;
#[cfg(target_arch = "aarch64")]
{
unsafe {
let inp = src.as_ptr();
let out = dst.as_mut_ptr();
let bulk = len & !15;
if bulk > 0 {
super::exp_kern::exp_asm(out, inp, bulk);
i = bulk;
}
while i + 4 <= len {
vst1q_f32(out.add(i), exp_neon(vld1q_f32(inp.add(i))));
i += 4;
}
}
}
while i < len {
dst[i] = exp_scalar(src[i]);
i += 1;
}
}
pub fn log_to(src: &[f32], dst: &mut [f32]) {
let len = src.len().min(dst.len());
let mut i = 0;
#[cfg(target_arch = "aarch64")]
{
unsafe {
let inp = src.as_ptr();
let out = dst.as_mut_ptr();
while i + 16 <= len {
let (r0, r1, r2, r3) = log_neon_x4(
vld1q_f32(inp.add(i)),
vld1q_f32(inp.add(i + 4)),
vld1q_f32(inp.add(i + 8)),
vld1q_f32(inp.add(i + 12)),
);
vst1q_f32(out.add(i), r0);
vst1q_f32(out.add(i + 4), r1);
vst1q_f32(out.add(i + 8), r2);
vst1q_f32(out.add(i + 12), r3);
i += 16;
}
while i + 4 <= len {
vst1q_f32(out.add(i), log_neon(vld1q_f32(inp.add(i))));
i += 4;
}
}
}
while i < len {
dst[i] = log_scalar(src[i]);
i += 1;
}
}
pub fn log(x: &mut [f32]) {
let len = x.len();
let mut i = 0;
#[cfg(target_arch = "aarch64")]
{
unsafe {
let ptr = x.as_mut_ptr();
while i + 16 <= len {
let (r0, r1, r2, r3) = log_neon_x4(
vld1q_f32(ptr.add(i)),
vld1q_f32(ptr.add(i + 4)),
vld1q_f32(ptr.add(i + 8)),
vld1q_f32(ptr.add(i + 12)),
);
vst1q_f32(ptr.add(i), r0);
vst1q_f32(ptr.add(i + 4), r1);
vst1q_f32(ptr.add(i + 8), r2);
vst1q_f32(ptr.add(i + 12), r3);
i += 16;
}
while i + 4 <= len {
vst1q_f32(ptr.add(i), log_neon(vld1q_f32(ptr.add(i))));
i += 4;
}
}
}
while i < len {
x[i] = log_scalar(x[i]);
i += 1;
}
}
pub fn tanh(x: &mut [f32]) {
let len = x.len();
let mut i = 0;
#[cfg(target_arch = "aarch64")]
{
unsafe {
let two = vdupq_n_f32(2.0);
let one = vdupq_n_f32(1.0);
let clamp_pos = vdupq_n_f32(9.0);
let clamp_neg = vdupq_n_f32(-9.0);
let ptr = x.as_mut_ptr();
while i + 16 <= len {
let v0 = vmaxq_f32(vminq_f32(vld1q_f32(ptr.add(i)), clamp_pos), clamp_neg);
let v1 = vmaxq_f32(vminq_f32(vld1q_f32(ptr.add(i + 4)), clamp_pos), clamp_neg);
let v2 = vmaxq_f32(vminq_f32(vld1q_f32(ptr.add(i + 8)), clamp_pos), clamp_neg);
let v3 = vmaxq_f32(vminq_f32(vld1q_f32(ptr.add(i + 12)), clamp_pos), clamp_neg);
let e0 = exp_neon(vmulq_f32(two, v0));
let e1 = exp_neon(vmulq_f32(two, v1));
let e2 = exp_neon(vmulq_f32(two, v2));
let e3 = exp_neon(vmulq_f32(two, v3));
vst1q_f32(
ptr.add(i),
vdivq_f32(vsubq_f32(e0, one), vaddq_f32(e0, one)),
);
vst1q_f32(
ptr.add(i + 4),
vdivq_f32(vsubq_f32(e1, one), vaddq_f32(e1, one)),
);
vst1q_f32(
ptr.add(i + 8),
vdivq_f32(vsubq_f32(e2, one), vaddq_f32(e2, one)),
);
vst1q_f32(
ptr.add(i + 12),
vdivq_f32(vsubq_f32(e3, one), vaddq_f32(e3, one)),
);
i += 16;
}
while i + 4 <= len {
let v = vmaxq_f32(vminq_f32(vld1q_f32(ptr.add(i)), clamp_pos), clamp_neg);
let e2x = exp_neon(vmulq_f32(two, v));
vst1q_f32(
ptr.add(i),
vdivq_f32(vsubq_f32(e2x, one), vaddq_f32(e2x, one)),
);
i += 4;
}
}
}
while i < len {
x[i] = tanh_scalar(x[i]);
i += 1;
}
}
pub fn sigmoid(x: &mut [f32]) {
let len = x.len();
let mut i = 0;
#[cfg(target_arch = "aarch64")]
{
unsafe {
let one = vdupq_n_f32(1.0);
let ptr = x.as_mut_ptr();
while i + 16 <= len {
let e0 = exp_neon(vnegq_f32(vld1q_f32(ptr.add(i))));
let e1 = exp_neon(vnegq_f32(vld1q_f32(ptr.add(i + 4))));
let e2 = exp_neon(vnegq_f32(vld1q_f32(ptr.add(i + 8))));
let e3 = exp_neon(vnegq_f32(vld1q_f32(ptr.add(i + 12))));
vst1q_f32(ptr.add(i), vdivq_f32(one, vaddq_f32(one, e0)));
vst1q_f32(ptr.add(i + 4), vdivq_f32(one, vaddq_f32(one, e1)));
vst1q_f32(ptr.add(i + 8), vdivq_f32(one, vaddq_f32(one, e2)));
vst1q_f32(ptr.add(i + 12), vdivq_f32(one, vaddq_f32(one, e3)));
i += 16;
}
while i + 4 <= len {
let e = exp_neon(vnegq_f32(vld1q_f32(ptr.add(i))));
vst1q_f32(ptr.add(i), vdivq_f32(one, vaddq_f32(one, e)));
i += 4;
}
}
}
while i < len {
x[i] = sigmoid_scalar(x[i]);
i += 1;
}
}
pub fn gelu(x: &mut [f32]) {
let len = x.len();
let mut i = 0;
#[cfg(target_arch = "aarch64")]
{
unsafe {
let half = vdupq_n_f32(0.5);
let one = vdupq_n_f32(1.0);
let two = vdupq_n_f32(2.0);
let coeff = vdupq_n_f32(GELU_COEFF);
let s2pi = vdupq_n_f32(SQRT_2_OVER_PI);
let clamp_pos = vdupq_n_f32(9.0);
let clamp_neg = vdupq_n_f32(-9.0);
let ptr = x.as_mut_ptr();
while i + 16 <= len {
let v0 = vld1q_f32(ptr.add(i));
let v1 = vld1q_f32(ptr.add(i + 4));
let v2 = vld1q_f32(ptr.add(i + 8));
let v3 = vld1q_f32(ptr.add(i + 12));
let i0 = vmaxq_f32(
vminq_f32(
vmulq_f32(s2pi, vfmaq_f32(v0, coeff, vmulq_f32(vmulq_f32(v0, v0), v0))),
clamp_pos,
),
clamp_neg,
);
let i1 = vmaxq_f32(
vminq_f32(
vmulq_f32(s2pi, vfmaq_f32(v1, coeff, vmulq_f32(vmulq_f32(v1, v1), v1))),
clamp_pos,
),
clamp_neg,
);
let i2 = vmaxq_f32(
vminq_f32(
vmulq_f32(s2pi, vfmaq_f32(v2, coeff, vmulq_f32(vmulq_f32(v2, v2), v2))),
clamp_pos,
),
clamp_neg,
);
let i3 = vmaxq_f32(
vminq_f32(
vmulq_f32(s2pi, vfmaq_f32(v3, coeff, vmulq_f32(vmulq_f32(v3, v3), v3))),
clamp_pos,
),
clamp_neg,
);
let e0 = exp_neon(vmulq_f32(two, i0));
let e1 = exp_neon(vmulq_f32(two, i1));
let e2 = exp_neon(vmulq_f32(two, i2));
let e3 = exp_neon(vmulq_f32(two, i3));
let t0 = vdivq_f32(vsubq_f32(e0, one), vaddq_f32(e0, one));
let t1 = vdivq_f32(vsubq_f32(e1, one), vaddq_f32(e1, one));
let t2 = vdivq_f32(vsubq_f32(e2, one), vaddq_f32(e2, one));
let t3 = vdivq_f32(vsubq_f32(e3, one), vaddq_f32(e3, one));
vst1q_f32(
ptr.add(i),
vmulq_f32(half, vmulq_f32(v0, vaddq_f32(one, t0))),
);
vst1q_f32(
ptr.add(i + 4),
vmulq_f32(half, vmulq_f32(v1, vaddq_f32(one, t1))),
);
vst1q_f32(
ptr.add(i + 8),
vmulq_f32(half, vmulq_f32(v2, vaddq_f32(one, t2))),
);
vst1q_f32(
ptr.add(i + 12),
vmulq_f32(half, vmulq_f32(v3, vaddq_f32(one, t3))),
);
i += 16;
}
while i + 4 <= len {
let v = vld1q_f32(ptr.add(i));
let x3 = vmulq_f32(vmulq_f32(v, v), v);
let inner = vmulq_f32(s2pi, vfmaq_f32(v, coeff, x3));
let inner_c = vmaxq_f32(vminq_f32(inner, clamp_pos), clamp_neg);
let e2 = exp_neon(vmulq_f32(two, inner_c));
let th = vdivq_f32(vsubq_f32(e2, one), vaddq_f32(e2, one));
vst1q_f32(
ptr.add(i),
vmulq_f32(half, vmulq_f32(v, vaddq_f32(one, th))),
);
i += 4;
}
}
}
while i < len {
let xi = x[i];
let inner = SQRT_2_OVER_PI * (xi + GELU_COEFF * xi * xi * xi);
x[i] = 0.5 * xi * (1.0 + tanh_scalar(inner));
i += 1;
}
}
pub fn silu(x: &mut [f32]) {
let len = x.len();
let mut i = 0;
#[cfg(target_arch = "aarch64")]
{
unsafe {
let one = vdupq_n_f32(1.0);
let ptr = x.as_mut_ptr();
while i + 16 <= len {
let v0 = vld1q_f32(ptr.add(i));
let v1 = vld1q_f32(ptr.add(i + 4));
let v2 = vld1q_f32(ptr.add(i + 8));
let v3 = vld1q_f32(ptr.add(i + 12));
let e0 = exp_neon(vnegq_f32(v0));
let e1 = exp_neon(vnegq_f32(v1));
let e2 = exp_neon(vnegq_f32(v2));
let e3 = exp_neon(vnegq_f32(v3));
vst1q_f32(
ptr.add(i),
vmulq_f32(v0, vdivq_f32(one, vaddq_f32(one, e0))),
);
vst1q_f32(
ptr.add(i + 4),
vmulq_f32(v1, vdivq_f32(one, vaddq_f32(one, e1))),
);
vst1q_f32(
ptr.add(i + 8),
vmulq_f32(v2, vdivq_f32(one, vaddq_f32(one, e2))),
);
vst1q_f32(
ptr.add(i + 12),
vmulq_f32(v3, vdivq_f32(one, vaddq_f32(one, e3))),
);
i += 16;
}
while i + 4 <= len {
let v = vld1q_f32(ptr.add(i));
let e = exp_neon(vnegq_f32(v));
vst1q_f32(ptr.add(i), vmulq_f32(v, vdivq_f32(one, vaddq_f32(one, e))));
i += 4;
}
}
}
while i < len {
let xi = x[i];
x[i] = xi * sigmoid_scalar(xi);
i += 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f32, b: f32, tol: f32) -> bool {
(a - b).abs() < tol
}
#[test]
fn test_exp_basic() {
let mut v = vec![0.0, 1.0, -1.0, 2.0];
exp(&mut v);
assert!(approx_eq(v[0], 1.0, 1e-5));
assert!(approx_eq(v[1], std::f32::consts::E, 1e-4));
assert!(approx_eq(v[2], 1.0 / std::f32::consts::E, 1e-4));
}
#[test]
fn test_sigmoid_bounds() {
let mut v = vec![-10.0, 0.0, 10.0];
sigmoid(&mut v);
assert!(v[0] < 0.01);
assert!(approx_eq(v[1], 0.5, 1e-5));
assert!(v[2] > 0.99);
}
#[test]
fn test_tanh_bounds() {
let mut v = vec![-20.0, 0.0, 20.0];
tanh(&mut v);
assert!(approx_eq(v[0], -1.0, 1e-5));
assert!(approx_eq(v[1], 0.0, 1e-5));
assert!(approx_eq(v[2], 1.0, 1e-5));
}
#[test]
fn test_log_basic() {
let mut v = vec![1.0, std::f32::consts::E, 0.5];
log(&mut v);
assert!(approx_eq(v[0], 0.0, 1e-5));
assert!(approx_eq(v[1], 1.0, 1e-3));
assert!(approx_eq(v[2], -(2.0f32.ln()), 1e-3));
}
#[test]
fn test_gelu_zero() {
let mut v = vec![0.0];
gelu(&mut v);
assert!(approx_eq(v[0], 0.0, 1e-6));
}
#[test]
fn test_silu_zero() {
let mut v = vec![0.0];
silu(&mut v);
assert!(approx_eq(v[0], 0.0, 1e-6));
}
}