1use dyn_stack::{DynStack, MemBuffer, StackReq};
2#[cfg(feature = "std")]
3use gemm_common::gemm::L2_SLAB;
4#[cfg(feature = "rayon")]
5use gemm_common::gemm::{get_threading_threshold, par_for_each};
6
7use gemm_common::{
8 cache::{kernel_params, DivCeil, KernelParams},
9 gemm::CACHELINE_ALIGN,
10 gemv, gevv,
11 microkernel::MicroKernelFn,
12 pack_operands::quick_zero,
13 simd::{MixedSimd, NullaryFnOnce},
14 Parallelism, Ptr,
15};
16type T = half::f16;
17
18#[allow(unused_imports)]
19use gemm_common::simd::*;
20
21#[inline(always)]
22unsafe fn pack_generic_inner_loop<
23 const N: usize,
24 const DST_WIDTH: usize,
25 S: MixedSimd<T, T, T, f32>,
26>(
27 simd: S,
28 mut dst: *mut f32,
29 mut src: *const T,
30 src_rs: isize,
31 src_cs: isize,
32 src_width: usize,
33 k: usize,
34) {
35 assert_eq!(N, S::SIMD_WIDTH);
36
37 if src_rs == 1 {
38 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
39 {
40 use core::any::TypeId;
41
42 let id = TypeId::of::<S>();
43 if id == TypeId::of::<V3>() {
44 let half_simd = V3Half::try_new().unwrap();
45
46 if src_width == 4 {
47 for _ in 0..k {
48 *(dst as *mut [f32; 4]) = half_simd.simd_from_dst(*(src as *const [T; 4]));
49 quick_zero::<f32>(core::slice::from_raw_parts_mut(
50 dst.add(src_width) as _,
51 DST_WIDTH - src_width,
52 ));
53 src = src.wrapping_offset(src_cs);
54 dst = dst.add(DST_WIDTH);
55 }
56 return;
57 }
58 }
59
60 #[cfg(feature = "nightly")]
61 if id == TypeId::of::<V4>() {
62 let quarter_simd = V3Half::try_new().unwrap();
63 let half_simd = V3::try_new().unwrap_unchecked();
64
65 if src_width == 4 {
66 for _ in 0..k {
67 *(dst as *mut [f32; 4]) =
68 <V3Half as MixedSimd<T, T, T, f32>>::simd_from_dst(
69 quarter_simd,
70 *(src as *const [T; 4]),
71 );
72 quick_zero::<f32>(core::slice::from_raw_parts_mut(
73 dst.add(src_width) as _,
74 DST_WIDTH - src_width,
75 ));
76 src = src.wrapping_offset(src_cs);
77 dst = dst.add(DST_WIDTH);
78 }
79 return;
80 }
81
82 if src_width == 8 {
83 for _ in 0..k {
84 *(dst as *mut [f32; 8]) = <V3 as MixedSimd<T, T, T, f32>>::simd_from_dst(
85 half_simd,
86 *(src as *const [T; 8]),
87 );
88 quick_zero::<f32>(core::slice::from_raw_parts_mut(
89 dst.add(src_width) as _,
90 DST_WIDTH - src_width,
91 ));
92 src = src.wrapping_offset(src_cs);
93 dst = dst.add(DST_WIDTH);
94 }
95 return;
96 }
97 }
98 }
99
100 if src_width % N == 0 {
101 for _ in 0..k {
102 for j in 0..src_width / N {
103 let j = j * N;
104 let dst = dst.add(j) as *mut S::AccN;
105 *dst = simd.simd_from_dst(*(src.offset(j as isize * src_rs) as *const S::DstN));
106 }
107 src = src.wrapping_offset(src_cs);
108 dst = dst.add(DST_WIDTH);
109 }
110 return;
111 }
112 }
113
114 for _ in 0..k {
115 for j in 0..src_width {
116 *dst.add(j) = simd.from_lhs(*src.offset(j as isize * src_rs));
117 }
118 quick_zero::<f32>(core::slice::from_raw_parts_mut(
119 dst.add(src_width) as _,
120 DST_WIDTH - src_width,
121 ));
122 src = src.wrapping_offset(src_cs);
123 dst = dst.add(DST_WIDTH);
124 }
125}
126
127#[inline(always)]
128unsafe fn pack_generic<const N: usize, const DST_WIDTH: usize, S: MixedSimd<T, T, T, f32>>(
129 simd: S,
130 m: usize,
131 k: usize,
132 mut dst: *mut f32,
133 mut src: *const T,
134 src_cs: isize,
135 src_rs: isize,
136 dst_stride: usize,
137) {
138 let m_width = m / DST_WIDTH * DST_WIDTH;
139
140 let mut i = 0;
141 while i < m_width {
142 pack_generic_inner_loop::<N, DST_WIDTH, _>(simd, dst, src, src_rs, src_cs, DST_WIDTH, k);
143 src = src.wrapping_offset(src_rs * DST_WIDTH as isize);
144 dst = dst.add(dst_stride);
145
146 i += DST_WIDTH;
147 }
148 if i < m {
149 pack_generic_inner_loop::<N, DST_WIDTH, _>(simd, dst, src, src_rs, src_cs, m - i, k);
150 }
151}
152
153#[inline(never)]
154pub unsafe fn pack_lhs<const N: usize, const MR: usize, S: MixedSimd<T, T, T, f32>>(
155 simd: S,
156 m: usize,
157 k: usize,
158 dst: Ptr<f32>,
159 src: Ptr<T>,
160 src_cs: isize,
161 src_rs: isize,
162 dst_stride: usize,
163) {
164 let dst = dst.0;
165 let src = src.0;
166 struct Impl<const N: usize, const MR: usize, S> {
167 simd: S,
168 m: usize,
169 k: usize,
170 dst: *mut f32,
171 src: *mut T,
172 src_cs: isize,
173 src_rs: isize,
174 dst_stride: usize,
175 }
176 impl<const N: usize, const MR: usize, S: MixedSimd<T, T, T, f32>> NullaryFnOnce for Impl<N, MR, S> {
177 type Output = ();
178
179 #[inline(always)]
180 fn call(self) -> Self::Output {
181 let Self {
182 simd,
183 m,
184 k,
185 dst,
186 src,
187 src_cs,
188 src_rs,
189 dst_stride,
190 } = self;
191 unsafe { pack_generic::<N, MR, _>(simd, m, k, dst, src, src_cs, src_rs, dst_stride) };
192 }
193 }
194
195 simd.vectorize(Impl::<N, MR, _> {
196 simd,
197 m,
198 k,
199 dst,
200 src,
201 src_cs,
202 src_rs,
203 dst_stride,
204 });
205}
206
207#[inline(never)]
208pub unsafe fn pack_rhs<const N: usize, const NR: usize, S: MixedSimd<T, T, T, f32>>(
209 simd: S,
210 n: usize,
211 k: usize,
212 dst: Ptr<f32>,
213 src: Ptr<T>,
214 src_cs: isize,
215 src_rs: isize,
216 dst_stride: usize,
217) {
218 let dst = dst.0;
219 let src = src.0;
220
221 struct Impl<const N: usize, const NR: usize, S> {
222 simd: S,
223 n: usize,
224 k: usize,
225 dst: *mut f32,
226 src: *mut T,
227 src_cs: isize,
228 src_rs: isize,
229 dst_stride: usize,
230 }
231 impl<const N: usize, const NR: usize, S: MixedSimd<T, T, T, f32>> NullaryFnOnce for Impl<N, NR, S> {
232 type Output = ();
233
234 #[inline(always)]
235 fn call(self) -> Self::Output {
236 let Self {
237 simd,
238 n,
239 k,
240 dst,
241 src,
242 src_cs,
243 src_rs,
244 dst_stride,
245 } = self;
246 unsafe { pack_generic::<N, NR, _>(simd, n, k, dst, src, src_rs, src_cs, dst_stride) };
247 }
248 }
249
250 simd.vectorize(Impl::<N, NR, _> {
251 simd,
252 n,
253 k,
254 dst,
255 src,
256 src_cs,
257 src_rs,
258 dst_stride,
259 });
260}
261
262#[inline(always)]
263pub unsafe fn gemm_basic_generic<
264 const N: usize,
265 const MR: usize,
266 const NR: usize,
267 const MR_DIV_N: usize,
268 S: MixedSimd<T, T, T, f32>,
269>(
270 simd: S,
271 m: usize,
272 n: usize,
273 k: usize,
274 dst: *mut T,
275 dst_cs: isize,
276 dst_rs: isize,
277 read_dst: bool,
278 lhs: *const T,
279 lhs_cs: isize,
280 lhs_rs: isize,
281 rhs: *const T,
282 rhs_cs: isize,
283 rhs_rs: isize,
284 mut alpha: T,
285 beta: T,
286 dispatcher: &[[MicroKernelFn<f32>; NR]; MR_DIV_N],
287 parallelism: Parallelism,
288) {
289 if m == 0 || n == 0 {
290 return;
291 }
292 if !read_dst {
293 alpha = T::ZERO;
294 }
295
296 if k == 0 {
297 if alpha == T::ZERO {
298 for j in 0..n {
299 for i in 0..m {
300 *dst.offset(i as isize * dst_rs + j as isize * dst_cs) = T::ZERO;
301 }
302 }
303 return;
304 }
305 if alpha == T::ONE {
306 return;
307 }
308
309 for j in 0..n {
310 for i in 0..m {
311 let dst = dst.offset(i as isize * dst_rs + j as isize * dst_cs);
312 *dst = alpha * *dst;
313 }
314 }
315 return;
316 }
317
318 {
319 if k <= 2 {
320 gevv::gevv(
321 simd,
322 m,
323 n,
324 k,
325 dst,
326 dst_cs,
327 dst_rs,
328 lhs,
329 lhs_cs,
330 lhs_rs,
331 rhs,
332 rhs_cs,
333 rhs_rs,
334 alpha,
335 beta,
336 |a, b, c| {
337 simd.into_dst(simd.mult_add(
338 simd.from_dst(a),
339 simd.from_dst(b),
340 simd.from_dst(c),
341 ))
342 },
343 );
344 return;
345 }
346
347 let alpha = simd.from_dst(alpha);
348 let beta = simd.from_dst(beta);
349 if n <= 1 && lhs_rs == 1 && dst_rs == 1 {
350 gemv::mixed_gemv_colmajor(
351 simd, m, n, k, dst, dst_cs, dst_rs, lhs, lhs_cs, lhs_rs, rhs, rhs_cs, rhs_rs,
352 alpha, beta,
353 );
354 return;
355 }
356 if n <= 1 && lhs_cs == 1 && rhs_rs == 1 {
357 gemv::mixed_gemv_rowmajor(
358 simd, m, n, k, dst, dst_cs, dst_rs, lhs, lhs_cs, lhs_rs, rhs, rhs_cs, rhs_rs,
359 alpha, beta,
360 );
361 return;
362 }
363
364 if m <= 1 && rhs_cs == 1 && dst_cs == 1 {
365 gemv::mixed_gemv_colmajor(
366 simd, n, m, k, dst, dst_rs, dst_cs, rhs, rhs_rs, rhs_cs, lhs, lhs_rs, lhs_cs,
367 alpha, beta,
368 );
369 return;
370 }
371 if m <= 1 && rhs_rs == 1 && lhs_cs == 1 {
372 gemv::mixed_gemv_rowmajor(
373 simd, n, m, k, dst, dst_rs, dst_cs, rhs, rhs_rs, rhs_cs, lhs, lhs_rs, lhs_cs,
374 alpha, beta,
375 );
376 return;
377 }
378 }
379
380 let KernelParams { kc, mc, nc } = kernel_params(m, n, k, MR, NR, core::mem::size_of::<f32>());
381 let nc = if nc > 0 {
382 nc
383 } else {
384 match parallelism {
385 Parallelism::None => 128 * NR,
386 #[cfg(feature = "rayon")]
387 Parallelism::Rayon(_) => n.msrv_next_multiple_of(NR),
388 }
389 };
390
391 let simd_align = CACHELINE_ALIGN;
392
393 let packed_rhs_stride = kc * NR;
394 let packed_lhs_stride = kc * MR;
395
396 let dst = Ptr(dst);
397 let lhs = Ptr(lhs as *mut T);
398 let rhs = Ptr(rhs as *mut T);
399
400 let do_prepack_lhs = m <= 2 * mc && ((m % N != 0) || lhs_rs != 1);
401
402 let rhs_req = StackReq::new_aligned::<f32>(packed_rhs_stride * (nc / NR), simd_align);
403 let lhs_req = StackReq::new_aligned::<f32>(
404 if do_prepack_lhs {
405 packed_lhs_stride * (m.msrv_next_multiple_of(MR) / MR)
406 } else {
407 0
408 },
409 simd_align,
410 );
411
412 let mut mem = MemBuffer::new(rhs_req.and(lhs_req));
413 #[cfg(not(feature = "std"))]
414 let mut l2_slab = MemBuffer::new(StackReq::new_aligned::<f32>(
415 packed_lhs_stride * (mc / MR),
416 simd_align,
417 ));
418
419 let stack = DynStack::new(&mut mem);
420 let (packed_rhs_storage, stack) =
421 stack.make_aligned_uninit::<f32>(packed_rhs_stride * (nc / NR), simd_align);
422
423 let packed_lhs_storage = stack
424 .make_aligned_uninit::<f32>(
425 if do_prepack_lhs {
426 packed_lhs_stride * (m.msrv_next_multiple_of(MR) / MR)
427 } else {
428 0
429 },
430 simd_align,
431 )
432 .0;
433
434 let packed_rhs = Ptr(packed_rhs_storage.as_mut_ptr() as *mut f32);
435 let prepacked_lhs = Ptr(packed_lhs_storage.as_mut_ptr() as *mut f32);
436
437 let packed_rhs_rs = NR as isize;
438 let packed_rhs_cs = 1;
439
440 let mut col_outer = 0;
441 while col_outer != n {
442 let n_chunk = nc.min(n - col_outer);
443
444 let mut alpha = simd.from_lhs(alpha);
445
446 let mut depth_outer = 0;
447 while depth_outer != k {
448 let k_chunk = kc.min(k - depth_outer);
449 let alpha_status = if alpha == 0.0 {
450 0
451 } else if alpha == 1.0 {
452 1
453 } else {
454 2
455 };
456
457 let n_threads = match parallelism {
458 Parallelism::None => 1,
459 #[cfg(feature = "rayon")]
460 Parallelism::Rayon(n_threads) => {
461 let threading_threshold = get_threading_threshold();
462 let total_work = (m * n_chunk).saturating_mul(k_chunk);
463 if total_work < threading_threshold {
464 1
465 } else {
466 if n_threads == 0 {
467 rayon::current_num_threads()
468 } else {
469 n_threads
470 }
471 }
472 }
473 };
474
475 if n_threads <= 1 {
477 pack_rhs::<N, NR, _>(
478 simd,
479 n_chunk,
480 k_chunk,
481 packed_rhs,
482 rhs.wrapping_offset(
483 depth_outer as isize * rhs_rs + col_outer as isize * rhs_cs,
484 ),
485 rhs_cs,
486 rhs_rs,
487 packed_rhs_stride,
488 );
489 } else {
490 #[cfg(feature = "rayon")]
491 {
492 let n_tasks = n_chunk.msrv_div_ceil(NR);
493 let base = n_tasks / n_threads;
494 let rem = n_tasks % n_threads;
495
496 let tid_to_col_inner = |tid: usize| {
497 if tid == n_threads {
498 return n_chunk;
499 }
500
501 let col = if tid < rem {
502 NR * tid * (base + 1)
503 } else {
504 NR * (rem + tid * base)
505 };
506
507 col.min(n_chunk)
508 };
509
510 let func = |tid: usize| {
511 let col_inner = tid_to_col_inner(tid);
512 let ncols = tid_to_col_inner(tid + 1) - col_inner;
513 let j = col_inner / NR;
514
515 if ncols > 0 {
516 pack_rhs::<N, NR, _>(
517 simd,
518 ncols,
519 k_chunk,
520 packed_rhs.wrapping_add(j * packed_rhs_stride),
521 rhs.wrapping_offset(
522 depth_outer as isize * rhs_rs
523 + (col_outer + col_inner) as isize * rhs_cs,
524 ),
525 rhs_cs,
526 rhs_rs,
527 packed_rhs_stride,
528 );
529 }
530 };
531 par_for_each(n_threads, func);
532 }
533
534 #[cfg(not(feature = "rayon"))]
535 unreachable!();
536 }
537 if do_prepack_lhs {
538 pack_lhs::<N, MR, _>(
539 simd,
540 m,
541 k_chunk,
542 prepacked_lhs,
543 lhs.wrapping_offset(depth_outer as isize * lhs_cs),
544 lhs_cs,
545 lhs_rs,
546 packed_lhs_stride,
547 );
548 }
549
550 let n_col_mini_chunks = (n_chunk + (NR - 1)) / NR;
551
552 let mut n_jobs = 0;
553 let mut row_outer = 0;
554 while row_outer != m {
555 let mut m_chunk = mc.min(m - row_outer);
556 if m_chunk > N && !do_prepack_lhs {
557 m_chunk = m_chunk / N * N;
558 }
559 let n_row_mini_chunks = (m_chunk + (MR - 1)) / MR;
560 n_jobs += n_col_mini_chunks * n_row_mini_chunks;
561 row_outer += m_chunk;
562 }
563
564 let func = move |tid, packed_lhs: Ptr<f32>| {
567 let min_jobs_per_thread = n_jobs / n_threads;
568 let rem = n_jobs - n_threads * min_jobs_per_thread;
569
570 let (job_start, job_end) = if tid < rem {
572 let start = tid * (min_jobs_per_thread + 1);
573 (start, start + min_jobs_per_thread + 1)
574 } else {
575 let start = tid * min_jobs_per_thread + rem;
577 (start, start + min_jobs_per_thread)
578 };
579
580 let mut row_outer = 0;
581 let mut job_id = 0;
582 while row_outer != m {
583 let mut m_chunk = mc.min(m - row_outer);
584 if m_chunk > N && !do_prepack_lhs {
585 m_chunk = m_chunk / N * N;
586 }
587 let n_row_mini_chunks = (m_chunk + (MR - 1)) / MR;
588
589 let n_mini_jobs = n_col_mini_chunks * n_row_mini_chunks;
590
591 if job_id >= job_end {
592 return;
593 }
594 if job_id + n_mini_jobs < job_start {
595 row_outer += m_chunk;
596 job_id += n_mini_jobs;
597 continue;
598 }
599
600 let packed_lhs_cs = MR as isize;
601
602 if !do_prepack_lhs {
603 pack_lhs::<N, MR, _>(
604 simd,
605 m_chunk,
606 k_chunk,
607 packed_lhs,
608 lhs.wrapping_offset(
609 row_outer as isize * lhs_rs + depth_outer as isize * lhs_cs,
610 ),
611 lhs_cs,
612 lhs_rs,
613 packed_lhs_stride,
614 );
615 }
616
617 let mut j = 0;
618 while j < n_col_mini_chunks {
619 let mut i = 0;
620 while i < n_row_mini_chunks {
621 let col_inner = NR * j;
622 let n_chunk_inner = NR.min(n_chunk - col_inner);
623
624 let row_inner = MR * i;
625 let m_chunk_inner = MR.min(m_chunk - row_inner);
626
627 let inner_idx = &mut i;
628 if job_id < job_start || job_id >= job_end {
629 job_id += 1;
630 *inner_idx += 1;
631 continue;
632 }
633 job_id += 1;
634
635 let dst = dst.wrapping_offset(
636 (row_outer + row_inner) as isize * dst_rs
637 + (col_outer + col_inner) as isize * dst_cs,
638 );
639
640 let func =
641 dispatcher[(m_chunk_inner + (N - 1)) / N - 1][n_chunk_inner - 1];
642
643 let mut tmp = [[0.0f32; MR]; NR];
644
645 func(
646 m_chunk_inner,
647 n_chunk_inner,
648 k_chunk,
649 tmp.as_mut_ptr() as *mut f32,
650 if do_prepack_lhs {
651 packed_lhs
652 .wrapping_add((i + row_outer / MR) * packed_lhs_stride)
653 .0
654 } else {
655 packed_lhs.wrapping_add(i * packed_lhs_stride).0
656 },
657 packed_rhs.wrapping_add(j * packed_rhs_stride).0,
658 MR as isize,
659 1,
660 packed_lhs_cs,
661 packed_rhs_rs,
662 packed_rhs_cs,
663 0.0,
664 beta.into(),
665 0,
666 false,
667 false,
668 false,
669 packed_lhs.wrapping_add((i + 1) * packed_lhs_stride).0,
670 );
671
672 match alpha_status {
673 0 => {
674 for j in 0..n_chunk_inner {
675 for i in 0..m_chunk_inner {
676 let dst = dst
677 .wrapping_offset(j as isize * dst_cs)
678 .wrapping_offset(i as isize * dst_rs)
679 .0;
680 *dst = simd.into_dst(tmp[j][i]);
681 }
682 }
683 }
684 1 => {
685 for j in 0..n_chunk_inner {
686 for i in 0..m_chunk_inner {
687 let dst = dst
688 .wrapping_offset(j as isize * dst_cs)
689 .wrapping_offset(i as isize * dst_rs)
690 .0;
691 *dst = simd.into_dst(simd.from_dst(*dst) + tmp[j][i]);
692 }
693 }
694 }
695 _ => {
696 for j in 0..n_chunk_inner {
697 for i in 0..m_chunk_inner {
698 let dst = dst
699 .wrapping_offset(j as isize * dst_cs)
700 .wrapping_offset(i as isize * dst_rs)
701 .0;
702 *dst = simd
703 .into_dst(alpha * simd.from_dst(*dst) + tmp[j][i]);
704 }
705 }
706 }
707 }
708
709 i += 1;
710 }
711 j += 1;
712 }
713
714 row_outer += m_chunk;
715 }
716 };
717
718 if do_prepack_lhs {
719 match parallelism {
720 Parallelism::None => func(0, prepacked_lhs),
721 #[cfg(feature = "rayon")]
722 Parallelism::Rayon(_) => {
723 if n_threads == 1 {
724 func(0, prepacked_lhs);
725 } else {
726 par_for_each(n_threads, |tid| func(tid, prepacked_lhs));
727 }
728 }
729 }
730 } else {
731 #[cfg(feature = "std")]
732 let func = |tid: usize| {
733 L2_SLAB.with(|mem| {
734 let mut mem = mem.borrow_mut();
735 let stack = DynStack::new(&mut mem);
736 let (packed_lhs_storage, _) = stack
737 .make_aligned_uninit::<f32>(packed_lhs_stride * (mc / MR), simd_align);
738 let packed_lhs = Ptr(packed_lhs_storage.as_mut_ptr() as *mut f32);
739 func(tid, packed_lhs);
740 });
741 };
742
743 #[cfg(not(feature = "std"))]
744 let mut func = |tid: usize| {
745 let stack = DynStack::new(&mut l2_slab);
746 let (mut packed_lhs_storage, _) =
747 stack.make_aligned_uninit::<f32>(packed_lhs_stride * (mc / MR), simd_align);
748 let packed_lhs = Ptr(packed_lhs_storage.as_mut_ptr() as *mut f32);
749 func(tid, packed_lhs);
750 };
751
752 match parallelism {
753 Parallelism::None => func(0),
754 #[cfg(feature = "rayon")]
755 Parallelism::Rayon(_) => {
756 if n_threads == 1 {
757 func(0);
758 } else {
759 par_for_each(n_threads, func);
760 }
761 }
762 }
763 }
764
765 alpha = 1.0;
766 depth_outer += k_chunk;
767 }
768 col_outer += n_chunk;
769 }
770}
771
772pub mod f16 {
773 use super::gemm_basic_generic;
774 use gemm_common::Parallelism;
775
776 type T = half::f16;
777 type GemmTy = unsafe fn(
778 usize,
779 usize,
780 usize,
781 *mut T,
782 isize,
783 isize,
784 bool,
785 *const T,
786 isize,
787 isize,
788 *const T,
789 isize,
790 isize,
791 T,
792 T,
793 bool,
794 bool,
795 bool,
796 Parallelism,
797 );
798
799 fn init_gemm_fn() -> GemmTy {
800 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
801 {
802 #[cfg(feature = "nightly")]
803 if gemm_common::feature_detected!("avx512f") {
804 return avx512f::gemm_basic;
805 }
806 if gemm_common::feature_detected!("fma") {
807 fma::gemm_basic
808 } else {
809 scalar::gemm_basic
810 }
811 }
812
813 #[cfg(target_arch = "aarch64")]
814 {
815 if gemm_common::feature_detected!("neon") {
816 #[cfg(feature = "experimental-apple-amx")]
817 if gemm_common::cache::HasAmx::get() {
818 return amx::gemm_basic;
819 }
820 if gemm_common::feature_detected!("fp16") {
821 neonfp16::gemm_basic
822 } else {
823 neon::gemm_basic
824 }
825 } else {
826 scalar::gemm_basic
827 }
828 }
829
830 #[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))]
831 {
832 scalar::gemm_basic
833 }
834 }
835
836 static GEMM_PTR: ::core::sync::atomic::AtomicPtr<()> =
837 ::core::sync::atomic::AtomicPtr::new(::core::ptr::null_mut());
838
839 #[inline(never)]
840 fn init_gemm_ptr() -> GemmTy {
841 let gemm_fn = init_gemm_fn();
842 GEMM_PTR.store(gemm_fn as *mut (), ::core::sync::atomic::Ordering::Relaxed);
843 gemm_fn
844 }
845
846 #[inline(always)]
847 pub fn get_gemm_fn() -> GemmTy {
848 let mut gemm_fn = GEMM_PTR.load(::core::sync::atomic::Ordering::Relaxed);
849 if gemm_fn.is_null() {
850 gemm_fn = init_gemm_ptr() as *mut ();
851 }
852 unsafe { ::core::mem::transmute(gemm_fn) }
853 }
854
855 mod scalar {
856 use super::*;
857 use gemm_common::simd::Scalar;
858 use gemm_f32::microkernel::scalar::f32::*;
859 const N: usize = 1;
860
861 #[inline(never)]
862 pub unsafe fn gemm_basic(
863 m: usize,
864 n: usize,
865 k: usize,
866 dst: *mut T,
867 dst_cs: isize,
868 dst_rs: isize,
869 read_dst: bool,
870 lhs: *const T,
871 lhs_cs: isize,
872 lhs_rs: isize,
873 rhs: *const T,
874 rhs_cs: isize,
875 rhs_rs: isize,
876 alpha: T,
877 beta: T,
878 _conj_dst: bool,
879 _conj_lhs: bool,
880 _conj_rhs: bool,
881 parallelism: gemm_common::Parallelism,
882 ) {
883 gemm_basic_generic::<N, { MR_DIV_N * N }, NR, MR_DIV_N, _>(
884 Scalar,
885 m,
886 n,
887 k,
888 dst,
889 dst_cs,
890 dst_rs,
891 read_dst,
892 lhs,
893 lhs_cs,
894 lhs_rs,
895 rhs,
896 rhs_cs,
897 rhs_rs,
898 alpha,
899 beta,
900 &UKR,
901 parallelism,
902 );
903 }
904 }
905
906 #[cfg(target_arch = "aarch64")]
907 mod neon {
908 use super::*;
909 use gemm_common::simd::MixedSimd;
910 use gemm_f32::microkernel::neon::f32::*;
911 const N: usize = 4;
912
913 #[inline(never)]
914 pub unsafe fn gemm_basic(
915 m: usize,
916 n: usize,
917 k: usize,
918 dst: *mut T,
919 dst_cs: isize,
920 dst_rs: isize,
921 read_dst: bool,
922 lhs: *const T,
923 lhs_cs: isize,
924 lhs_rs: isize,
925 rhs: *const T,
926 rhs_cs: isize,
927 rhs_rs: isize,
928 alpha: T,
929 beta: T,
930 _conj_dst: bool,
931 _conj_lhs: bool,
932 _conj_rhs: bool,
933 parallelism: gemm_common::Parallelism,
934 ) {
935 gemm_basic_generic::<N, { MR_DIV_N * N }, NR, MR_DIV_N, _>(
936 gemm_common::simd::Neon::try_new().unwrap(),
937 m,
938 n,
939 k,
940 dst,
941 dst_cs,
942 dst_rs,
943 read_dst,
944 lhs,
945 lhs_cs,
946 lhs_rs,
947 rhs,
948 rhs_cs,
949 rhs_rs,
950 alpha,
951 beta,
952 &UKR,
953 parallelism,
954 );
955 }
956 }
957
958 #[cfg(target_arch = "aarch64")]
959 mod neonfp16 {
960 use crate::microkernel::neonfp16::f16::*;
961 use gemm_common::simd::{MixedSimd, NeonFp16};
962 type T = half::f16;
963
964 #[inline(never)]
965 pub unsafe fn gemm_basic(
966 m: usize,
967 n: usize,
968 k: usize,
969 dst: *mut T,
970 dst_cs: isize,
971 dst_rs: isize,
972 read_dst: bool,
973 lhs: *const T,
974 lhs_cs: isize,
975 lhs_rs: isize,
976 rhs: *const T,
977 rhs_cs: isize,
978 rhs_rs: isize,
979 alpha: T,
980 beta: T,
981 _conj_dst: bool,
982 _conj_lhs: bool,
983 _conj_rhs: bool,
984 parallelism: gemm_common::Parallelism,
985 ) {
986 let simd = <NeonFp16 as MixedSimd<T, T, T, T>>::try_new().unwrap();
987
988 gemm_common::gemm::gemm_basic_generic::<_, _, N, { MR_DIV_N * N }, NR, MR_DIV_N, 0, 0>(
989 simd,
990 m,
991 n,
992 k,
993 dst,
994 dst_cs,
995 dst_rs,
996 read_dst,
997 lhs,
998 lhs_cs,
999 lhs_rs,
1000 rhs,
1001 rhs_cs,
1002 rhs_rs,
1003 alpha,
1004 beta,
1005 false,
1006 false,
1007 false,
1008 move |a, b, c| <NeonFp16 as MixedSimd<T, T, T, T>>::mult_add(simd, a, b, c),
1009 &UKR,
1010 &[],
1011 false,
1012 parallelism,
1013 );
1014 }
1015 }
1016
1017 #[cfg(target_arch = "aarch64")]
1018 #[cfg(feature = "experimental-apple-amx")]
1019 mod amx {
1020 use crate::microkernel::amx::f16::*;
1021 use gemm_common::simd::{MixedSimd, NeonFp16};
1022 type T = half::f16;
1023
1024 #[inline(never)]
1025 pub unsafe fn gemm_basic(
1026 m: usize,
1027 n: usize,
1028 k: usize,
1029 dst: *mut T,
1030 dst_cs: isize,
1031 dst_rs: isize,
1032 read_dst: bool,
1033 lhs: *const T,
1034 lhs_cs: isize,
1035 lhs_rs: isize,
1036 rhs: *const T,
1037 rhs_cs: isize,
1038 rhs_rs: isize,
1039 alpha: T,
1040 beta: T,
1041 _conj_dst: bool,
1042 _conj_lhs: bool,
1043 _conj_rhs: bool,
1044 parallelism: gemm_common::Parallelism,
1045 ) {
1046 let simd = <NeonFp16 as MixedSimd<T, T, T, T>>::try_new().unwrap();
1047
1048 gemm_common::gemm::gemm_basic_generic::<_, _, N, { MR_DIV_N * N }, NR, MR_DIV_N, 0, 0>(
1049 simd,
1050 m,
1051 n,
1052 k,
1053 dst,
1054 dst_cs,
1055 dst_rs,
1056 read_dst,
1057 lhs,
1058 lhs_cs,
1059 lhs_rs,
1060 rhs,
1061 rhs_cs,
1062 rhs_rs,
1063 alpha,
1064 beta,
1065 false,
1066 false,
1067 false,
1068 move |a, b, c| <NeonFp16 as MixedSimd<T, T, T, T>>::mult_add(simd, a, b, c),
1069 &UKR,
1070 &[],
1071 true,
1072 parallelism,
1073 );
1074 }
1075 }
1076
1077 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1078 mod fma {
1079 use super::*;
1080 use gemm_common::simd::V3;
1081 use gemm_f32::microkernel::fma::f32::*;
1082 const N: usize = 8;
1083
1084 #[inline(never)]
1085 pub unsafe fn gemm_basic(
1086 m: usize,
1087 n: usize,
1088 k: usize,
1089 dst: *mut T,
1090 dst_cs: isize,
1091 dst_rs: isize,
1092 read_dst: bool,
1093 lhs: *const T,
1094 lhs_cs: isize,
1095 lhs_rs: isize,
1096 rhs: *const T,
1097 rhs_cs: isize,
1098 rhs_rs: isize,
1099 alpha: T,
1100 beta: T,
1101 _conj_dst: bool,
1102 _conj_lhs: bool,
1103 _conj_rhs: bool,
1104 parallelism: gemm_common::Parallelism,
1105 ) {
1106 gemm_basic_generic::<N, { MR_DIV_N * N }, NR, MR_DIV_N, _>(
1107 V3::try_new().unwrap(),
1108 m,
1109 n,
1110 k,
1111 dst,
1112 dst_cs,
1113 dst_rs,
1114 read_dst,
1115 lhs,
1116 lhs_cs,
1117 lhs_rs,
1118 rhs,
1119 rhs_cs,
1120 rhs_rs,
1121 alpha,
1122 beta,
1123 &UKR,
1124 parallelism,
1125 );
1126 }
1127 }
1128
1129 #[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
1130 mod avx512f {
1131 use super::*;
1132 use gemm_common::simd::V4;
1133 use gemm_f32::microkernel::avx512f::f32::*;
1134 const N: usize = 16;
1135
1136 #[inline(never)]
1137 pub unsafe fn gemm_basic(
1138 m: usize,
1139 n: usize,
1140 k: usize,
1141 dst: *mut T,
1142 dst_cs: isize,
1143 dst_rs: isize,
1144 read_dst: bool,
1145 lhs: *const T,
1146 lhs_cs: isize,
1147 lhs_rs: isize,
1148 rhs: *const T,
1149 rhs_cs: isize,
1150 rhs_rs: isize,
1151 alpha: T,
1152 beta: T,
1153 _conj_dst: bool,
1154 _conj_lhs: bool,
1155 _conj_rhs: bool,
1156 parallelism: gemm_common::Parallelism,
1157 ) {
1158 gemm_basic_generic::<N, { MR_DIV_N * N }, NR, MR_DIV_N, _>(
1159 V4::try_new().unwrap(),
1160 m,
1161 n,
1162 k,
1163 dst,
1164 dst_cs,
1165 dst_rs,
1166 read_dst,
1167 lhs,
1168 lhs_cs,
1169 lhs_rs,
1170 rhs,
1171 rhs_cs,
1172 rhs_rs,
1173 alpha,
1174 beta,
1175 &UKR,
1176 parallelism,
1177 );
1178 }
1179 }
1180}