pub mod scalar {
pub mod f32 {
type T = f32;
const N: usize = 1;
type Pack = [T; N];
#[inline(always)]
unsafe fn splat(value: T) -> Pack {
[value]
}
#[inline(always)]
unsafe fn mul(lhs: Pack, rhs: Pack) -> Pack {
[lhs[0] * rhs[0]]
}
#[inline(always)]
unsafe fn add(lhs: Pack, rhs: Pack) -> Pack {
[lhs[0] + rhs[0]]
}
#[inline(always)]
unsafe fn mul_add(a: Pack, b: Pack, c: Pack) -> Pack {
add(mul(a, b), c)
}
#[inline(always)]
pub unsafe fn scalar_mul(lhs: T, rhs: T) -> T {
lhs * rhs
}
#[inline(always)]
pub unsafe fn scalar_add(lhs: T, rhs: T) -> T {
lhs + rhs
}
#[inline(always)]
pub unsafe fn scalar_mul_add(a: T, b: T, c: T) -> T {
a * b + c
}
microkernel!(, 2, x1x1, 1, 1);
microkernel!(, 2, x1x2, 1, 2);
microkernel!(, 2, x1x3, 1, 3);
microkernel!(, 2, x1x4, 1, 4);
microkernel!(, 2, x2x1, 2, 1);
microkernel!(, 2, x2x2, 2, 2);
microkernel!(, 2, x2x3, 2, 3);
microkernel!(, 2, x2x4, 2, 4);
microkernel_fn_array! {
[x1x1, x1x2, x1x3, x1x4,],
[x2x1, x2x2, x2x3, x2x4,],
}
pub const H_M: usize = 0;
pub const H_N: usize = 0;
pub const H_UKR: [[gemm_common::microkernel::HMicroKernelFn<T>; H_N]; H_M] = [];
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub mod fma {
pub mod f32 {
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
use core::mem::transmute;
use gemm_common::pulp::u32x8;
type T = f32;
const N: usize = 8;
type Pack = [T; N];
#[inline(always)]
unsafe fn splat(value: T) -> Pack {
transmute(_mm256_set1_ps(value))
}
#[inline(always)]
unsafe fn mul(lhs: Pack, rhs: Pack) -> Pack {
transmute(_mm256_mul_ps(transmute(lhs), transmute(rhs)))
}
#[inline(always)]
unsafe fn add(lhs: Pack, rhs: Pack) -> Pack {
transmute(_mm256_add_ps(transmute(lhs), transmute(rhs)))
}
#[inline(always)]
unsafe fn mul_add(a: Pack, b: Pack, c: Pack) -> Pack {
transmute(_mm256_fmadd_ps(transmute(a), transmute(b), transmute(c)))
}
#[inline(always)]
pub unsafe fn scalar_mul(lhs: T, rhs: T) -> T {
lhs * rhs
}
#[inline(always)]
pub unsafe fn scalar_add(lhs: T, rhs: T) -> T {
lhs + rhs
}
#[inline(always)]
pub unsafe fn scalar_mul_add(a: T, b: T, c: T) -> T {
gemm_common::simd::v3_fmaf(a, b, c)
}
static U32_MASKS: [u32x8; 9] = [
u32x8(0, 0, 0, 0, 0, 0, 0, 0),
u32x8(!0, 0, 0, 0, 0, 0, 0, 0),
u32x8(!0, !0, 0, 0, 0, 0, 0, 0),
u32x8(!0, !0, !0, 0, 0, 0, 0, 0),
u32x8(!0, !0, !0, !0, 0, 0, 0, 0),
u32x8(!0, !0, !0, !0, !0, 0, 0, 0),
u32x8(!0, !0, !0, !0, !0, !0, 0, 0),
u32x8(!0, !0, !0, !0, !0, !0, !0, 0),
u32x8(!0, !0, !0, !0, !0, !0, !0, !0),
];
#[inline(always)]
pub unsafe fn partial_load(ptr: *const T, len: usize) -> [T; N] {
transmute(_mm256_maskload_ps(
ptr,
transmute(*(U32_MASKS.as_ptr().add(len))),
))
}
#[inline(always)]
pub unsafe fn reduce_sum(x: [T; N]) -> T {
let x: __m256 = transmute(x);
let x = _mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps::<1>(x));
let hi = _mm_movehl_ps(x, x);
let r0 = _mm_add_ps(x, hi);
let r0_shuffled = _mm_shuffle_ps::<0b0001>(r0, r0);
let r = _mm_add_ss(r0, r0_shuffled);
_mm_cvtss_f32(r)
}
microkernel!(["fma"], 2, x1x1, 1, 1);
microkernel!(["fma"], 2, x1x2, 1, 2);
microkernel!(["fma"], 2, x1x3, 1, 3);
microkernel!(["fma"], 2, x1x4, 1, 4);
microkernel!(["fma"], 2, x1x5, 1, 5);
microkernel!(["fma"], 2, x1x6, 1, 6);
microkernel!(["fma"], 2, x2x1, 2, 1);
microkernel!(["fma"], 2, x2x2, 2, 2);
microkernel!(["fma"], 2, x2x3, 2, 3);
microkernel!(["fma"], 2, x2x4, 2, 4);
microkernel!(["fma"], 2, x2x5, 2, 5);
microkernel!(["fma"], 2, x2x6, 2, 6);
microkernel_fn_array! {
[x1x1, x1x2, x1x3, x1x4, x1x5, x1x6,],
[x2x1, x2x2, x2x3, x2x4, x2x5, x2x6,],
}
horizontal_kernel!(["fma"], hx1x1, 1, 1);
horizontal_kernel!(["fma"], hx1x2, 1, 2);
horizontal_kernel!(["fma"], hx2x1, 2, 1);
horizontal_kernel!(["fma"], hx2x2, 2, 2);
hmicrokernel_fn_array! {
[hx1x1, hx1x2,],
[hx2x1, hx2x2,],
}
}
}
#[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
pub mod avx512f {
pub mod f32 {
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
use core::mem::transmute;
type T = f32;
const N: usize = 16;
type Pack = [T; N];
#[inline(always)]
unsafe fn splat(value: T) -> Pack {
transmute(_mm512_set1_ps(value))
}
#[inline(always)]
unsafe fn mul(lhs: Pack, rhs: Pack) -> Pack {
transmute(_mm512_mul_ps(transmute(lhs), transmute(rhs)))
}
#[inline(always)]
unsafe fn add(lhs: Pack, rhs: Pack) -> Pack {
transmute(_mm512_add_ps(transmute(lhs), transmute(rhs)))
}
#[inline(always)]
unsafe fn mul_add(a: Pack, b: Pack, c: Pack) -> Pack {
transmute(_mm512_fmadd_ps(transmute(a), transmute(b), transmute(c)))
}
#[inline(always)]
pub unsafe fn scalar_mul(lhs: T, rhs: T) -> T {
lhs * rhs
}
#[inline(always)]
pub unsafe fn scalar_add(lhs: T, rhs: T) -> T {
lhs + rhs
}
#[inline(always)]
pub unsafe fn scalar_mul_add(a: T, b: T, c: T) -> T {
gemm_common::simd::v3_fmaf(a, b, c)
}
static U32_MASKS: [u16; 17] = [
0b0000000000000000,
0b0000000000000001,
0b0000000000000011,
0b0000000000000111,
0b0000000000001111,
0b0000000000011111,
0b0000000000111111,
0b0000000001111111,
0b0000000011111111,
0b0000000111111111,
0b0000001111111111,
0b0000011111111111,
0b0000111111111111,
0b0001111111111111,
0b0011111111111111,
0b0111111111111111,
0b1111111111111111,
];
#[inline(always)]
pub unsafe fn partial_load(ptr: *const T, len: usize) -> [T; N] {
transmute(_mm512_maskz_loadu_ps(
transmute(*(U32_MASKS.as_ptr().add(len))),
ptr,
))
}
#[inline(always)]
pub unsafe fn reduce_sum(x: [T; N]) -> T {
let x = transmute(x);
let x = _mm256_add_ps(
_mm512_castps512_ps256(x),
transmute(_mm512_extractf64x4_pd::<1>(transmute(x))),
);
let x = _mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps::<1>(x));
let hi = _mm_movehl_ps(x, x);
let r0 = _mm_add_ps(x, hi);
let r0_shuffled = _mm_shuffle_ps::<0b0001>(r0, r0);
let r = _mm_add_ss(r0, r0_shuffled);
_mm_cvtss_f32(r)
}
microkernel!(["avx512f"], 4, x1x1, 1, 1);
microkernel!(["avx512f"], 4, x1x2, 1, 2);
microkernel!(["avx512f"], 4, x1x3, 1, 3);
microkernel!(["avx512f"], 4, x1x4, 1, 4);
microkernel!(["avx512f"], 4, x1x5, 1, 5);
microkernel!(["avx512f"], 4, x1x6, 1, 6);
microkernel!(["avx512f"], 4, x2x1, 2, 1);
microkernel!(["avx512f"], 4, x2x2, 2, 2);
microkernel!(["avx512f"], 4, x2x3, 2, 3);
microkernel!(["avx512f"], 4, x2x4, 2, 4);
microkernel!(["avx512f"], 4, x2x5, 2, 5);
microkernel!(["avx512f"], 4, x2x6, 2, 6);
microkernel!(["avx512f"], 4, x3x1, 3, 1);
microkernel!(["avx512f"], 4, x3x2, 3, 2);
microkernel!(["avx512f"], 4, x3x3, 3, 3);
microkernel!(["avx512f"], 4, x3x4, 3, 4);
microkernel!(["avx512f"], 4, x3x5, 3, 5);
microkernel!(["avx512f"], 4, x3x6, 3, 6);
microkernel!(["avx512f"], 4, x4x1, 4, 1);
microkernel!(["avx512f"], 4, x4x2, 4, 2);
microkernel!(["avx512f"], 4, x4x3, 4, 3);
microkernel!(["avx512f"], 4, x4x4, 4, 4);
microkernel!(["avx512f"], 4, x4x5, 4, 5);
microkernel!(["avx512f"], 4, x4x6, 4, 6);
microkernel_fn_array! {
[x1x1, x1x2, x1x3, x1x4, x1x5, x1x6,],
[x2x1, x2x2, x2x3, x2x4, x2x5, x2x6,],
[x3x1, x3x2, x3x3, x3x4, x3x5, x3x6,],
[x4x1, x4x2, x4x3, x4x4, x4x5, x4x6,],
}
horizontal_kernel!(["avx512f"], hx1x1, 1, 1);
horizontal_kernel!(["avx512f"], hx1x2, 1, 2);
horizontal_kernel!(["avx512f"], hx1x3, 1, 3);
horizontal_kernel!(["avx512f"], hx1x4, 1, 4);
horizontal_kernel!(["avx512f"], hx2x1, 2, 1);
horizontal_kernel!(["avx512f"], hx2x2, 2, 2);
horizontal_kernel!(["avx512f"], hx2x3, 2, 3);
horizontal_kernel!(["avx512f"], hx2x4, 2, 4);
horizontal_kernel!(["avx512f"], hx3x1, 3, 1);
horizontal_kernel!(["avx512f"], hx3x2, 3, 2);
horizontal_kernel!(["avx512f"], hx3x3, 3, 3);
horizontal_kernel!(["avx512f"], hx3x4, 3, 4);
horizontal_kernel!(["avx512f"], hx4x1, 4, 1);
horizontal_kernel!(["avx512f"], hx4x2, 4, 2);
horizontal_kernel!(["avx512f"], hx4x3, 4, 3);
horizontal_kernel!(["avx512f"], hx4x4, 4, 4);
hmicrokernel_fn_array! {
[hx1x1, hx1x2, hx1x3, hx1x4,],
[hx2x1, hx2x2, hx2x3, hx2x4,],
[hx3x1, hx3x2, hx3x3, hx3x4,],
[hx4x1, hx4x2, hx4x3, hx4x4,],
}
}
}
#[allow(dead_code)]
mod v128_common {
pub mod f32 {
pub type T = f32;
pub const N: usize = 4;
pub type Pack = [T; N];
#[inline(always)]
pub unsafe fn splat(value: T) -> Pack {
[value, value, value, value]
}
}
}
#[cfg(target_arch = "aarch64")]
pub mod neon {
pub mod f32 {
use super::super::v128_common::f32::*;
use core::arch::aarch64::*;
use core::mem::transmute;
#[cfg(miri)]
unsafe fn vfmaq_f32(c: float32x4_t, a: float32x4_t, b: float32x4_t) -> float32x4_t {
let c: [f32; 4] = transmute(c);
let a: [f32; 4] = transmute(a);
let b: [f32; 4] = transmute(b);
transmute([
f32::mul_add(a[0], b[0], c[0]),
f32::mul_add(a[1], b[1], c[1]),
f32::mul_add(a[2], b[2], c[2]),
f32::mul_add(a[3], b[3], c[3]),
])
}
#[inline(always)]
pub unsafe fn mul(lhs: Pack, rhs: Pack) -> Pack {
transmute(vmulq_f32(transmute(lhs), transmute(rhs)))
}
#[inline(always)]
pub unsafe fn add(lhs: Pack, rhs: Pack) -> Pack {
transmute(vaddq_f32(transmute(lhs), transmute(rhs)))
}
#[inline(always)]
pub unsafe fn mul_add(a: Pack, b: Pack, c: Pack) -> Pack {
transmute(vfmaq_f32(transmute(c), transmute(a), transmute(b)))
}
#[inline(always)]
pub unsafe fn mul_add_lane<const LANE: i32>(a: Pack, b: Pack, c: Pack) -> Pack {
transmute(vfmaq_laneq_f32::<LANE>(
transmute(c),
transmute(a),
transmute(b),
))
}
#[inline(always)]
pub unsafe fn scalar_mul(lhs: T, rhs: T) -> T {
lhs * rhs
}
#[inline(always)]
pub unsafe fn scalar_add(lhs: T, rhs: T) -> T {
lhs + rhs
}
#[inline(always)]
pub unsafe fn scalar_mul_add(a: T, b: T, c: T) -> T {
gemm_common::simd::neon_fmaf(a, b, c)
}
#[inline(always)]
pub unsafe fn load<const MR_DIV_N: usize>(dst: *mut Pack, ptr: *const f32) {
match MR_DIV_N {
1 => *(dst as *mut [Pack; 1]) = transmute(vld1q_f32(ptr)),
2 => *(dst as *mut [Pack; 2]) = transmute(vld1q_f32_x2(ptr)),
3 => *(dst as *mut [Pack; 3]) = transmute(vld1q_f32_x3(ptr)),
4 => *(dst as *mut [Pack; 4]) = transmute(vld1q_f32_x4(ptr)),
_ => unreachable!(),
}
}
microkernel!(["neon"], 4, x1x1, 1, 1);
microkernel!(["neon"], 4, x1x2, 1, 2);
microkernel!(["neon"], 4, x1x3, 1, 3);
microkernel!(["neon"], 4, x1x4, 1, 4, 1, 4);
microkernel!(["neon"], 4, x2x1, 2, 1);
microkernel!(["neon"], 4, x2x2, 2, 2);
microkernel!(["neon"], 4, x2x3, 2, 3);
microkernel!(["neon"], 4, x2x4, 2, 4, 1, 4);
microkernel!(["neon"], 4, x3x1, 3, 1);
microkernel!(["neon"], 4, x3x2, 3, 2);
microkernel!(["neon"], 4, x3x3, 3, 3);
microkernel!(["neon"], 4, x3x4, 3, 4, 1, 4);
microkernel!(["neon"], 4, x4x1, 4, 1);
microkernel!(["neon"], 4, x4x2, 4, 2);
microkernel!(["neon"], 4, x4x3, 4, 3);
microkernel!(["neon"], 4, x4x4, 4, 4, 1, 4);
microkernel_fn_array! {
[x1x1, x1x2, x1x3, x1x4, ],
[x2x1, x2x2, x2x3, x2x4, ],
[x3x1, x3x2, x3x3, x3x4, ],
[x4x1, x4x2, x4x3, x4x4, ],
}
pub const H_M: usize = 0;
pub const H_N: usize = 0;
pub const H_UKR: [[gemm_common::microkernel::HMicroKernelFn<T>; H_N]; H_M] = [];
}
}
#[cfg(target_arch = "aarch64")]
pub mod amx {
pub mod f32 {
pub type T = f32;
pub const N: usize = 16;
#[inline(always)]
pub unsafe fn scalar_mul(lhs: T, rhs: T) -> T {
lhs * rhs
}
#[inline(always)]
pub unsafe fn scalar_add(lhs: T, rhs: T) -> T {
lhs + rhs
}
#[inline(always)]
pub unsafe fn scalar_mul_add(a: T, b: T, c: T) -> T {
gemm_common::simd::neon_fmaf(a, b, c)
}
microkernel_amx!(f32, ["neon"], 4, x1x16, 1, 16, 1, 16);
microkernel_amx!(f32, ["neon"], 4, x1x32, 1, 32, 2, 16);
microkernel_amx!(f32, ["neon"], 4, x2x16, 2, 16, 1, 16);
microkernel_amx!(f32, ["neon"], 4, x2x32, 2, 32, 2, 16);
microkernel_fn_array! {
[
x1x16,x1x16,x1x16,x1x16,x1x16,x1x16,x1x16,x1x16,x1x16,x1x16,x1x16,x1x16,x1x16,x1x16,x1x16,x1x16,
x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,
],
}
pub const H_M: usize = 0;
pub const H_N: usize = 0;
pub const H_UKR: [[gemm_common::microkernel::HMicroKernelFn<T>; H_N]; H_M] = [];
}
}
#[cfg(target_arch = "wasm32")]
pub mod simd128 {
pub mod f32 {
use super::super::v128_common::f32::*;
use core::arch::wasm32::*;
use core::mem::transmute;
#[inline(always)]
pub unsafe fn mul(lhs: Pack, rhs: Pack) -> Pack {
transmute(f32x4_mul(transmute(lhs), transmute(rhs)))
}
#[inline(always)]
pub unsafe fn add(lhs: Pack, rhs: Pack) -> Pack {
transmute(f32x4_add(transmute(lhs), transmute(rhs)))
}
#[inline(always)]
pub unsafe fn mul_add(a: Pack, b: Pack, c: Pack) -> Pack {
add(c, mul(a, b))
}
#[inline(always)]
pub unsafe fn scalar_mul(lhs: T, rhs: T) -> T {
lhs * rhs
}
#[inline(always)]
pub unsafe fn scalar_add(lhs: T, rhs: T) -> T {
lhs + rhs
}
#[inline(always)]
pub unsafe fn scalar_mul_add(a: T, b: T, c: T) -> T {
a * b + c
}
microkernel!(["simd128"], 2, x1x1, 1, 1);
microkernel!(["simd128"], 2, x1x2, 1, 2);
microkernel!(["simd128"], 2, x1x3, 1, 3);
microkernel!(["simd128"], 2, x1x4, 1, 4);
microkernel!(["simd128"], 2, x2x1, 2, 1);
microkernel!(["simd128"], 2, x2x2, 2, 2);
microkernel!(["simd128"], 2, x2x3, 2, 3);
microkernel!(["simd128"], 2, x2x4, 2, 4);
microkernel!(["simd128"], 2, x3x1, 3, 1);
microkernel!(["simd128"], 2, x3x2, 3, 2);
microkernel!(["simd128"], 2, x3x3, 3, 3);
microkernel!(["simd128"], 2, x3x4, 3, 4);
microkernel_fn_array! {
[x1x1, x1x2, x1x3, x1x4,],
[x2x1, x2x2, x2x3, x2x4,],
[x3x1, x3x2, x3x3, x3x4,],
}
pub const H_M: usize = 0;
pub const H_N: usize = 0;
pub const H_UKR: [[gemm_common::microkernel::HMicroKernelFn<T>; H_N]; H_M] = [];
}
}