@@ -4625,16 +4625,6 @@ static void llm_load_hparams(
4625
4625
4626
4626
// non-transformer models do not have attention heads
4627
4627
if (hparams.n_head() > 0) {
4628
- // sanity check for n_rot (optional)
4629
- hparams.n_rot = hparams.n_embd / hparams.n_head();
4630
-
4631
- ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
4632
-
4633
- if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
4634
- if (hparams.n_rot != hparams.n_embd / hparams.n_head()) {
4635
- throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head()));
4636
- }
4637
- }
4638
4628
// gpt-neox n_rot = rotary_pct * (n_embd / n_head)
4639
4629
// gpt-j n_rot = rotary_dim
4640
4630
@@ -4643,6 +4633,17 @@ static void llm_load_hparams(
4643
4633
4644
4634
hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();
4645
4635
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
4636
+
4637
+ // sanity check for n_rot (optional)
4638
+ hparams.n_rot = hparams.n_embd_head_k;
4639
+
4640
+ ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
4641
+
4642
+ if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
4643
+ if (hparams.n_rot != hparams.n_embd_head_k) {
4644
+ throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
4645
+ }
4646
+ }
4646
4647
} else {
4647
4648
hparams.n_rot = 0;
4648
4649
hparams.n_embd_head_k = 0;
@@ -11490,7 +11491,7 @@ struct llm_build_context {
11490
11491
11491
11492
Qcur = ggml_rope_ext(
11492
11493
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
11493
- n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11494
+ n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
11494
11495
ext_factor, attn_factor, beta_fast, beta_slow);
11495
11496
cb(Qcur, "Qcur", il);
11496
11497
@@ -11499,7 +11500,7 @@ struct llm_build_context {
11499
11500
11500
11501
Kcur = ggml_rope_ext(
11501
11502
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
11502
- n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11503
+ n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
11503
11504
ext_factor, attn_factor, beta_fast, beta_slow);
11504
11505
cb(Kcur, "Kcur", il);
11505
11506
@@ -11603,7 +11604,7 @@ struct llm_build_context {
11603
11604
11604
11605
Qcur = ggml_rope_ext(
11605
11606
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
11606
- n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11607
+ n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
11607
11608
ext_factor, attn_factor, beta_fast, beta_slow);
11608
11609
cb(Qcur, "Qcur", il);
11609
11610
@@ -11612,7 +11613,7 @@ struct llm_build_context {
11612
11613
11613
11614
Kcur = ggml_rope_ext(
11614
11615
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
11615
- n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11616
+ n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
11616
11617
ext_factor, attn_factor, beta_fast, beta_slow);
11617
11618
cb(Kcur, "Kcur", il);
11618
11619
0 commit comments