1pub mod scalar {
2 pub mod f64 {
3 type T = f64;
4 const N: usize = 1;
5 type Pack = [T; N];
6
7 #[inline(always)]
8 unsafe fn splat(value: T) -> Pack {
9 [value]
10 }
11
12 #[inline(always)]
13 unsafe fn mul(lhs: Pack, rhs: Pack) -> Pack {
14 [lhs[0] * rhs[0]]
15 }
16
17 #[inline(always)]
18 unsafe fn add(lhs: Pack, rhs: Pack) -> Pack {
19 [lhs[0] + rhs[0]]
20 }
21
22 #[inline(always)]
23 unsafe fn mul_add(a: Pack, b: Pack, c: Pack) -> Pack {
24 add(mul(a, b), c)
25 }
26
27 #[inline(always)]
28 pub unsafe fn scalar_mul(lhs: T, rhs: T) -> T {
29 lhs * rhs
30 }
31
32 #[inline(always)]
33 pub unsafe fn scalar_add(lhs: T, rhs: T) -> T {
34 lhs + rhs
35 }
36
37 #[inline(always)]
38 pub unsafe fn scalar_mul_add(a: T, b: T, c: T) -> T {
39 a * b + c
40 }
41
42 microkernel!(, 2, x1x1, 1, 1);
43 microkernel!(, 2, x1x2, 1, 2);
44 microkernel!(, 2, x1x3, 1, 3);
45 microkernel!(, 2, x1x4, 1, 4);
46
47 microkernel!(, 2, x2x1, 2, 1);
48 microkernel!(, 2, x2x2, 2, 2);
49 microkernel!(, 2, x2x3, 2, 3);
50 microkernel!(, 2, x2x4, 2, 4);
51
52 microkernel_fn_array! {
53 [x1x1, x1x2, x1x3, x1x4,],
54 [x2x1, x2x2, x2x3, x2x4,],
55 }
56
57 pub const H_M: usize = 0;
58 pub const H_N: usize = 0;
59 pub const H_UKR: [[gemm_common::microkernel::HMicroKernelFn<T>; H_N]; H_M] = [];
60 }
61}
62
63#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
64pub mod fma {
65 pub mod f64 {
66 #[cfg(target_arch = "x86")]
67 use core::arch::x86::*;
68 #[cfg(target_arch = "x86_64")]
69 use core::arch::x86_64::*;
70 use core::mem::transmute;
71
72 use gemm_common::pulp::u64x4;
73
74 type T = f64;
75 const N: usize = 4;
76 type Pack = [T; N];
77
78 #[inline(always)]
79 unsafe fn splat(value: T) -> Pack {
80 transmute(_mm256_set1_pd(value))
81 }
82
83 #[inline(always)]
84 unsafe fn mul(lhs: Pack, rhs: Pack) -> Pack {
85 transmute(_mm256_mul_pd(transmute(lhs), transmute(rhs)))
86 }
87
88 #[inline(always)]
89 unsafe fn add(lhs: Pack, rhs: Pack) -> Pack {
90 transmute(_mm256_add_pd(transmute(lhs), transmute(rhs)))
91 }
92
93 #[inline(always)]
94 unsafe fn mul_add(a: Pack, b: Pack, c: Pack) -> Pack {
95 transmute(_mm256_fmadd_pd(transmute(a), transmute(b), transmute(c)))
96 }
97
98 #[inline(always)]
99 pub unsafe fn scalar_mul(lhs: T, rhs: T) -> T {
100 lhs * rhs
101 }
102
103 #[inline(always)]
104 pub unsafe fn scalar_add(lhs: T, rhs: T) -> T {
105 lhs + rhs
106 }
107
108 #[inline(always)]
109 pub unsafe fn scalar_mul_add(a: T, b: T, c: T) -> T {
110 gemm_common::simd::v3_fma(a, b, c)
111 }
112
113 static U64_MASKS: [u64x4; 5] = [
114 u64x4(0, 0, 0, 0),
115 u64x4(!0, 0, 0, 0),
116 u64x4(!0, !0, 0, 0),
117 u64x4(!0, !0, !0, 0),
118 u64x4(!0, !0, !0, !0),
119 ];
120
121 #[inline(always)]
122 pub unsafe fn partial_load(ptr: *const T, len: usize) -> [T; N] {
123 transmute(_mm256_maskload_pd(
124 ptr,
125 transmute(*(U64_MASKS.as_ptr().add(len))),
126 ))
127 }
128
129 #[inline(always)]
130 pub unsafe fn reduce_sum(x: [T; N]) -> T {
131 let x: __m256d = transmute(x);
132 let x = _mm_add_pd(_mm256_castpd256_pd128(x), _mm256_extractf128_pd::<1>(x));
133 let hi = transmute(_mm_movehl_ps(transmute(x), transmute(x)));
134 let r = _mm_add_sd(x, hi);
135 _mm_cvtsd_f64(r)
136 }
137
138 microkernel!(["fma"], 2, x1x1, 1, 1);
139 microkernel!(["fma"], 2, x1x2, 1, 2);
140 microkernel!(["fma"], 2, x1x3, 1, 3);
141 microkernel!(["fma"], 2, x1x4, 1, 4);
142 microkernel!(["fma"], 2, x1x5, 1, 5);
143 microkernel!(["fma"], 2, x1x6, 1, 6);
144
145 microkernel!(["fma"], 2, x2x1, 2, 1);
146 microkernel!(["fma"], 2, x2x2, 2, 2);
147 microkernel!(["fma"], 2, x2x3, 2, 3);
148 microkernel!(["fma"], 2, x2x4, 2, 4);
149 microkernel!(["fma"], 2, x2x5, 2, 5);
150 microkernel!(["fma"], 2, x2x6, 2, 6);
151
152 microkernel_fn_array! {
153 [x1x1, x1x2, x1x3, x1x4, x1x5, x1x6,],
154 [x2x1, x2x2, x2x3, x2x4, x2x5, x2x6,],
155 }
156
157 horizontal_kernel!(["fma"], hx1x1, 1, 1);
158 horizontal_kernel!(["fma"], hx1x2, 1, 2);
159 horizontal_kernel!(["fma"], hx2x1, 2, 1);
160 horizontal_kernel!(["fma"], hx2x2, 2, 2);
161 hmicrokernel_fn_array! {
162 [hx1x1, hx1x2,],
163 [hx2x1, hx2x2,],
164 }
165 }
166}
167
168#[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
169pub mod avx512f {
170 pub mod f64 {
171 #[cfg(target_arch = "x86")]
172 use core::arch::x86::*;
173 #[cfg(target_arch = "x86_64")]
174 use core::arch::x86_64::*;
175 use core::mem::transmute;
176
177 type T = f64;
178 const N: usize = 8;
179 type Pack = [T; N];
180
181 #[inline(always)]
182 unsafe fn splat(value: T) -> Pack {
183 transmute(_mm512_set1_pd(value))
184 }
185
186 #[inline(always)]
187 unsafe fn mul(lhs: Pack, rhs: Pack) -> Pack {
188 transmute(_mm512_mul_pd(transmute(lhs), transmute(rhs)))
189 }
190
191 #[inline(always)]
192 unsafe fn add(lhs: Pack, rhs: Pack) -> Pack {
193 transmute(_mm512_add_pd(transmute(lhs), transmute(rhs)))
194 }
195
196 #[inline(always)]
197 unsafe fn mul_add(a: Pack, b: Pack, c: Pack) -> Pack {
198 transmute(_mm512_fmadd_pd(transmute(a), transmute(b), transmute(c)))
199 }
200
201 #[inline(always)]
202 pub unsafe fn scalar_mul(lhs: T, rhs: T) -> T {
203 lhs * rhs
204 }
205
206 #[inline(always)]
207 pub unsafe fn scalar_add(lhs: T, rhs: T) -> T {
208 lhs + rhs
209 }
210
211 #[inline(always)]
212 pub unsafe fn scalar_mul_add(a: T, b: T, c: T) -> T {
213 gemm_common::simd::v3_fma(a, b, c)
214 }
215
216 microkernel!(["avx512f"], 4, x1x1, 1, 1);
217 microkernel!(["avx512f"], 4, x1x2, 1, 2);
218 microkernel!(["avx512f"], 4, x1x3, 1, 3);
219 microkernel!(["avx512f"], 4, x1x4, 1, 4);
220 microkernel!(["avx512f"], 4, x1x5, 1, 5);
221 microkernel!(["avx512f"], 4, x1x6, 1, 6);
222
223 microkernel!(["avx512f"], 4, x2x1, 2, 1);
224 microkernel!(["avx512f"], 4, x2x2, 2, 2);
225 microkernel!(["avx512f"], 4, x2x3, 2, 3);
226 microkernel!(["avx512f"], 4, x2x4, 2, 4);
227 microkernel!(["avx512f"], 4, x2x5, 2, 5);
228 microkernel!(["avx512f"], 4, x2x6, 2, 6);
229
230 microkernel!(["avx512f"], 4, x3x1, 3, 1);
231 microkernel!(["avx512f"], 4, x3x2, 3, 2);
232 microkernel!(["avx512f"], 4, x3x3, 3, 3);
233 microkernel!(["avx512f"], 4, x3x4, 3, 4);
234 microkernel!(["avx512f"], 4, x3x5, 3, 5);
235 microkernel!(["avx512f"], 4, x3x6, 3, 6);
236
237 microkernel!(["avx512f"], 4, x4x1, 4, 1);
238 microkernel!(["avx512f"], 4, x4x2, 4, 2);
239 microkernel!(["avx512f"], 4, x4x3, 4, 3);
240 microkernel!(["avx512f"], 4, x4x4, 4, 4);
241 microkernel!(["avx512f"], 4, x4x5, 4, 5);
242 microkernel!(["avx512f"], 4, x4x6, 4, 6);
243
244 microkernel_fn_array! {
245 [x1x1, x1x2, x1x3, x1x4, x1x5, x1x6,],
246 [x2x1, x2x2, x2x3, x2x4, x2x5, x2x6,],
247 [x3x1, x3x2, x3x3, x3x4, x3x5, x3x6,],
248 [x4x1, x4x2, x4x3, x4x4, x4x5, x4x6,],
249 }
250
251 static U64_MASKS: [u8; 9] = [
252 0b00000000, 0b00000001, 0b00000011, 0b00000111, 0b00001111, 0b00011111, 0b00111111, 0b01111111, 0b11111111, ];
262
263 #[inline(always)]
264 pub unsafe fn partial_load(ptr: *const T, len: usize) -> [T; N] {
265 transmute(_mm512_maskz_loadu_pd(*(U64_MASKS.as_ptr().add(len)), ptr))
266 }
267
268 #[inline(always)]
269 pub unsafe fn reduce_sum(x: [T; N]) -> T {
270 let x = transmute(x);
271 let x = _mm256_add_pd(_mm512_castpd512_pd256(x), _mm512_extractf64x4_pd::<1>(x));
272 let x = _mm_add_pd(_mm256_castpd256_pd128(x), _mm256_extractf128_pd::<1>(x));
273 let hi = transmute(_mm_movehl_ps(transmute(x), transmute(x)));
274 let r = _mm_add_sd(x, hi);
275 _mm_cvtsd_f64(r)
276 }
277
278 horizontal_kernel!(["avx512f"], hx1x1, 1, 1);
279 horizontal_kernel!(["avx512f"], hx1x2, 1, 2);
280 horizontal_kernel!(["avx512f"], hx1x3, 1, 3);
281 horizontal_kernel!(["avx512f"], hx1x4, 1, 4);
282 horizontal_kernel!(["avx512f"], hx2x1, 2, 1);
283 horizontal_kernel!(["avx512f"], hx2x2, 2, 2);
284 horizontal_kernel!(["avx512f"], hx2x3, 2, 3);
285 horizontal_kernel!(["avx512f"], hx2x4, 2, 4);
286 horizontal_kernel!(["avx512f"], hx3x1, 3, 1);
287 horizontal_kernel!(["avx512f"], hx3x2, 3, 2);
288 horizontal_kernel!(["avx512f"], hx3x3, 3, 3);
289 horizontal_kernel!(["avx512f"], hx3x4, 3, 4);
290 horizontal_kernel!(["avx512f"], hx4x1, 4, 1);
291 horizontal_kernel!(["avx512f"], hx4x2, 4, 2);
292 horizontal_kernel!(["avx512f"], hx4x3, 4, 3);
293 horizontal_kernel!(["avx512f"], hx4x4, 4, 4);
294 hmicrokernel_fn_array! {
295 [hx1x1, hx1x2, hx1x3, hx1x4,],
296 [hx2x1, hx2x2, hx2x3, hx2x4,],
297 [hx3x1, hx3x2, hx3x3, hx3x4,],
298 [hx4x1, hx4x2, hx4x3, hx4x4,],
299 }
300 }
301}
302
303#[allow(dead_code)]
304mod v128_common {
305 pub mod f64 {
306 pub type T = f64;
307 pub const N: usize = 2;
308 pub type Pack = [T; N];
309
310 #[inline(always)]
311 pub unsafe fn splat(value: T) -> Pack {
312 [value, value]
313 }
314 }
315}
316
317#[cfg(target_arch = "aarch64")]
318pub mod neon {
319 pub mod f64 {
320 use super::super::v128_common::f64::*;
321 use core::arch::aarch64::*;
322 use core::mem::transmute;
323
324 #[cfg(miri)]
325 unsafe fn vfmaq_f64(c: float64x2_t, a: float64x2_t, b: float64x2_t) -> float64x2_t {
326 let c: [f64; 2] = transmute(c);
327 let a: [f64; 2] = transmute(a);
328 let b: [f64; 2] = transmute(b);
329
330 transmute([
331 f64::mul_add(a[0], b[0], c[0]),
332 f64::mul_add(a[1], b[1], c[1]),
333 ])
334 }
335
336 #[inline(always)]
337 pub unsafe fn mul(lhs: Pack, rhs: Pack) -> Pack {
338 transmute(vmulq_f64(transmute(lhs), transmute(rhs)))
339 }
340
341 #[inline(always)]
342 pub unsafe fn add(lhs: Pack, rhs: Pack) -> Pack {
343 transmute(vaddq_f64(transmute(lhs), transmute(rhs)))
344 }
345
346 #[inline(always)]
347 pub unsafe fn mul_add(a: Pack, b: Pack, c: Pack) -> Pack {
348 transmute(vfmaq_f64(transmute(c), transmute(a), transmute(b)))
349 }
350
351 #[inline(always)]
352 pub unsafe fn mul_add_lane<const LANE: i32>(a: Pack, b: Pack, c: Pack) -> Pack {
353 transmute(vfmaq_laneq_f64::<LANE>(
354 transmute(c),
355 transmute(a),
356 transmute(b),
357 ))
358 }
359
360 #[inline(always)]
361 pub unsafe fn scalar_mul(lhs: T, rhs: T) -> T {
362 lhs * rhs
363 }
364
365 #[inline(always)]
366 pub unsafe fn scalar_add(lhs: T, rhs: T) -> T {
367 lhs + rhs
368 }
369
370 #[inline(always)]
371 pub unsafe fn scalar_mul_add(a: T, b: T, c: T) -> T {
372 gemm_common::simd::neon_fma(a, b, c)
373 }
374
375 microkernel!(["neon"], 2, x1x1, 1, 1);
376 microkernel!(["neon"], 2, x1x2, 1, 2, 1, 2);
377 microkernel!(["neon"], 2, x1x3, 1, 3);
378 microkernel!(["neon"], 2, x1x4, 1, 4, 2, 2);
379
380 microkernel!(["neon"], 2, x2x1, 2, 1);
381 microkernel!(["neon"], 2, x2x2, 2, 2, 1, 2);
382 microkernel!(["neon"], 2, x2x3, 2, 3);
383 microkernel!(["neon"], 2, x2x4, 2, 4, 2, 2);
384
385 microkernel!(["neon"], 2, x3x1, 3, 1);
386 microkernel!(["neon"], 2, x3x2, 3, 2, 1, 2);
387 microkernel!(["neon"], 2, x3x3, 3, 3);
388 microkernel!(["neon"], 2, x3x4, 3, 4, 2, 2);
389
390 microkernel!(["neon"], 2, x4x1, 4, 1);
391 microkernel!(["neon"], 2, x4x2, 4, 2, 1, 2);
392 microkernel!(["neon"], 2, x4x3, 4, 3);
393 microkernel!(["neon"], 2, x4x4, 4, 4, 2, 2);
394
395 #[inline(always)]
396 pub unsafe fn load<const MR_DIV_N: usize>(dst: *mut Pack, ptr: *const f64) {
397 match MR_DIV_N {
398 1 => *(dst as *mut [Pack; 1]) = transmute(vld1q_f64(ptr)),
399 2 => *(dst as *mut [Pack; 2]) = transmute(vld1q_f64_x2(ptr)),
400 3 => *(dst as *mut [Pack; 3]) = transmute(vld1q_f64_x3(ptr)),
401 4 => *(dst as *mut [Pack; 4]) = transmute(vld1q_f64_x4(ptr)),
402 _ => unreachable!(),
403 }
404 }
405
406 microkernel_fn_array! {
407 [x1x1, x1x2, x1x3, x1x4, ],
408 [x2x1, x2x2, x2x3, x2x4, ],
409 [x3x1, x3x2, x3x3, x3x4, ],
410 [x4x1, x4x2, x4x3, x4x4, ],
411 }
412
413 pub const H_M: usize = 0;
414 pub const H_N: usize = 0;
415 pub const H_UKR: [[gemm_common::microkernel::HMicroKernelFn<T>; H_N]; H_M] = [];
416 }
417}
418
419#[cfg(target_arch = "aarch64")]
420pub mod amx {
421 pub mod f64 {
422 pub type T = f64;
423 pub const N: usize = 8;
424
425 microkernel_amx!(f64, ["neon"], 4, x1x8, 1, 8, 1, 8);
426 microkernel_amx!(f64, ["neon"], 4, x1x16, 1, 16, 2, 8);
427 microkernel_amx!(f64, ["neon"], 4, x2x8, 2, 8, 1, 8);
428 microkernel_amx!(f64, ["neon"], 4, x2x16, 2, 16, 2, 8);
429
430 microkernel_fn_array! {
431 [
432 x1x8,x1x8,x1x8,x1x8,x1x8,x1x8,x1x8,x1x8,
433 x1x16,x1x16,x1x16,x1x16,x1x16,x1x16,x1x16,x1x16,
434 ],
435 [
436 x2x8,x2x8,x2x8,x2x8,x2x8,x2x8,x2x8,x2x8,
437 x2x16,x2x16,x2x16,x2x16,x2x16,x2x16,x2x16,x2x16,
438 ],
439 }
440
441 pub const H_M: usize = 0;
442 pub const H_N: usize = 0;
443 pub const H_UKR: [[gemm_common::microkernel::HMicroKernelFn<T>; H_N]; H_M] = [];
444 }
445}
446
447#[cfg(target_arch = "wasm32")]
448pub mod simd128 {
449 pub mod f64 {
450 use super::super::v128_common::f64::*;
451 use core::arch::wasm32::*;
452 use core::mem::transmute;
453
454 #[inline(always)]
455 pub unsafe fn mul(lhs: Pack, rhs: Pack) -> Pack {
456 transmute(f64x2_mul(transmute(lhs), transmute(rhs)))
457 }
458
459 #[inline(always)]
460 pub unsafe fn add(lhs: Pack, rhs: Pack) -> Pack {
461 transmute(f64x2_add(transmute(lhs), transmute(rhs)))
462 }
463
464 #[inline(always)]
465 pub unsafe fn mul_add(a: Pack, b: Pack, c: Pack) -> Pack {
466 add(c, mul(a, b))
467 }
468
469 #[inline(always)]
470 pub unsafe fn scalar_mul(lhs: T, rhs: T) -> T {
471 lhs * rhs
472 }
473
474 #[inline(always)]
475 pub unsafe fn scalar_add(lhs: T, rhs: T) -> T {
476 lhs + rhs
477 }
478
479 #[inline(always)]
480 pub unsafe fn scalar_mul_add(a: T, b: T, c: T) -> T {
481 a * b + c
482 }
483
484 microkernel!(["simd128"], 2, x1x1, 1, 1);
485 microkernel!(["simd128"], 2, x1x2, 1, 2);
486 microkernel!(["simd128"], 2, x1x3, 1, 3);
487 microkernel!(["simd128"], 2, x1x4, 1, 4);
488
489 microkernel!(["simd128"], 2, x2x1, 2, 1);
490 microkernel!(["simd128"], 2, x2x2, 2, 2);
491 microkernel!(["simd128"], 2, x2x3, 2, 3);
492 microkernel!(["simd128"], 2, x2x4, 2, 4);
493
494 microkernel!(["simd128"], 2, x3x1, 3, 1);
495 microkernel!(["simd128"], 2, x3x2, 3, 2);
496 microkernel!(["simd128"], 2, x3x3, 3, 3);
497 microkernel!(["simd128"], 2, x3x4, 3, 4);
498
499 microkernel_fn_array! {
500 [x1x1, x1x2, x1x3, x1x4,],
501 [x2x1, x2x2, x2x3, x2x4,],
502 [x3x1, x3x2, x3x3, x3x4,],
503 }
504
505 pub const H_M: usize = 0;
506 pub const H_N: usize = 0;
507 pub const H_UKR: [[gemm_common::microkernel::HMicroKernelFn<T>; H_N]; H_M] = [];
508 }
509}