Skip to content

Commit 05e6f5a

Browse files
Vithuleppvname
andauthored
ggml: aarch64: implement SVE kernels for q2_k_q8_k vector dot (#12064)
* Added SVE Support for Q2_K Quantized Models * Use 4-space indentation in the switch cases * removed comments lines * Remove the loop Retain the curly bracess for better understanding of code * Remove the comment like added for q3_k_q8_k kernel --------- Co-authored-by: vithulep <[email protected]>
1 parent 673cfef commit 05e6f5a

File tree

1 file changed

+246
-1
lines changed

1 file changed

+246
-1
lines changed

ggml/src/ggml-cpu/ggml-cpu-quants.c

+246-1
Original file line numberDiff line numberDiff line change
@@ -4587,7 +4587,252 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
45874587

45884588
const int nb = n / QK_K;
45894589

4590-
#ifdef __ARM_NEON
4590+
#ifdef __ARM_FEATURE_SVE
4591+
const int vector_length = svcntb()*8;
4592+
const svuint8_t m3s = svdup_n_u8(0x3);
4593+
const svuint32_t m4s = svdup_n_u32(0xF);
4594+
const svint32_t vzero_sv = svdup_n_s32(0);
4595+
svfloat32_t acc_sum = svdup_n_f32(0);
4596+
svbool_t pred_s32 = svptrue_pat_b32(SV_VL4);
4597+
4598+
switch (vector_length) {
4599+
case 128:
4600+
for (int i = 0; i < nb; ++i) {
4601+
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
4602+
svfloat32_t d_broad = svdup_n_f32((float32_t)d);
4603+
const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
4604+
svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin);
4605+
4606+
const uint8_t * restrict q2 = x[i].qs;
4607+
const int8_t * restrict q8_sv = y[i].qs;
4608+
const uint8_t * restrict sc = x[i].scales;
4609+
4610+
svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc);
4611+
const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
4612+
4613+
mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+4);
4614+
const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
4615+
4616+
svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums);
4617+
svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+4);
4618+
4619+
const svint32_t s0 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_2, q8sums_sv_2));
4620+
4621+
mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+8);
4622+
const svint32_t mins_sv_3 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
4623+
4624+
mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+12);
4625+
const svint32_t mins_sv_4 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
4626+
4627+
q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums+8);
4628+
q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+12);
4629+
4630+
svint32_t s1 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_3, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_4, q8sums_sv_2));
4631+
4632+
svfloat32_t temp = svcvt_f32_s32_x(svptrue_b32(), svadd_s32_x(svptrue_b32(), s0, s1));
4633+
4634+
acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, temp, dmin_broad);
4635+
4636+
svint32_t sumi1 = svdup_n_s32(0);
4637+
4638+
{
4639+
const svuint8_t q2bits_1 = svld1_u8(svptrue_b8(), q2);
4640+
svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_1, m3s));
4641+
svint8_t q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4642+
const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc), m4s));
4643+
4644+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 0));
4645+
4646+
const svuint8_t q2bits_3 = svld1_u8(svptrue_b8(), q2+16);
4647+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_3, m3s));
4648+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4649+
4650+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 1));
4651+
4652+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 2), m3s));
4653+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4654+
4655+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 2));
4656+
4657+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 2), m3s));
4658+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4659+
4660+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 3));
4661+
4662+
4663+
const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+4), m4s));
4664+
4665+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 4), m3s));
4666+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4667+
4668+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 0));
4669+
4670+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 4), m3s));
4671+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4672+
4673+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 1));
4674+
4675+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 6), m3s));
4676+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4677+
4678+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 2));
4679+
4680+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 6), m3s));
4681+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4682+
4683+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 3));
4684+
4685+
//-------------------------------
4686+
4687+
q2 += 32;
4688+
const svint32_t scales_sv_2 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+8), m4s));
4689+
const svuint8_t q2bits_2 = svld1_u8(svptrue_b8(), q2);
4690+
4691+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_2, m3s));
4692+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4693+
4694+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 0));
4695+
4696+
const svuint8_t q2bits_4 = svld1_u8(svptrue_b8(), q2+16);
4697+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_4, m3s));
4698+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4699+
4700+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 1));
4701+
4702+
4703+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 2), m3s));
4704+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4705+
4706+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 2));
4707+
4708+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 2), m3s));
4709+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4710+
4711+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 3));
4712+
4713+
4714+
const svint32_t scales_sv_3 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+12), m4s));
4715+
4716+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 4), m3s));
4717+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4718+
4719+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 0));
4720+
4721+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 4), m3s));
4722+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4723+
4724+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 1));
4725+
4726+
4727+
4728+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 6), m3s));
4729+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4730+
4731+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 2));
4732+
4733+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 6), m3s));
4734+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4735+
4736+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 3));
4737+
}
4738+
acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, svcvt_f32_s32_x(svptrue_b32(), sumi1), d_broad);
4739+
}
4740+
*s = svaddv_f32(svptrue_b32(), acc_sum);
4741+
break;
4742+
4743+
case 256:
4744+
case 512:
4745+
for (int i = 0; i < nb; ++i) {
4746+
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
4747+
svfloat32_t d_broad = svdup_n_f32((float32_t)d);
4748+
const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
4749+
svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin);
4750+
4751+
const uint8_t * restrict q2 = x[i].qs;
4752+
const int8_t * restrict q8_sv = y[i].qs;
4753+
const uint8_t * restrict sc = x[i].scales;
4754+
4755+
const svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc); sc += 8;
4756+
const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, m4s));
4757+
const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, 4));
4758+
svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums);
4759+
4760+
const svuint32_t mins_and_scales_sve_1 = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc);
4761+
const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, m4s));
4762+
const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, 4));
4763+
4764+
svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums+8);
4765+
4766+
svfloat32_t temp = svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_2, q8sums_sv_2)));
4767+
4768+
acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, temp, dmin_broad);
4769+
4770+
svint32_t sumi1 = svdup_n_s32(0);
4771+
4772+
{
4773+
const svuint8_t q2bits_1 = svld1_u8(svptrue_pat_b8(SV_VL32), q2);
4774+
svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_1, m3s));
4775+
svint8_t q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
4776+
4777+
svint32_t scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 0), svdup_lane_s32(scales_sv, 1));
4778+
sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
4779+
4780+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 2), m3s));
4781+
q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
4782+
4783+
svint32_t scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 2), svdup_lane_s32(scales_sv, 3));
4784+
sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(svdup_n_s32(0), q2bytes_sv, q8bytes_sv), scale_2);
4785+
4786+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 4), m3s));
4787+
q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
4788+
4789+
scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 4), svdup_lane_s32(scales_sv, 5));
4790+
sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
4791+
4792+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 6), m3s));
4793+
q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
4794+
4795+
scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 6), svdup_lane_s32(scales_sv, 7));
4796+
sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
4797+
4798+
q2 += 32;
4799+
4800+
const svuint8_t q2bits_2 = svld1_u8(svptrue_pat_b8(SV_VL32), q2);
4801+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_2, m3s));
4802+
q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
4803+
4804+
scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 0), svdup_lane_s32(scales_sv_1, 1));
4805+
sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
4806+
4807+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 2), m3s));
4808+
q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
4809+
4810+
scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 2), svdup_lane_s32(scales_sv_1, 3));
4811+
sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
4812+
4813+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 4), m3s));
4814+
q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
4815+
4816+
scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 4), svdup_lane_s32(scales_sv_1, 5));
4817+
sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
4818+
4819+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 6), m3s));
4820+
q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
4821+
4822+
scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 6), svdup_lane_s32(scales_sv_1, 7));
4823+
sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
4824+
}
4825+
acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), sumi1), d_broad);
4826+
}
4827+
*s = svaddv_f32(svptrue_pat_b32(SV_VL8), acc_sum);
4828+
break;
4829+
4830+
default:
4831+
assert(false && "Unsupported vector length");
4832+
break;
4833+
}
4834+
4835+
#elif __ARM_NEON
45914836
const uint8x16_t m3 = vdupq_n_u8(0x3);
45924837
const uint8x16_t m4 = vdupq_n_u8(0xF);
45934838

0 commit comments

Comments
 (0)