gemm_f64/
microkernel.rs

1pub mod scalar {
2    pub mod f64 {
3        type T = f64;
4        const N: usize = 1;
5        type Pack = [T; N];
6
7        #[inline(always)]
8        unsafe fn splat(value: T) -> Pack {
9            [value]
10        }
11
12        #[inline(always)]
13        unsafe fn mul(lhs: Pack, rhs: Pack) -> Pack {
14            [lhs[0] * rhs[0]]
15        }
16
17        #[inline(always)]
18        unsafe fn add(lhs: Pack, rhs: Pack) -> Pack {
19            [lhs[0] + rhs[0]]
20        }
21
22        #[inline(always)]
23        unsafe fn mul_add(a: Pack, b: Pack, c: Pack) -> Pack {
24            add(mul(a, b), c)
25        }
26
27        #[inline(always)]
28        pub unsafe fn scalar_mul(lhs: T, rhs: T) -> T {
29            lhs * rhs
30        }
31
32        #[inline(always)]
33        pub unsafe fn scalar_add(lhs: T, rhs: T) -> T {
34            lhs + rhs
35        }
36
37        #[inline(always)]
38        pub unsafe fn scalar_mul_add(a: T, b: T, c: T) -> T {
39            a * b + c
40        }
41
42        microkernel!(, 2, x1x1, 1, 1);
43        microkernel!(, 2, x1x2, 1, 2);
44        microkernel!(, 2, x1x3, 1, 3);
45        microkernel!(, 2, x1x4, 1, 4);
46
47        microkernel!(, 2, x2x1, 2, 1);
48        microkernel!(, 2, x2x2, 2, 2);
49        microkernel!(, 2, x2x3, 2, 3);
50        microkernel!(, 2, x2x4, 2, 4);
51
52        microkernel_fn_array! {
53            [x1x1, x1x2, x1x3, x1x4,],
54            [x2x1, x2x2, x2x3, x2x4,],
55        }
56
57        pub const H_M: usize = 0;
58        pub const H_N: usize = 0;
59        pub const H_UKR: [[gemm_common::microkernel::HMicroKernelFn<T>; H_N]; H_M] = [];
60    }
61}
62
63#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
64pub mod fma {
65    pub mod f64 {
66        #[cfg(target_arch = "x86")]
67        use core::arch::x86::*;
68        #[cfg(target_arch = "x86_64")]
69        use core::arch::x86_64::*;
70        use core::mem::transmute;
71
72        use gemm_common::pulp::u64x4;
73
74        type T = f64;
75        const N: usize = 4;
76        type Pack = [T; N];
77
78        #[inline(always)]
79        unsafe fn splat(value: T) -> Pack {
80            transmute(_mm256_set1_pd(value))
81        }
82
83        #[inline(always)]
84        unsafe fn mul(lhs: Pack, rhs: Pack) -> Pack {
85            transmute(_mm256_mul_pd(transmute(lhs), transmute(rhs)))
86        }
87
88        #[inline(always)]
89        unsafe fn add(lhs: Pack, rhs: Pack) -> Pack {
90            transmute(_mm256_add_pd(transmute(lhs), transmute(rhs)))
91        }
92
93        #[inline(always)]
94        unsafe fn mul_add(a: Pack, b: Pack, c: Pack) -> Pack {
95            transmute(_mm256_fmadd_pd(transmute(a), transmute(b), transmute(c)))
96        }
97
98        #[inline(always)]
99        pub unsafe fn scalar_mul(lhs: T, rhs: T) -> T {
100            lhs * rhs
101        }
102
103        #[inline(always)]
104        pub unsafe fn scalar_add(lhs: T, rhs: T) -> T {
105            lhs + rhs
106        }
107
108        #[inline(always)]
109        pub unsafe fn scalar_mul_add(a: T, b: T, c: T) -> T {
110            gemm_common::simd::v3_fma(a, b, c)
111        }
112
113        static U64_MASKS: [u64x4; 5] = [
114            u64x4(0, 0, 0, 0),
115            u64x4(!0, 0, 0, 0),
116            u64x4(!0, !0, 0, 0),
117            u64x4(!0, !0, !0, 0),
118            u64x4(!0, !0, !0, !0),
119        ];
120
121        #[inline(always)]
122        pub unsafe fn partial_load(ptr: *const T, len: usize) -> [T; N] {
123            transmute(_mm256_maskload_pd(
124                ptr,
125                transmute(*(U64_MASKS.as_ptr().add(len))),
126            ))
127        }
128
129        #[inline(always)]
130        pub unsafe fn reduce_sum(x: [T; N]) -> T {
131            let x: __m256d = transmute(x);
132            let x = _mm_add_pd(_mm256_castpd256_pd128(x), _mm256_extractf128_pd::<1>(x));
133            let hi = transmute(_mm_movehl_ps(transmute(x), transmute(x)));
134            let r = _mm_add_sd(x, hi);
135            _mm_cvtsd_f64(r)
136        }
137
138        microkernel!(["fma"], 2, x1x1, 1, 1);
139        microkernel!(["fma"], 2, x1x2, 1, 2);
140        microkernel!(["fma"], 2, x1x3, 1, 3);
141        microkernel!(["fma"], 2, x1x4, 1, 4);
142        microkernel!(["fma"], 2, x1x5, 1, 5);
143        microkernel!(["fma"], 2, x1x6, 1, 6);
144
145        microkernel!(["fma"], 2, x2x1, 2, 1);
146        microkernel!(["fma"], 2, x2x2, 2, 2);
147        microkernel!(["fma"], 2, x2x3, 2, 3);
148        microkernel!(["fma"], 2, x2x4, 2, 4);
149        microkernel!(["fma"], 2, x2x5, 2, 5);
150        microkernel!(["fma"], 2, x2x6, 2, 6);
151
152        microkernel_fn_array! {
153            [x1x1, x1x2, x1x3, x1x4, x1x5, x1x6,],
154            [x2x1, x2x2, x2x3, x2x4, x2x5, x2x6,],
155        }
156
157        horizontal_kernel!(["fma"], hx1x1, 1, 1);
158        horizontal_kernel!(["fma"], hx1x2, 1, 2);
159        horizontal_kernel!(["fma"], hx2x1, 2, 1);
160        horizontal_kernel!(["fma"], hx2x2, 2, 2);
161        hmicrokernel_fn_array! {
162            [hx1x1, hx1x2,],
163            [hx2x1, hx2x2,],
164        }
165    }
166}
167
168#[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
169pub mod avx512f {
170    pub mod f64 {
171        #[cfg(target_arch = "x86")]
172        use core::arch::x86::*;
173        #[cfg(target_arch = "x86_64")]
174        use core::arch::x86_64::*;
175        use core::mem::transmute;
176
177        type T = f64;
178        const N: usize = 8;
179        type Pack = [T; N];
180
181        #[inline(always)]
182        unsafe fn splat(value: T) -> Pack {
183            transmute(_mm512_set1_pd(value))
184        }
185
186        #[inline(always)]
187        unsafe fn mul(lhs: Pack, rhs: Pack) -> Pack {
188            transmute(_mm512_mul_pd(transmute(lhs), transmute(rhs)))
189        }
190
191        #[inline(always)]
192        unsafe fn add(lhs: Pack, rhs: Pack) -> Pack {
193            transmute(_mm512_add_pd(transmute(lhs), transmute(rhs)))
194        }
195
196        #[inline(always)]
197        unsafe fn mul_add(a: Pack, b: Pack, c: Pack) -> Pack {
198            transmute(_mm512_fmadd_pd(transmute(a), transmute(b), transmute(c)))
199        }
200
201        #[inline(always)]
202        pub unsafe fn scalar_mul(lhs: T, rhs: T) -> T {
203            lhs * rhs
204        }
205
206        #[inline(always)]
207        pub unsafe fn scalar_add(lhs: T, rhs: T) -> T {
208            lhs + rhs
209        }
210
211        #[inline(always)]
212        pub unsafe fn scalar_mul_add(a: T, b: T, c: T) -> T {
213            gemm_common::simd::v3_fma(a, b, c)
214        }
215
216        microkernel!(["avx512f"], 4, x1x1, 1, 1);
217        microkernel!(["avx512f"], 4, x1x2, 1, 2);
218        microkernel!(["avx512f"], 4, x1x3, 1, 3);
219        microkernel!(["avx512f"], 4, x1x4, 1, 4);
220        microkernel!(["avx512f"], 4, x1x5, 1, 5);
221        microkernel!(["avx512f"], 4, x1x6, 1, 6);
222
223        microkernel!(["avx512f"], 4, x2x1, 2, 1);
224        microkernel!(["avx512f"], 4, x2x2, 2, 2);
225        microkernel!(["avx512f"], 4, x2x3, 2, 3);
226        microkernel!(["avx512f"], 4, x2x4, 2, 4);
227        microkernel!(["avx512f"], 4, x2x5, 2, 5);
228        microkernel!(["avx512f"], 4, x2x6, 2, 6);
229
230        microkernel!(["avx512f"], 4, x3x1, 3, 1);
231        microkernel!(["avx512f"], 4, x3x2, 3, 2);
232        microkernel!(["avx512f"], 4, x3x3, 3, 3);
233        microkernel!(["avx512f"], 4, x3x4, 3, 4);
234        microkernel!(["avx512f"], 4, x3x5, 3, 5);
235        microkernel!(["avx512f"], 4, x3x6, 3, 6);
236
237        microkernel!(["avx512f"], 4, x4x1, 4, 1);
238        microkernel!(["avx512f"], 4, x4x2, 4, 2);
239        microkernel!(["avx512f"], 4, x4x3, 4, 3);
240        microkernel!(["avx512f"], 4, x4x4, 4, 4);
241        microkernel!(["avx512f"], 4, x4x5, 4, 5);
242        microkernel!(["avx512f"], 4, x4x6, 4, 6);
243
244        microkernel_fn_array! {
245            [x1x1, x1x2, x1x3, x1x4, x1x5, x1x6,],
246            [x2x1, x2x2, x2x3, x2x4, x2x5, x2x6,],
247            [x3x1, x3x2, x3x3, x3x4, x3x5, x3x6,],
248            [x4x1, x4x2, x4x3, x4x4, x4x5, x4x6,],
249        }
250
251        static U64_MASKS: [u8; 9] = [
252            0b00000000, //
253            0b00000001, //
254            0b00000011, //
255            0b00000111, //
256            0b00001111, //
257            0b00011111, //
258            0b00111111, //
259            0b01111111, //
260            0b11111111, //
261        ];
262
263        #[inline(always)]
264        pub unsafe fn partial_load(ptr: *const T, len: usize) -> [T; N] {
265            transmute(_mm512_maskz_loadu_pd(*(U64_MASKS.as_ptr().add(len)), ptr))
266        }
267
268        #[inline(always)]
269        pub unsafe fn reduce_sum(x: [T; N]) -> T {
270            let x = transmute(x);
271            let x = _mm256_add_pd(_mm512_castpd512_pd256(x), _mm512_extractf64x4_pd::<1>(x));
272            let x = _mm_add_pd(_mm256_castpd256_pd128(x), _mm256_extractf128_pd::<1>(x));
273            let hi = transmute(_mm_movehl_ps(transmute(x), transmute(x)));
274            let r = _mm_add_sd(x, hi);
275            _mm_cvtsd_f64(r)
276        }
277
278        horizontal_kernel!(["avx512f"], hx1x1, 1, 1);
279        horizontal_kernel!(["avx512f"], hx1x2, 1, 2);
280        horizontal_kernel!(["avx512f"], hx1x3, 1, 3);
281        horizontal_kernel!(["avx512f"], hx1x4, 1, 4);
282        horizontal_kernel!(["avx512f"], hx2x1, 2, 1);
283        horizontal_kernel!(["avx512f"], hx2x2, 2, 2);
284        horizontal_kernel!(["avx512f"], hx2x3, 2, 3);
285        horizontal_kernel!(["avx512f"], hx2x4, 2, 4);
286        horizontal_kernel!(["avx512f"], hx3x1, 3, 1);
287        horizontal_kernel!(["avx512f"], hx3x2, 3, 2);
288        horizontal_kernel!(["avx512f"], hx3x3, 3, 3);
289        horizontal_kernel!(["avx512f"], hx3x4, 3, 4);
290        horizontal_kernel!(["avx512f"], hx4x1, 4, 1);
291        horizontal_kernel!(["avx512f"], hx4x2, 4, 2);
292        horizontal_kernel!(["avx512f"], hx4x3, 4, 3);
293        horizontal_kernel!(["avx512f"], hx4x4, 4, 4);
294        hmicrokernel_fn_array! {
295            [hx1x1, hx1x2, hx1x3, hx1x4,],
296            [hx2x1, hx2x2, hx2x3, hx2x4,],
297            [hx3x1, hx3x2, hx3x3, hx3x4,],
298            [hx4x1, hx4x2, hx4x3, hx4x4,],
299        }
300    }
301}
302
303#[allow(dead_code)]
304mod v128_common {
305    pub mod f64 {
306        pub type T = f64;
307        pub const N: usize = 2;
308        pub type Pack = [T; N];
309
310        #[inline(always)]
311        pub unsafe fn splat(value: T) -> Pack {
312            [value, value]
313        }
314    }
315}
316
317#[cfg(target_arch = "aarch64")]
318pub mod neon {
319    pub mod f64 {
320        use super::super::v128_common::f64::*;
321        use core::arch::aarch64::*;
322        use core::mem::transmute;
323
324        #[cfg(miri)]
325        unsafe fn vfmaq_f64(c: float64x2_t, a: float64x2_t, b: float64x2_t) -> float64x2_t {
326            let c: [f64; 2] = transmute(c);
327            let a: [f64; 2] = transmute(a);
328            let b: [f64; 2] = transmute(b);
329
330            transmute([
331                f64::mul_add(a[0], b[0], c[0]),
332                f64::mul_add(a[1], b[1], c[1]),
333            ])
334        }
335
336        #[inline(always)]
337        pub unsafe fn mul(lhs: Pack, rhs: Pack) -> Pack {
338            transmute(vmulq_f64(transmute(lhs), transmute(rhs)))
339        }
340
341        #[inline(always)]
342        pub unsafe fn add(lhs: Pack, rhs: Pack) -> Pack {
343            transmute(vaddq_f64(transmute(lhs), transmute(rhs)))
344        }
345
346        #[inline(always)]
347        pub unsafe fn mul_add(a: Pack, b: Pack, c: Pack) -> Pack {
348            transmute(vfmaq_f64(transmute(c), transmute(a), transmute(b)))
349        }
350
351        #[inline(always)]
352        pub unsafe fn mul_add_lane<const LANE: i32>(a: Pack, b: Pack, c: Pack) -> Pack {
353            transmute(vfmaq_laneq_f64::<LANE>(
354                transmute(c),
355                transmute(a),
356                transmute(b),
357            ))
358        }
359
360        #[inline(always)]
361        pub unsafe fn scalar_mul(lhs: T, rhs: T) -> T {
362            lhs * rhs
363        }
364
365        #[inline(always)]
366        pub unsafe fn scalar_add(lhs: T, rhs: T) -> T {
367            lhs + rhs
368        }
369
370        #[inline(always)]
371        pub unsafe fn scalar_mul_add(a: T, b: T, c: T) -> T {
372            gemm_common::simd::neon_fma(a, b, c)
373        }
374
375        microkernel!(["neon"], 2, x1x1, 1, 1);
376        microkernel!(["neon"], 2, x1x2, 1, 2, 1, 2);
377        microkernel!(["neon"], 2, x1x3, 1, 3);
378        microkernel!(["neon"], 2, x1x4, 1, 4, 2, 2);
379
380        microkernel!(["neon"], 2, x2x1, 2, 1);
381        microkernel!(["neon"], 2, x2x2, 2, 2, 1, 2);
382        microkernel!(["neon"], 2, x2x3, 2, 3);
383        microkernel!(["neon"], 2, x2x4, 2, 4, 2, 2);
384
385        microkernel!(["neon"], 2, x3x1, 3, 1);
386        microkernel!(["neon"], 2, x3x2, 3, 2, 1, 2);
387        microkernel!(["neon"], 2, x3x3, 3, 3);
388        microkernel!(["neon"], 2, x3x4, 3, 4, 2, 2);
389
390        microkernel!(["neon"], 2, x4x1, 4, 1);
391        microkernel!(["neon"], 2, x4x2, 4, 2, 1, 2);
392        microkernel!(["neon"], 2, x4x3, 4, 3);
393        microkernel!(["neon"], 2, x4x4, 4, 4, 2, 2);
394
395        #[inline(always)]
396        pub unsafe fn load<const MR_DIV_N: usize>(dst: *mut Pack, ptr: *const f64) {
397            match MR_DIV_N {
398                1 => *(dst as *mut [Pack; 1]) = transmute(vld1q_f64(ptr)),
399                2 => *(dst as *mut [Pack; 2]) = transmute(vld1q_f64_x2(ptr)),
400                3 => *(dst as *mut [Pack; 3]) = transmute(vld1q_f64_x3(ptr)),
401                4 => *(dst as *mut [Pack; 4]) = transmute(vld1q_f64_x4(ptr)),
402                _ => unreachable!(),
403            }
404        }
405
406        microkernel_fn_array! {
407            [x1x1, x1x2, x1x3, x1x4, ],
408            [x2x1, x2x2, x2x3, x2x4, ],
409            [x3x1, x3x2, x3x3, x3x4, ],
410            [x4x1, x4x2, x4x3, x4x4, ],
411        }
412
413        pub const H_M: usize = 0;
414        pub const H_N: usize = 0;
415        pub const H_UKR: [[gemm_common::microkernel::HMicroKernelFn<T>; H_N]; H_M] = [];
416    }
417}
418
419#[cfg(target_arch = "aarch64")]
420pub mod amx {
421    pub mod f64 {
422        pub type T = f64;
423        pub const N: usize = 8;
424
425        microkernel_amx!(f64, ["neon"], 4, x1x8, 1, 8, 1, 8);
426        microkernel_amx!(f64, ["neon"], 4, x1x16, 1, 16, 2, 8);
427        microkernel_amx!(f64, ["neon"], 4, x2x8, 2, 8, 1, 8);
428        microkernel_amx!(f64, ["neon"], 4, x2x16, 2, 16, 2, 8);
429
430        microkernel_fn_array! {
431            [
432                x1x8,x1x8,x1x8,x1x8,x1x8,x1x8,x1x8,x1x8,
433                x1x16,x1x16,x1x16,x1x16,x1x16,x1x16,x1x16,x1x16,
434            ],
435            [
436                x2x8,x2x8,x2x8,x2x8,x2x8,x2x8,x2x8,x2x8,
437                x2x16,x2x16,x2x16,x2x16,x2x16,x2x16,x2x16,x2x16,
438            ],
439        }
440
441        pub const H_M: usize = 0;
442        pub const H_N: usize = 0;
443        pub const H_UKR: [[gemm_common::microkernel::HMicroKernelFn<T>; H_N]; H_M] = [];
444    }
445}
446
447#[cfg(target_arch = "wasm32")]
448pub mod simd128 {
449    pub mod f64 {
450        use super::super::v128_common::f64::*;
451        use core::arch::wasm32::*;
452        use core::mem::transmute;
453
454        #[inline(always)]
455        pub unsafe fn mul(lhs: Pack, rhs: Pack) -> Pack {
456            transmute(f64x2_mul(transmute(lhs), transmute(rhs)))
457        }
458
459        #[inline(always)]
460        pub unsafe fn add(lhs: Pack, rhs: Pack) -> Pack {
461            transmute(f64x2_add(transmute(lhs), transmute(rhs)))
462        }
463
464        #[inline(always)]
465        pub unsafe fn mul_add(a: Pack, b: Pack, c: Pack) -> Pack {
466            add(c, mul(a, b))
467        }
468
469        #[inline(always)]
470        pub unsafe fn scalar_mul(lhs: T, rhs: T) -> T {
471            lhs * rhs
472        }
473
474        #[inline(always)]
475        pub unsafe fn scalar_add(lhs: T, rhs: T) -> T {
476            lhs + rhs
477        }
478
479        #[inline(always)]
480        pub unsafe fn scalar_mul_add(a: T, b: T, c: T) -> T {
481            a * b + c
482        }
483
484        microkernel!(["simd128"], 2, x1x1, 1, 1);
485        microkernel!(["simd128"], 2, x1x2, 1, 2);
486        microkernel!(["simd128"], 2, x1x3, 1, 3);
487        microkernel!(["simd128"], 2, x1x4, 1, 4);
488
489        microkernel!(["simd128"], 2, x2x1, 2, 1);
490        microkernel!(["simd128"], 2, x2x2, 2, 2);
491        microkernel!(["simd128"], 2, x2x3, 2, 3);
492        microkernel!(["simd128"], 2, x2x4, 2, 4);
493
494        microkernel!(["simd128"], 2, x3x1, 3, 1);
495        microkernel!(["simd128"], 2, x3x2, 3, 2);
496        microkernel!(["simd128"], 2, x3x3, 3, 3);
497        microkernel!(["simd128"], 2, x3x4, 3, 4);
498
499        microkernel_fn_array! {
500            [x1x1, x1x2, x1x3, x1x4,],
501            [x2x1, x2x2, x2x3, x2x4,],
502            [x3x1, x3x2, x3x3, x3x4,],
503        }
504
505        pub const H_M: usize = 0;
506        pub const H_N: usize = 0;
507        pub const H_UKR: [[gemm_common::microkernel::HMicroKernelFn<T>; H_N]; H_M] = [];
508    }
509}