1#[allow(dead_code)]
2mod v128_common {
3 pub mod f16 {
4 pub type T = half::f16;
5 pub const N: usize = 8;
6 pub type Pack = [T; N];
7
8 #[inline(always)]
9 pub unsafe fn splat(value: T) -> Pack {
10 [value, value, value, value, value, value, value, value]
11 }
12 }
13}
14
15#[cfg(target_arch = "aarch64")]
16pub mod neonfp16 {
17 pub mod f16 {
18 pub use super::super::v128_common::f16::*;
19 use core::mem::transmute;
20 use gemm_common::simd::aarch64::*;
21
22 #[inline(always)]
23 pub unsafe fn mul(lhs: Pack, rhs: Pack) -> Pack {
24 transmute(vmulq_f16(transmute(lhs), transmute(rhs)))
25 }
26
27 #[inline(always)]
28 pub unsafe fn add(lhs: Pack, rhs: Pack) -> Pack {
29 transmute(vaddq_f16(transmute(lhs), transmute(rhs)))
30 }
31
32 #[inline(always)]
33 pub unsafe fn mul_add(a: Pack, b: Pack, c: Pack) -> Pack {
34 transmute(vfmaq_f16(transmute(c), transmute(a), transmute(b)))
35 }
36
37 #[inline(always)]
38 pub unsafe fn scalar_mul(lhs: T, rhs: T) -> T {
39 transmute(multiply_f16_fp16(transmute(lhs), transmute(rhs)))
40 }
41
42 #[inline(always)]
43 pub unsafe fn scalar_add(lhs: T, rhs: T) -> T {
44 transmute(add_f16_fp16(transmute(lhs), transmute(rhs)))
45 }
46
47 #[inline(always)]
48 pub unsafe fn scalar_mul_add(a: T, b: T, c: T) -> T {
49 transmute(fmaq_f16(transmute(c), transmute(a), transmute(b)))
50 }
51
52 #[inline(always)]
53 pub unsafe fn mul_add_lane<const LANE: i32>(a: Pack, b: Pack, c: Pack) -> Pack {
54 transmute(vfmaq_laneq_f16::<LANE>(
55 transmute(c),
56 transmute(a),
57 transmute(b),
58 ))
59 }
60
61 pub unsafe fn load<const MR_DIV_N: usize>(dst: *mut Pack, ptr: *const half::f16) {
62 use core::arch::aarch64::*;
63 let ptr = ptr as *const f32;
64 match MR_DIV_N {
65 1 => *(dst as *mut [Pack; 1]) = transmute(vld1q_f32(ptr)),
66 2 => *(dst as *mut [Pack; 2]) = transmute(vld1q_f32_x2(ptr)),
67 3 => *(dst as *mut [Pack; 3]) = transmute(vld1q_f32_x3(ptr)),
68 4 => *(dst as *mut [Pack; 4]) = transmute(vld1q_f32_x4(ptr)),
69 _ => unreachable!(),
70 }
71 }
72
73 microkernel!(["neon,fp16"], 4, x1x1, 1, 1);
74 microkernel!(["neon,fp16"], 4, x1x2, 1, 2);
75 microkernel!(["neon,fp16"], 4, x1x3, 1, 3);
76 microkernel!(["neon,fp16"], 4, x1x4, 1, 4);
77 microkernel!(["neon,fp16"], 4, x1x5, 1, 5);
78 microkernel!(["neon,fp16"], 4, x1x6, 1, 6);
79 microkernel!(["neon,fp16"], 4, x1x7, 1, 7);
80 microkernel!(["neon,fp16"], 4, x1x8, 1, 8, 1, 8);
81
82 microkernel!(["neon,fp16"], 4, x2x1, 2, 1);
83 microkernel!(["neon,fp16"], 4, x2x2, 2, 2);
84 microkernel!(["neon,fp16"], 4, x2x3, 2, 3);
85 microkernel!(["neon,fp16"], 4, x2x4, 2, 4);
86 microkernel!(["neon,fp16"], 4, x2x5, 2, 5);
87 microkernel!(["neon,fp16"], 4, x2x6, 2, 6);
88 microkernel!(["neon,fp16"], 4, x2x7, 2, 7);
89 microkernel!(["neon,fp16"], 4, x2x8, 2, 8, 1, 8);
90
91 microkernel_fn_array! {
92 [x1x1, x1x2, x1x3, x1x4, x1x5, x1x6, x1x7, x1x8, ],
93 [x2x1, x2x2, x2x3, x2x4, x2x5, x2x6, x2x7, x2x8, ],
94 }
95 }
96}
97
98#[cfg(target_arch = "aarch64")]
99pub mod amx {
100 pub mod f16 {
101 pub type T = half::f16;
102 pub const N: usize = 32;
103
104 microkernel_amx!(f16, ["neon"], 4, x1x32, 1, 32, 1, 32);
105
106 microkernel_fn_array! {
107 [
108 x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,
109 ],
110 }
111 }
112}