gemm_f16/
gemm.rs

1use dyn_stack::{DynStack, MemBuffer, StackReq};
2#[cfg(feature = "std")]
3use gemm_common::gemm::L2_SLAB;
4#[cfg(feature = "rayon")]
5use gemm_common::gemm::{get_threading_threshold, par_for_each};
6
7use gemm_common::{
8    cache::{kernel_params, DivCeil, KernelParams},
9    gemm::CACHELINE_ALIGN,
10    gemv, gevv,
11    microkernel::MicroKernelFn,
12    pack_operands::quick_zero,
13    simd::{MixedSimd, NullaryFnOnce},
14    Parallelism, Ptr,
15};
16type T = half::f16;
17
18#[allow(unused_imports)]
19use gemm_common::simd::*;
20
21#[inline(always)]
22unsafe fn pack_generic_inner_loop<
23    const N: usize,
24    const DST_WIDTH: usize,
25    S: MixedSimd<T, T, T, f32>,
26>(
27    simd: S,
28    mut dst: *mut f32,
29    mut src: *const T,
30    src_rs: isize,
31    src_cs: isize,
32    src_width: usize,
33    k: usize,
34) {
35    assert_eq!(N, S::SIMD_WIDTH);
36
37    if src_rs == 1 {
38        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
39        {
40            use core::any::TypeId;
41
42            let id = TypeId::of::<S>();
43            if id == TypeId::of::<V3>() {
44                let half_simd = V3Half::try_new().unwrap();
45
46                if src_width == 4 {
47                    for _ in 0..k {
48                        *(dst as *mut [f32; 4]) = half_simd.simd_from_dst(*(src as *const [T; 4]));
49                        quick_zero::<f32>(core::slice::from_raw_parts_mut(
50                            dst.add(src_width) as _,
51                            DST_WIDTH - src_width,
52                        ));
53                        src = src.wrapping_offset(src_cs);
54                        dst = dst.add(DST_WIDTH);
55                    }
56                    return;
57                }
58            }
59
60            #[cfg(feature = "nightly")]
61            if id == TypeId::of::<V4>() {
62                let quarter_simd = V3Half::try_new().unwrap();
63                let half_simd = V3::try_new().unwrap_unchecked();
64
65                if src_width == 4 {
66                    for _ in 0..k {
67                        *(dst as *mut [f32; 4]) =
68                            <V3Half as MixedSimd<T, T, T, f32>>::simd_from_dst(
69                                quarter_simd,
70                                *(src as *const [T; 4]),
71                            );
72                        quick_zero::<f32>(core::slice::from_raw_parts_mut(
73                            dst.add(src_width) as _,
74                            DST_WIDTH - src_width,
75                        ));
76                        src = src.wrapping_offset(src_cs);
77                        dst = dst.add(DST_WIDTH);
78                    }
79                    return;
80                }
81
82                if src_width == 8 {
83                    for _ in 0..k {
84                        *(dst as *mut [f32; 8]) = <V3 as MixedSimd<T, T, T, f32>>::simd_from_dst(
85                            half_simd,
86                            *(src as *const [T; 8]),
87                        );
88                        quick_zero::<f32>(core::slice::from_raw_parts_mut(
89                            dst.add(src_width) as _,
90                            DST_WIDTH - src_width,
91                        ));
92                        src = src.wrapping_offset(src_cs);
93                        dst = dst.add(DST_WIDTH);
94                    }
95                    return;
96                }
97            }
98        }
99
100        if src_width % N == 0 {
101            for _ in 0..k {
102                for j in 0..src_width / N {
103                    let j = j * N;
104                    let dst = dst.add(j) as *mut S::AccN;
105                    *dst = simd.simd_from_dst(*(src.offset(j as isize * src_rs) as *const S::DstN));
106                }
107                src = src.wrapping_offset(src_cs);
108                dst = dst.add(DST_WIDTH);
109            }
110            return;
111        }
112    }
113
114    for _ in 0..k {
115        for j in 0..src_width {
116            *dst.add(j) = simd.from_lhs(*src.offset(j as isize * src_rs));
117        }
118        quick_zero::<f32>(core::slice::from_raw_parts_mut(
119            dst.add(src_width) as _,
120            DST_WIDTH - src_width,
121        ));
122        src = src.wrapping_offset(src_cs);
123        dst = dst.add(DST_WIDTH);
124    }
125}
126
127#[inline(always)]
128unsafe fn pack_generic<const N: usize, const DST_WIDTH: usize, S: MixedSimd<T, T, T, f32>>(
129    simd: S,
130    m: usize,
131    k: usize,
132    mut dst: *mut f32,
133    mut src: *const T,
134    src_cs: isize,
135    src_rs: isize,
136    dst_stride: usize,
137) {
138    let m_width = m / DST_WIDTH * DST_WIDTH;
139
140    let mut i = 0;
141    while i < m_width {
142        pack_generic_inner_loop::<N, DST_WIDTH, _>(simd, dst, src, src_rs, src_cs, DST_WIDTH, k);
143        src = src.wrapping_offset(src_rs * DST_WIDTH as isize);
144        dst = dst.add(dst_stride);
145
146        i += DST_WIDTH;
147    }
148    if i < m {
149        pack_generic_inner_loop::<N, DST_WIDTH, _>(simd, dst, src, src_rs, src_cs, m - i, k);
150    }
151}
152
153#[inline(never)]
154pub unsafe fn pack_lhs<const N: usize, const MR: usize, S: MixedSimd<T, T, T, f32>>(
155    simd: S,
156    m: usize,
157    k: usize,
158    dst: Ptr<f32>,
159    src: Ptr<T>,
160    src_cs: isize,
161    src_rs: isize,
162    dst_stride: usize,
163) {
164    let dst = dst.0;
165    let src = src.0;
166    struct Impl<const N: usize, const MR: usize, S> {
167        simd: S,
168        m: usize,
169        k: usize,
170        dst: *mut f32,
171        src: *mut T,
172        src_cs: isize,
173        src_rs: isize,
174        dst_stride: usize,
175    }
176    impl<const N: usize, const MR: usize, S: MixedSimd<T, T, T, f32>> NullaryFnOnce for Impl<N, MR, S> {
177        type Output = ();
178
179        #[inline(always)]
180        fn call(self) -> Self::Output {
181            let Self {
182                simd,
183                m,
184                k,
185                dst,
186                src,
187                src_cs,
188                src_rs,
189                dst_stride,
190            } = self;
191            unsafe { pack_generic::<N, MR, _>(simd, m, k, dst, src, src_cs, src_rs, dst_stride) };
192        }
193    }
194
195    simd.vectorize(Impl::<N, MR, _> {
196        simd,
197        m,
198        k,
199        dst,
200        src,
201        src_cs,
202        src_rs,
203        dst_stride,
204    });
205}
206
207#[inline(never)]
208pub unsafe fn pack_rhs<const N: usize, const NR: usize, S: MixedSimd<T, T, T, f32>>(
209    simd: S,
210    n: usize,
211    k: usize,
212    dst: Ptr<f32>,
213    src: Ptr<T>,
214    src_cs: isize,
215    src_rs: isize,
216    dst_stride: usize,
217) {
218    let dst = dst.0;
219    let src = src.0;
220
221    struct Impl<const N: usize, const NR: usize, S> {
222        simd: S,
223        n: usize,
224        k: usize,
225        dst: *mut f32,
226        src: *mut T,
227        src_cs: isize,
228        src_rs: isize,
229        dst_stride: usize,
230    }
231    impl<const N: usize, const NR: usize, S: MixedSimd<T, T, T, f32>> NullaryFnOnce for Impl<N, NR, S> {
232        type Output = ();
233
234        #[inline(always)]
235        fn call(self) -> Self::Output {
236            let Self {
237                simd,
238                n,
239                k,
240                dst,
241                src,
242                src_cs,
243                src_rs,
244                dst_stride,
245            } = self;
246            unsafe { pack_generic::<N, NR, _>(simd, n, k, dst, src, src_rs, src_cs, dst_stride) };
247        }
248    }
249
250    simd.vectorize(Impl::<N, NR, _> {
251        simd,
252        n,
253        k,
254        dst,
255        src,
256        src_cs,
257        src_rs,
258        dst_stride,
259    });
260}
261
262#[inline(always)]
263pub unsafe fn gemm_basic_generic<
264    const N: usize,
265    const MR: usize,
266    const NR: usize,
267    const MR_DIV_N: usize,
268    S: MixedSimd<T, T, T, f32>,
269>(
270    simd: S,
271    m: usize,
272    n: usize,
273    k: usize,
274    dst: *mut T,
275    dst_cs: isize,
276    dst_rs: isize,
277    read_dst: bool,
278    lhs: *const T,
279    lhs_cs: isize,
280    lhs_rs: isize,
281    rhs: *const T,
282    rhs_cs: isize,
283    rhs_rs: isize,
284    mut alpha: T,
285    beta: T,
286    dispatcher: &[[MicroKernelFn<f32>; NR]; MR_DIV_N],
287    parallelism: Parallelism,
288) {
289    if m == 0 || n == 0 {
290        return;
291    }
292    if !read_dst {
293        alpha = T::ZERO;
294    }
295
296    if k == 0 {
297        if alpha == T::ZERO {
298            for j in 0..n {
299                for i in 0..m {
300                    *dst.offset(i as isize * dst_rs + j as isize * dst_cs) = T::ZERO;
301                }
302            }
303            return;
304        }
305        if alpha == T::ONE {
306            return;
307        }
308
309        for j in 0..n {
310            for i in 0..m {
311                let dst = dst.offset(i as isize * dst_rs + j as isize * dst_cs);
312                *dst = alpha * *dst;
313            }
314        }
315        return;
316    }
317
318    {
319        if k <= 2 {
320            gevv::gevv(
321                simd,
322                m,
323                n,
324                k,
325                dst,
326                dst_cs,
327                dst_rs,
328                lhs,
329                lhs_cs,
330                lhs_rs,
331                rhs,
332                rhs_cs,
333                rhs_rs,
334                alpha,
335                beta,
336                |a, b, c| {
337                    simd.into_dst(simd.mult_add(
338                        simd.from_dst(a),
339                        simd.from_dst(b),
340                        simd.from_dst(c),
341                    ))
342                },
343            );
344            return;
345        }
346
347        let alpha = simd.from_dst(alpha);
348        let beta = simd.from_dst(beta);
349        if n <= 1 && lhs_rs == 1 && dst_rs == 1 {
350            gemv::mixed_gemv_colmajor(
351                simd, m, n, k, dst, dst_cs, dst_rs, lhs, lhs_cs, lhs_rs, rhs, rhs_cs, rhs_rs,
352                alpha, beta,
353            );
354            return;
355        }
356        if n <= 1 && lhs_cs == 1 && rhs_rs == 1 {
357            gemv::mixed_gemv_rowmajor(
358                simd, m, n, k, dst, dst_cs, dst_rs, lhs, lhs_cs, lhs_rs, rhs, rhs_cs, rhs_rs,
359                alpha, beta,
360            );
361            return;
362        }
363
364        if m <= 1 && rhs_cs == 1 && dst_cs == 1 {
365            gemv::mixed_gemv_colmajor(
366                simd, n, m, k, dst, dst_rs, dst_cs, rhs, rhs_rs, rhs_cs, lhs, lhs_rs, lhs_cs,
367                alpha, beta,
368            );
369            return;
370        }
371        if m <= 1 && rhs_rs == 1 && lhs_cs == 1 {
372            gemv::mixed_gemv_rowmajor(
373                simd, n, m, k, dst, dst_rs, dst_cs, rhs, rhs_rs, rhs_cs, lhs, lhs_rs, lhs_cs,
374                alpha, beta,
375            );
376            return;
377        }
378    }
379
380    let KernelParams { kc, mc, nc } = kernel_params(m, n, k, MR, NR, core::mem::size_of::<f32>());
381    let nc = if nc > 0 {
382        nc
383    } else {
384        match parallelism {
385            Parallelism::None => 128 * NR,
386            #[cfg(feature = "rayon")]
387            Parallelism::Rayon(_) => n.msrv_next_multiple_of(NR),
388        }
389    };
390
391    let simd_align = CACHELINE_ALIGN;
392
393    let packed_rhs_stride = kc * NR;
394    let packed_lhs_stride = kc * MR;
395
396    let dst = Ptr(dst);
397    let lhs = Ptr(lhs as *mut T);
398    let rhs = Ptr(rhs as *mut T);
399
400    let do_prepack_lhs = m <= 2 * mc && ((m % N != 0) || lhs_rs != 1);
401
402    let rhs_req = StackReq::new_aligned::<f32>(packed_rhs_stride * (nc / NR), simd_align);
403    let lhs_req = StackReq::new_aligned::<f32>(
404        if do_prepack_lhs {
405            packed_lhs_stride * (m.msrv_next_multiple_of(MR) / MR)
406        } else {
407            0
408        },
409        simd_align,
410    );
411
412    let mut mem = MemBuffer::new(rhs_req.and(lhs_req));
413    #[cfg(not(feature = "std"))]
414    let mut l2_slab = MemBuffer::new(StackReq::new_aligned::<f32>(
415        packed_lhs_stride * (mc / MR),
416        simd_align,
417    ));
418
419    let stack = DynStack::new(&mut mem);
420    let (packed_rhs_storage, stack) =
421        stack.make_aligned_uninit::<f32>(packed_rhs_stride * (nc / NR), simd_align);
422
423    let packed_lhs_storage = stack
424        .make_aligned_uninit::<f32>(
425            if do_prepack_lhs {
426                packed_lhs_stride * (m.msrv_next_multiple_of(MR) / MR)
427            } else {
428                0
429            },
430            simd_align,
431        )
432        .0;
433
434    let packed_rhs = Ptr(packed_rhs_storage.as_mut_ptr() as *mut f32);
435    let prepacked_lhs = Ptr(packed_lhs_storage.as_mut_ptr() as *mut f32);
436
437    let packed_rhs_rs = NR as isize;
438    let packed_rhs_cs = 1;
439
440    let mut col_outer = 0;
441    while col_outer != n {
442        let n_chunk = nc.min(n - col_outer);
443
444        let mut alpha = simd.from_lhs(alpha);
445
446        let mut depth_outer = 0;
447        while depth_outer != k {
448            let k_chunk = kc.min(k - depth_outer);
449            let alpha_status = if alpha == 0.0 {
450                0
451            } else if alpha == 1.0 {
452                1
453            } else {
454                2
455            };
456
457            let n_threads = match parallelism {
458                Parallelism::None => 1,
459                #[cfg(feature = "rayon")]
460                Parallelism::Rayon(n_threads) => {
461                    let threading_threshold = get_threading_threshold();
462                    let total_work = (m * n_chunk).saturating_mul(k_chunk);
463                    if total_work < threading_threshold {
464                        1
465                    } else {
466                        if n_threads == 0 {
467                            rayon::current_num_threads()
468                        } else {
469                            n_threads
470                        }
471                    }
472                }
473            };
474
475            // pack rhs
476            if n_threads <= 1 {
477                pack_rhs::<N, NR, _>(
478                    simd,
479                    n_chunk,
480                    k_chunk,
481                    packed_rhs,
482                    rhs.wrapping_offset(
483                        depth_outer as isize * rhs_rs + col_outer as isize * rhs_cs,
484                    ),
485                    rhs_cs,
486                    rhs_rs,
487                    packed_rhs_stride,
488                );
489            } else {
490                #[cfg(feature = "rayon")]
491                {
492                    let n_tasks = n_chunk.msrv_div_ceil(NR);
493                    let base = n_tasks / n_threads;
494                    let rem = n_tasks % n_threads;
495
496                    let tid_to_col_inner = |tid: usize| {
497                        if tid == n_threads {
498                            return n_chunk;
499                        }
500
501                        let col = if tid < rem {
502                            NR * tid * (base + 1)
503                        } else {
504                            NR * (rem + tid * base)
505                        };
506
507                        col.min(n_chunk)
508                    };
509
510                    let func = |tid: usize| {
511                        let col_inner = tid_to_col_inner(tid);
512                        let ncols = tid_to_col_inner(tid + 1) - col_inner;
513                        let j = col_inner / NR;
514
515                        if ncols > 0 {
516                            pack_rhs::<N, NR, _>(
517                                simd,
518                                ncols,
519                                k_chunk,
520                                packed_rhs.wrapping_add(j * packed_rhs_stride),
521                                rhs.wrapping_offset(
522                                    depth_outer as isize * rhs_rs
523                                        + (col_outer + col_inner) as isize * rhs_cs,
524                                ),
525                                rhs_cs,
526                                rhs_rs,
527                                packed_rhs_stride,
528                            );
529                        }
530                    };
531                    par_for_each(n_threads, func);
532                }
533
534                #[cfg(not(feature = "rayon"))]
535                unreachable!();
536            }
537            if do_prepack_lhs {
538                pack_lhs::<N, MR, _>(
539                    simd,
540                    m,
541                    k_chunk,
542                    prepacked_lhs,
543                    lhs.wrapping_offset(depth_outer as isize * lhs_cs),
544                    lhs_cs,
545                    lhs_rs,
546                    packed_lhs_stride,
547                );
548            }
549
550            let n_col_mini_chunks = (n_chunk + (NR - 1)) / NR;
551
552            let mut n_jobs = 0;
553            let mut row_outer = 0;
554            while row_outer != m {
555                let mut m_chunk = mc.min(m - row_outer);
556                if m_chunk > N && !do_prepack_lhs {
557                    m_chunk = m_chunk / N * N;
558                }
559                let n_row_mini_chunks = (m_chunk + (MR - 1)) / MR;
560                n_jobs += n_col_mini_chunks * n_row_mini_chunks;
561                row_outer += m_chunk;
562            }
563
564            // use a single thread for small workloads
565
566            let func = move |tid, packed_lhs: Ptr<f32>| {
567                let min_jobs_per_thread = n_jobs / n_threads;
568                let rem = n_jobs - n_threads * min_jobs_per_thread;
569
570                // thread `tid` takes min_jobs_per_thread or min_jobs_per_thread + 1
571                let (job_start, job_end) = if tid < rem {
572                    let start = tid * (min_jobs_per_thread + 1);
573                    (start, start + min_jobs_per_thread + 1)
574                } else {
575                    // start = rem * (min_jobs_per_thread + 1) + (tid - rem) * min_jobs_per_thread;
576                    let start = tid * min_jobs_per_thread + rem;
577                    (start, start + min_jobs_per_thread)
578                };
579
580                let mut row_outer = 0;
581                let mut job_id = 0;
582                while row_outer != m {
583                    let mut m_chunk = mc.min(m - row_outer);
584                    if m_chunk > N && !do_prepack_lhs {
585                        m_chunk = m_chunk / N * N;
586                    }
587                    let n_row_mini_chunks = (m_chunk + (MR - 1)) / MR;
588
589                    let n_mini_jobs = n_col_mini_chunks * n_row_mini_chunks;
590
591                    if job_id >= job_end {
592                        return;
593                    }
594                    if job_id + n_mini_jobs < job_start {
595                        row_outer += m_chunk;
596                        job_id += n_mini_jobs;
597                        continue;
598                    }
599
600                    let packed_lhs_cs = MR as isize;
601
602                    if !do_prepack_lhs {
603                        pack_lhs::<N, MR, _>(
604                            simd,
605                            m_chunk,
606                            k_chunk,
607                            packed_lhs,
608                            lhs.wrapping_offset(
609                                row_outer as isize * lhs_rs + depth_outer as isize * lhs_cs,
610                            ),
611                            lhs_cs,
612                            lhs_rs,
613                            packed_lhs_stride,
614                        );
615                    }
616
617                    let mut j = 0;
618                    while j < n_col_mini_chunks {
619                        let mut i = 0;
620                        while i < n_row_mini_chunks {
621                            let col_inner = NR * j;
622                            let n_chunk_inner = NR.min(n_chunk - col_inner);
623
624                            let row_inner = MR * i;
625                            let m_chunk_inner = MR.min(m_chunk - row_inner);
626
627                            let inner_idx = &mut i;
628                            if job_id < job_start || job_id >= job_end {
629                                job_id += 1;
630                                *inner_idx += 1;
631                                continue;
632                            }
633                            job_id += 1;
634
635                            let dst = dst.wrapping_offset(
636                                (row_outer + row_inner) as isize * dst_rs
637                                    + (col_outer + col_inner) as isize * dst_cs,
638                            );
639
640                            let func =
641                                dispatcher[(m_chunk_inner + (N - 1)) / N - 1][n_chunk_inner - 1];
642
643                            let mut tmp = [[0.0f32; MR]; NR];
644
645                            func(
646                                m_chunk_inner,
647                                n_chunk_inner,
648                                k_chunk,
649                                tmp.as_mut_ptr() as *mut f32,
650                                if do_prepack_lhs {
651                                    packed_lhs
652                                        .wrapping_add((i + row_outer / MR) * packed_lhs_stride)
653                                        .0
654                                } else {
655                                    packed_lhs.wrapping_add(i * packed_lhs_stride).0
656                                },
657                                packed_rhs.wrapping_add(j * packed_rhs_stride).0,
658                                MR as isize,
659                                1,
660                                packed_lhs_cs,
661                                packed_rhs_rs,
662                                packed_rhs_cs,
663                                0.0,
664                                beta.into(),
665                                0,
666                                false,
667                                false,
668                                false,
669                                packed_lhs.wrapping_add((i + 1) * packed_lhs_stride).0,
670                            );
671
672                            match alpha_status {
673                                0 => {
674                                    for j in 0..n_chunk_inner {
675                                        for i in 0..m_chunk_inner {
676                                            let dst = dst
677                                                .wrapping_offset(j as isize * dst_cs)
678                                                .wrapping_offset(i as isize * dst_rs)
679                                                .0;
680                                            *dst = simd.into_dst(tmp[j][i]);
681                                        }
682                                    }
683                                }
684                                1 => {
685                                    for j in 0..n_chunk_inner {
686                                        for i in 0..m_chunk_inner {
687                                            let dst = dst
688                                                .wrapping_offset(j as isize * dst_cs)
689                                                .wrapping_offset(i as isize * dst_rs)
690                                                .0;
691                                            *dst = simd.into_dst(simd.from_dst(*dst) + tmp[j][i]);
692                                        }
693                                    }
694                                }
695                                _ => {
696                                    for j in 0..n_chunk_inner {
697                                        for i in 0..m_chunk_inner {
698                                            let dst = dst
699                                                .wrapping_offset(j as isize * dst_cs)
700                                                .wrapping_offset(i as isize * dst_rs)
701                                                .0;
702                                            *dst = simd
703                                                .into_dst(alpha * simd.from_dst(*dst) + tmp[j][i]);
704                                        }
705                                    }
706                                }
707                            }
708
709                            i += 1;
710                        }
711                        j += 1;
712                    }
713
714                    row_outer += m_chunk;
715                }
716            };
717
718            if do_prepack_lhs {
719                match parallelism {
720                    Parallelism::None => func(0, prepacked_lhs),
721                    #[cfg(feature = "rayon")]
722                    Parallelism::Rayon(_) => {
723                        if n_threads == 1 {
724                            func(0, prepacked_lhs);
725                        } else {
726                            par_for_each(n_threads, |tid| func(tid, prepacked_lhs));
727                        }
728                    }
729                }
730            } else {
731                #[cfg(feature = "std")]
732                let func = |tid: usize| {
733                    L2_SLAB.with(|mem| {
734                        let mut mem = mem.borrow_mut();
735                        let stack = DynStack::new(&mut mem);
736                        let (packed_lhs_storage, _) = stack
737                            .make_aligned_uninit::<f32>(packed_lhs_stride * (mc / MR), simd_align);
738                        let packed_lhs = Ptr(packed_lhs_storage.as_mut_ptr() as *mut f32);
739                        func(tid, packed_lhs);
740                    });
741                };
742
743                #[cfg(not(feature = "std"))]
744                let mut func = |tid: usize| {
745                    let stack = DynStack::new(&mut l2_slab);
746                    let (mut packed_lhs_storage, _) =
747                        stack.make_aligned_uninit::<f32>(packed_lhs_stride * (mc / MR), simd_align);
748                    let packed_lhs = Ptr(packed_lhs_storage.as_mut_ptr() as *mut f32);
749                    func(tid, packed_lhs);
750                };
751
752                match parallelism {
753                    Parallelism::None => func(0),
754                    #[cfg(feature = "rayon")]
755                    Parallelism::Rayon(_) => {
756                        if n_threads == 1 {
757                            func(0);
758                        } else {
759                            par_for_each(n_threads, func);
760                        }
761                    }
762                }
763            }
764
765            alpha = 1.0;
766            depth_outer += k_chunk;
767        }
768        col_outer += n_chunk;
769    }
770}
771
772pub mod f16 {
773    use super::gemm_basic_generic;
774    use gemm_common::Parallelism;
775
776    type T = half::f16;
777    type GemmTy = unsafe fn(
778        usize,
779        usize,
780        usize,
781        *mut T,
782        isize,
783        isize,
784        bool,
785        *const T,
786        isize,
787        isize,
788        *const T,
789        isize,
790        isize,
791        T,
792        T,
793        bool,
794        bool,
795        bool,
796        Parallelism,
797    );
798
799    fn init_gemm_fn() -> GemmTy {
800        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
801        {
802            #[cfg(feature = "nightly")]
803            if gemm_common::feature_detected!("avx512f") {
804                return avx512f::gemm_basic;
805            }
806            if gemm_common::feature_detected!("fma") {
807                fma::gemm_basic
808            } else {
809                scalar::gemm_basic
810            }
811        }
812
813        #[cfg(target_arch = "aarch64")]
814        {
815            if gemm_common::feature_detected!("neon") {
816                #[cfg(feature = "experimental-apple-amx")]
817                if gemm_common::cache::HasAmx::get() {
818                    return amx::gemm_basic;
819                }
820                if gemm_common::feature_detected!("fp16") {
821                    neonfp16::gemm_basic
822                } else {
823                    neon::gemm_basic
824                }
825            } else {
826                scalar::gemm_basic
827            }
828        }
829
830        #[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))]
831        {
832            scalar::gemm_basic
833        }
834    }
835
836    static GEMM_PTR: ::core::sync::atomic::AtomicPtr<()> =
837        ::core::sync::atomic::AtomicPtr::new(::core::ptr::null_mut());
838
839    #[inline(never)]
840    fn init_gemm_ptr() -> GemmTy {
841        let gemm_fn = init_gemm_fn();
842        GEMM_PTR.store(gemm_fn as *mut (), ::core::sync::atomic::Ordering::Relaxed);
843        gemm_fn
844    }
845
846    #[inline(always)]
847    pub fn get_gemm_fn() -> GemmTy {
848        let mut gemm_fn = GEMM_PTR.load(::core::sync::atomic::Ordering::Relaxed);
849        if gemm_fn.is_null() {
850            gemm_fn = init_gemm_ptr() as *mut ();
851        }
852        unsafe { ::core::mem::transmute(gemm_fn) }
853    }
854
855    mod scalar {
856        use super::*;
857        use gemm_common::simd::Scalar;
858        use gemm_f32::microkernel::scalar::f32::*;
859        const N: usize = 1;
860
861        #[inline(never)]
862        pub unsafe fn gemm_basic(
863            m: usize,
864            n: usize,
865            k: usize,
866            dst: *mut T,
867            dst_cs: isize,
868            dst_rs: isize,
869            read_dst: bool,
870            lhs: *const T,
871            lhs_cs: isize,
872            lhs_rs: isize,
873            rhs: *const T,
874            rhs_cs: isize,
875            rhs_rs: isize,
876            alpha: T,
877            beta: T,
878            _conj_dst: bool,
879            _conj_lhs: bool,
880            _conj_rhs: bool,
881            parallelism: gemm_common::Parallelism,
882        ) {
883            gemm_basic_generic::<N, { MR_DIV_N * N }, NR, MR_DIV_N, _>(
884                Scalar,
885                m,
886                n,
887                k,
888                dst,
889                dst_cs,
890                dst_rs,
891                read_dst,
892                lhs,
893                lhs_cs,
894                lhs_rs,
895                rhs,
896                rhs_cs,
897                rhs_rs,
898                alpha,
899                beta,
900                &UKR,
901                parallelism,
902            );
903        }
904    }
905
906    #[cfg(target_arch = "aarch64")]
907    mod neon {
908        use super::*;
909        use gemm_common::simd::MixedSimd;
910        use gemm_f32::microkernel::neon::f32::*;
911        const N: usize = 4;
912
913        #[inline(never)]
914        pub unsafe fn gemm_basic(
915            m: usize,
916            n: usize,
917            k: usize,
918            dst: *mut T,
919            dst_cs: isize,
920            dst_rs: isize,
921            read_dst: bool,
922            lhs: *const T,
923            lhs_cs: isize,
924            lhs_rs: isize,
925            rhs: *const T,
926            rhs_cs: isize,
927            rhs_rs: isize,
928            alpha: T,
929            beta: T,
930            _conj_dst: bool,
931            _conj_lhs: bool,
932            _conj_rhs: bool,
933            parallelism: gemm_common::Parallelism,
934        ) {
935            gemm_basic_generic::<N, { MR_DIV_N * N }, NR, MR_DIV_N, _>(
936                gemm_common::simd::Neon::try_new().unwrap(),
937                m,
938                n,
939                k,
940                dst,
941                dst_cs,
942                dst_rs,
943                read_dst,
944                lhs,
945                lhs_cs,
946                lhs_rs,
947                rhs,
948                rhs_cs,
949                rhs_rs,
950                alpha,
951                beta,
952                &UKR,
953                parallelism,
954            );
955        }
956    }
957
958    #[cfg(target_arch = "aarch64")]
959    mod neonfp16 {
960        use crate::microkernel::neonfp16::f16::*;
961        use gemm_common::simd::{MixedSimd, NeonFp16};
962        type T = half::f16;
963
964        #[inline(never)]
965        pub unsafe fn gemm_basic(
966            m: usize,
967            n: usize,
968            k: usize,
969            dst: *mut T,
970            dst_cs: isize,
971            dst_rs: isize,
972            read_dst: bool,
973            lhs: *const T,
974            lhs_cs: isize,
975            lhs_rs: isize,
976            rhs: *const T,
977            rhs_cs: isize,
978            rhs_rs: isize,
979            alpha: T,
980            beta: T,
981            _conj_dst: bool,
982            _conj_lhs: bool,
983            _conj_rhs: bool,
984            parallelism: gemm_common::Parallelism,
985        ) {
986            let simd = <NeonFp16 as MixedSimd<T, T, T, T>>::try_new().unwrap();
987
988            gemm_common::gemm::gemm_basic_generic::<_, _, N, { MR_DIV_N * N }, NR, MR_DIV_N, 0, 0>(
989                simd,
990                m,
991                n,
992                k,
993                dst,
994                dst_cs,
995                dst_rs,
996                read_dst,
997                lhs,
998                lhs_cs,
999                lhs_rs,
1000                rhs,
1001                rhs_cs,
1002                rhs_rs,
1003                alpha,
1004                beta,
1005                false,
1006                false,
1007                false,
1008                move |a, b, c| <NeonFp16 as MixedSimd<T, T, T, T>>::mult_add(simd, a, b, c),
1009                &UKR,
1010                &[],
1011                false,
1012                parallelism,
1013            );
1014        }
1015    }
1016
1017    #[cfg(target_arch = "aarch64")]
1018    #[cfg(feature = "experimental-apple-amx")]
1019    mod amx {
1020        use crate::microkernel::amx::f16::*;
1021        use gemm_common::simd::{MixedSimd, NeonFp16};
1022        type T = half::f16;
1023
1024        #[inline(never)]
1025        pub unsafe fn gemm_basic(
1026            m: usize,
1027            n: usize,
1028            k: usize,
1029            dst: *mut T,
1030            dst_cs: isize,
1031            dst_rs: isize,
1032            read_dst: bool,
1033            lhs: *const T,
1034            lhs_cs: isize,
1035            lhs_rs: isize,
1036            rhs: *const T,
1037            rhs_cs: isize,
1038            rhs_rs: isize,
1039            alpha: T,
1040            beta: T,
1041            _conj_dst: bool,
1042            _conj_lhs: bool,
1043            _conj_rhs: bool,
1044            parallelism: gemm_common::Parallelism,
1045        ) {
1046            let simd = <NeonFp16 as MixedSimd<T, T, T, T>>::try_new().unwrap();
1047
1048            gemm_common::gemm::gemm_basic_generic::<_, _, N, { MR_DIV_N * N }, NR, MR_DIV_N, 0, 0>(
1049                simd,
1050                m,
1051                n,
1052                k,
1053                dst,
1054                dst_cs,
1055                dst_rs,
1056                read_dst,
1057                lhs,
1058                lhs_cs,
1059                lhs_rs,
1060                rhs,
1061                rhs_cs,
1062                rhs_rs,
1063                alpha,
1064                beta,
1065                false,
1066                false,
1067                false,
1068                move |a, b, c| <NeonFp16 as MixedSimd<T, T, T, T>>::mult_add(simd, a, b, c),
1069                &UKR,
1070                &[],
1071                true,
1072                parallelism,
1073            );
1074        }
1075    }
1076
1077    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1078    mod fma {
1079        use super::*;
1080        use gemm_common::simd::V3;
1081        use gemm_f32::microkernel::fma::f32::*;
1082        const N: usize = 8;
1083
1084        #[inline(never)]
1085        pub unsafe fn gemm_basic(
1086            m: usize,
1087            n: usize,
1088            k: usize,
1089            dst: *mut T,
1090            dst_cs: isize,
1091            dst_rs: isize,
1092            read_dst: bool,
1093            lhs: *const T,
1094            lhs_cs: isize,
1095            lhs_rs: isize,
1096            rhs: *const T,
1097            rhs_cs: isize,
1098            rhs_rs: isize,
1099            alpha: T,
1100            beta: T,
1101            _conj_dst: bool,
1102            _conj_lhs: bool,
1103            _conj_rhs: bool,
1104            parallelism: gemm_common::Parallelism,
1105        ) {
1106            gemm_basic_generic::<N, { MR_DIV_N * N }, NR, MR_DIV_N, _>(
1107                V3::try_new().unwrap(),
1108                m,
1109                n,
1110                k,
1111                dst,
1112                dst_cs,
1113                dst_rs,
1114                read_dst,
1115                lhs,
1116                lhs_cs,
1117                lhs_rs,
1118                rhs,
1119                rhs_cs,
1120                rhs_rs,
1121                alpha,
1122                beta,
1123                &UKR,
1124                parallelism,
1125            );
1126        }
1127    }
1128
1129    #[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
1130    mod avx512f {
1131        use super::*;
1132        use gemm_common::simd::V4;
1133        use gemm_f32::microkernel::avx512f::f32::*;
1134        const N: usize = 16;
1135
1136        #[inline(never)]
1137        pub unsafe fn gemm_basic(
1138            m: usize,
1139            n: usize,
1140            k: usize,
1141            dst: *mut T,
1142            dst_cs: isize,
1143            dst_rs: isize,
1144            read_dst: bool,
1145            lhs: *const T,
1146            lhs_cs: isize,
1147            lhs_rs: isize,
1148            rhs: *const T,
1149            rhs_cs: isize,
1150            rhs_rs: isize,
1151            alpha: T,
1152            beta: T,
1153            _conj_dst: bool,
1154            _conj_lhs: bool,
1155            _conj_rhs: bool,
1156            parallelism: gemm_common::Parallelism,
1157        ) {
1158            gemm_basic_generic::<N, { MR_DIV_N * N }, NR, MR_DIV_N, _>(
1159                V4::try_new().unwrap(),
1160                m,
1161                n,
1162                k,
1163                dst,
1164                dst_cs,
1165                dst_rs,
1166                read_dst,
1167                lhs,
1168                lhs_cs,
1169                lhs_rs,
1170                rhs,
1171                rhs_cs,
1172                rhs_rs,
1173                alpha,
1174                beta,
1175                &UKR,
1176                parallelism,
1177            );
1178        }
1179    }
1180}