gemm_f16/
microkernel.rs

1#[allow(dead_code)]
2mod v128_common {
3    pub mod f16 {
4        pub type T = half::f16;
5        pub const N: usize = 8;
6        pub type Pack = [T; N];
7
8        #[inline(always)]
9        pub unsafe fn splat(value: T) -> Pack {
10            [value, value, value, value, value, value, value, value]
11        }
12    }
13}
14
15#[cfg(target_arch = "aarch64")]
16pub mod neonfp16 {
17    pub mod f16 {
18        pub use super::super::v128_common::f16::*;
19        use core::mem::transmute;
20        use gemm_common::simd::aarch64::*;
21
22        #[inline(always)]
23        pub unsafe fn mul(lhs: Pack, rhs: Pack) -> Pack {
24            transmute(vmulq_f16(transmute(lhs), transmute(rhs)))
25        }
26
27        #[inline(always)]
28        pub unsafe fn add(lhs: Pack, rhs: Pack) -> Pack {
29            transmute(vaddq_f16(transmute(lhs), transmute(rhs)))
30        }
31
32        #[inline(always)]
33        pub unsafe fn mul_add(a: Pack, b: Pack, c: Pack) -> Pack {
34            transmute(vfmaq_f16(transmute(c), transmute(a), transmute(b)))
35        }
36
37        #[inline(always)]
38        pub unsafe fn scalar_mul(lhs: T, rhs: T) -> T {
39            transmute(multiply_f16_fp16(transmute(lhs), transmute(rhs)))
40        }
41
42        #[inline(always)]
43        pub unsafe fn scalar_add(lhs: T, rhs: T) -> T {
44            transmute(add_f16_fp16(transmute(lhs), transmute(rhs)))
45        }
46
47        #[inline(always)]
48        pub unsafe fn scalar_mul_add(a: T, b: T, c: T) -> T {
49            transmute(fmaq_f16(transmute(c), transmute(a), transmute(b)))
50        }
51
52        #[inline(always)]
53        pub unsafe fn mul_add_lane<const LANE: i32>(a: Pack, b: Pack, c: Pack) -> Pack {
54            transmute(vfmaq_laneq_f16::<LANE>(
55                transmute(c),
56                transmute(a),
57                transmute(b),
58            ))
59        }
60
61        pub unsafe fn load<const MR_DIV_N: usize>(dst: *mut Pack, ptr: *const half::f16) {
62            use core::arch::aarch64::*;
63            let ptr = ptr as *const f32;
64            match MR_DIV_N {
65                1 => *(dst as *mut [Pack; 1]) = transmute(vld1q_f32(ptr)),
66                2 => *(dst as *mut [Pack; 2]) = transmute(vld1q_f32_x2(ptr)),
67                3 => *(dst as *mut [Pack; 3]) = transmute(vld1q_f32_x3(ptr)),
68                4 => *(dst as *mut [Pack; 4]) = transmute(vld1q_f32_x4(ptr)),
69                _ => unreachable!(),
70            }
71        }
72
73        microkernel!(["neon,fp16"], 4, x1x1, 1, 1);
74        microkernel!(["neon,fp16"], 4, x1x2, 1, 2);
75        microkernel!(["neon,fp16"], 4, x1x3, 1, 3);
76        microkernel!(["neon,fp16"], 4, x1x4, 1, 4);
77        microkernel!(["neon,fp16"], 4, x1x5, 1, 5);
78        microkernel!(["neon,fp16"], 4, x1x6, 1, 6);
79        microkernel!(["neon,fp16"], 4, x1x7, 1, 7);
80        microkernel!(["neon,fp16"], 4, x1x8, 1, 8, 1, 8);
81
82        microkernel!(["neon,fp16"], 4, x2x1, 2, 1);
83        microkernel!(["neon,fp16"], 4, x2x2, 2, 2);
84        microkernel!(["neon,fp16"], 4, x2x3, 2, 3);
85        microkernel!(["neon,fp16"], 4, x2x4, 2, 4);
86        microkernel!(["neon,fp16"], 4, x2x5, 2, 5);
87        microkernel!(["neon,fp16"], 4, x2x6, 2, 6);
88        microkernel!(["neon,fp16"], 4, x2x7, 2, 7);
89        microkernel!(["neon,fp16"], 4, x2x8, 2, 8, 1, 8);
90
91        microkernel_fn_array! {
92            [x1x1, x1x2, x1x3, x1x4, x1x5, x1x6, x1x7, x1x8, ],
93            [x2x1, x2x2, x2x3, x2x4, x2x5, x2x6, x2x7, x2x8, ],
94        }
95    }
96}
97
98#[cfg(target_arch = "aarch64")]
99pub mod amx {
100    pub mod f16 {
101        pub type T = half::f16;
102        pub const N: usize = 32;
103
104        microkernel_amx!(f16, ["neon"], 4, x1x32, 1, 32, 1, 32);
105
106        microkernel_fn_array! {
107                    [
108        x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,x1x32,
109                    ],
110                }
111    }
112}