@@ -4587,7 +4587,252 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
4587
4587
4588
4588
const int nb = n / QK_K;
4589
4589
4590
- #ifdef __ARM_NEON
4590
+ #ifdef __ARM_FEATURE_SVE
4591
+ const int vector_length = svcntb()*8;
4592
+ const svuint8_t m3s = svdup_n_u8(0x3);
4593
+ const svuint32_t m4s = svdup_n_u32(0xF);
4594
+ const svint32_t vzero_sv = svdup_n_s32(0);
4595
+ svfloat32_t acc_sum = svdup_n_f32(0);
4596
+ svbool_t pred_s32 = svptrue_pat_b32(SV_VL4);
4597
+
4598
+ switch (vector_length) {
4599
+ case 128:
4600
+ for (int i = 0; i < nb; ++i) {
4601
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
4602
+ svfloat32_t d_broad = svdup_n_f32((float32_t)d);
4603
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
4604
+ svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin);
4605
+
4606
+ const uint8_t * restrict q2 = x[i].qs;
4607
+ const int8_t * restrict q8_sv = y[i].qs;
4608
+ const uint8_t * restrict sc = x[i].scales;
4609
+
4610
+ svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc);
4611
+ const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
4612
+
4613
+ mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+4);
4614
+ const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
4615
+
4616
+ svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums);
4617
+ svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+4);
4618
+
4619
+ const svint32_t s0 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_2, q8sums_sv_2));
4620
+
4621
+ mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+8);
4622
+ const svint32_t mins_sv_3 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
4623
+
4624
+ mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+12);
4625
+ const svint32_t mins_sv_4 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
4626
+
4627
+ q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums+8);
4628
+ q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+12);
4629
+
4630
+ svint32_t s1 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_3, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_4, q8sums_sv_2));
4631
+
4632
+ svfloat32_t temp = svcvt_f32_s32_x(svptrue_b32(), svadd_s32_x(svptrue_b32(), s0, s1));
4633
+
4634
+ acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, temp, dmin_broad);
4635
+
4636
+ svint32_t sumi1 = svdup_n_s32(0);
4637
+
4638
+ {
4639
+ const svuint8_t q2bits_1 = svld1_u8(svptrue_b8(), q2);
4640
+ svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_1, m3s));
4641
+ svint8_t q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4642
+ const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc), m4s));
4643
+
4644
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 0));
4645
+
4646
+ const svuint8_t q2bits_3 = svld1_u8(svptrue_b8(), q2+16);
4647
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_3, m3s));
4648
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4649
+
4650
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 1));
4651
+
4652
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 2), m3s));
4653
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4654
+
4655
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 2));
4656
+
4657
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 2), m3s));
4658
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4659
+
4660
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 3));
4661
+
4662
+
4663
+ const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+4), m4s));
4664
+
4665
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 4), m3s));
4666
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4667
+
4668
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 0));
4669
+
4670
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 4), m3s));
4671
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4672
+
4673
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 1));
4674
+
4675
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 6), m3s));
4676
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4677
+
4678
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 2));
4679
+
4680
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 6), m3s));
4681
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4682
+
4683
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 3));
4684
+
4685
+ //-------------------------------
4686
+
4687
+ q2 += 32;
4688
+ const svint32_t scales_sv_2 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+8), m4s));
4689
+ const svuint8_t q2bits_2 = svld1_u8(svptrue_b8(), q2);
4690
+
4691
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_2, m3s));
4692
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4693
+
4694
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 0));
4695
+
4696
+ const svuint8_t q2bits_4 = svld1_u8(svptrue_b8(), q2+16);
4697
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_4, m3s));
4698
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4699
+
4700
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 1));
4701
+
4702
+
4703
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 2), m3s));
4704
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4705
+
4706
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 2));
4707
+
4708
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 2), m3s));
4709
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4710
+
4711
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 3));
4712
+
4713
+
4714
+ const svint32_t scales_sv_3 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+12), m4s));
4715
+
4716
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 4), m3s));
4717
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4718
+
4719
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 0));
4720
+
4721
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 4), m3s));
4722
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4723
+
4724
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 1));
4725
+
4726
+
4727
+
4728
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 6), m3s));
4729
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4730
+
4731
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 2));
4732
+
4733
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 6), m3s));
4734
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4735
+
4736
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 3));
4737
+ }
4738
+ acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, svcvt_f32_s32_x(svptrue_b32(), sumi1), d_broad);
4739
+ }
4740
+ *s = svaddv_f32(svptrue_b32(), acc_sum);
4741
+ break;
4742
+
4743
+ case 256:
4744
+ case 512:
4745
+ for (int i = 0; i < nb; ++i) {
4746
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
4747
+ svfloat32_t d_broad = svdup_n_f32((float32_t)d);
4748
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
4749
+ svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin);
4750
+
4751
+ const uint8_t * restrict q2 = x[i].qs;
4752
+ const int8_t * restrict q8_sv = y[i].qs;
4753
+ const uint8_t * restrict sc = x[i].scales;
4754
+
4755
+ const svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc); sc += 8;
4756
+ const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, m4s));
4757
+ const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, 4));
4758
+ svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums);
4759
+
4760
+ const svuint32_t mins_and_scales_sve_1 = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc);
4761
+ const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, m4s));
4762
+ const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, 4));
4763
+
4764
+ svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums+8);
4765
+
4766
+ svfloat32_t temp = svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_2, q8sums_sv_2)));
4767
+
4768
+ acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, temp, dmin_broad);
4769
+
4770
+ svint32_t sumi1 = svdup_n_s32(0);
4771
+
4772
+ {
4773
+ const svuint8_t q2bits_1 = svld1_u8(svptrue_pat_b8(SV_VL32), q2);
4774
+ svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_1, m3s));
4775
+ svint8_t q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
4776
+
4777
+ svint32_t scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 0), svdup_lane_s32(scales_sv, 1));
4778
+ sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
4779
+
4780
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 2), m3s));
4781
+ q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
4782
+
4783
+ svint32_t scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 2), svdup_lane_s32(scales_sv, 3));
4784
+ sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(svdup_n_s32(0), q2bytes_sv, q8bytes_sv), scale_2);
4785
+
4786
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 4), m3s));
4787
+ q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
4788
+
4789
+ scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 4), svdup_lane_s32(scales_sv, 5));
4790
+ sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
4791
+
4792
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 6), m3s));
4793
+ q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
4794
+
4795
+ scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 6), svdup_lane_s32(scales_sv, 7));
4796
+ sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
4797
+
4798
+ q2 += 32;
4799
+
4800
+ const svuint8_t q2bits_2 = svld1_u8(svptrue_pat_b8(SV_VL32), q2);
4801
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_2, m3s));
4802
+ q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
4803
+
4804
+ scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 0), svdup_lane_s32(scales_sv_1, 1));
4805
+ sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
4806
+
4807
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 2), m3s));
4808
+ q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
4809
+
4810
+ scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 2), svdup_lane_s32(scales_sv_1, 3));
4811
+ sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
4812
+
4813
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 4), m3s));
4814
+ q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
4815
+
4816
+ scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 4), svdup_lane_s32(scales_sv_1, 5));
4817
+ sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
4818
+
4819
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 6), m3s));
4820
+ q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
4821
+
4822
+ scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 6), svdup_lane_s32(scales_sv_1, 7));
4823
+ sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
4824
+ }
4825
+ acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), sumi1), d_broad);
4826
+ }
4827
+ *s = svaddv_f32(svptrue_pat_b32(SV_VL8), acc_sum);
4828
+ break;
4829
+
4830
+ default:
4831
+ assert(false && "Unsupported vector length");
4832
+ break;
4833
+ }
4834
+
4835
+ #elif __ARM_NEON
4591
4836
const uint8x16_t m3 = vdupq_n_u8(0x3);
4592
4837
const uint8x16_t m4 = vdupq_n_u8(0xF);
4593
4838
0 commit comments