Skip to content

Commit b504008

Browse files
authored
llama : fix n_rot default (#8348)
ggml-ci
1 parent d39130a commit b504008

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

src/llama.cpp

+15-14
Original file line numberDiff line numberDiff line change
@@ -4625,16 +4625,6 @@ static void llm_load_hparams(
46254625

46264626
// non-transformer models do not have attention heads
46274627
if (hparams.n_head() > 0) {
4628-
// sanity check for n_rot (optional)
4629-
hparams.n_rot = hparams.n_embd / hparams.n_head();
4630-
4631-
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
4632-
4633-
if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
4634-
if (hparams.n_rot != hparams.n_embd / hparams.n_head()) {
4635-
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head()));
4636-
}
4637-
}
46384628
// gpt-neox n_rot = rotary_pct * (n_embd / n_head)
46394629
// gpt-j n_rot = rotary_dim
46404630

@@ -4643,6 +4633,17 @@ static void llm_load_hparams(
46434633

46444634
hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();
46454635
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
4636+
4637+
// sanity check for n_rot (optional)
4638+
hparams.n_rot = hparams.n_embd_head_k;
4639+
4640+
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
4641+
4642+
if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
4643+
if (hparams.n_rot != hparams.n_embd_head_k) {
4644+
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
4645+
}
4646+
}
46464647
} else {
46474648
hparams.n_rot = 0;
46484649
hparams.n_embd_head_k = 0;
@@ -11490,7 +11491,7 @@ struct llm_build_context {
1149011491

1149111492
Qcur = ggml_rope_ext(
1149211493
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
11493-
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
11494+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1149411495
ext_factor, attn_factor, beta_fast, beta_slow);
1149511496
cb(Qcur, "Qcur", il);
1149611497

@@ -11499,7 +11500,7 @@ struct llm_build_context {
1149911500

1150011501
Kcur = ggml_rope_ext(
1150111502
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
11502-
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
11503+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1150311504
ext_factor, attn_factor, beta_fast, beta_slow);
1150411505
cb(Kcur, "Kcur", il);
1150511506

@@ -11603,7 +11604,7 @@ struct llm_build_context {
1160311604

1160411605
Qcur = ggml_rope_ext(
1160511606
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
11606-
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
11607+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1160711608
ext_factor, attn_factor, beta_fast, beta_slow);
1160811609
cb(Qcur, "Qcur", il);
1160911610

@@ -11612,7 +11613,7 @@ struct llm_build_context {
1161211613

1161311614
Kcur = ggml_rope_ext(
1161411615
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
11615-
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
11616+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1161611617
ext_factor, attn_factor, beta_fast, beta_slow);
1161711618
cb(Kcur, "Kcur", il);
1161811619

0 commit comments

Comments
 (0)