Skip to content

Commit c43a3e7

Browse files
ngxsonnisparks
andauthored
llama : add Phi-4-mini support (supersede #12099) (#12108)
* Added Phi-4-mini-instruct support * Update regex per ngxson * Change the vocab base to Xenova/gpt-4o * fix conversion update script * no need to check longrope * minor style fix * fix python style --------- Co-authored-by: Nicholas Sparks <[email protected]>
1 parent 84d5f4b commit c43a3e7

7 files changed

+191
-8
lines changed

convert_hf_to_gguf.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
699699
if chkhsh == "b3f499bb4255f8ca19fccd664443283318f2fd2414d5e0b040fbdd0cc195d6c5":
700700
# ref: https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
701701
res = "deepseek-r1-qwen"
702+
if chkhsh == "ccc2ef013c104be7bae2965776d611e1d7a8a2a9c547dd93a682c9a9fc80352e":
703+
# ref: https://huggingface.co/Xenova/gpt-4o
704+
res = "gpt-4o"
702705

703706
if res is None:
704707
logger.warning("\n")
@@ -2512,7 +2515,8 @@ def set_gguf_parameters(self):
25122515
rms_eps = self.find_hparam(["rms_norm_eps"])
25132516
max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"])
25142517
orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"])
2515-
rope_dims = n_embd // n_head
2518+
rot_pct = self.hparams.get("partial_rotary_factor", 1.0)
2519+
rope_dims = int(rot_pct * n_embd) // n_head
25162520

25172521
self.gguf_writer.add_context_length(max_pos_embds)
25182522
self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_max_pos_embds)
@@ -2536,7 +2540,8 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
25362540
n_head = self.find_hparam(["num_attention_heads", "n_head"])
25372541
max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"])
25382542
orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"])
2539-
rope_dims = n_embd // n_head
2543+
rot_pct = self.hparams.get("partial_rotary_factor", 1.0)
2544+
rope_dims = int(rot_pct * n_embd) // n_head
25402545

25412546
# write rope scaling for long context (128k) model
25422547
rope_scaling = self.find_hparam(['rope_scaling'], True)
@@ -2565,7 +2570,7 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
25652570
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
25662571

25672572
if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2:
2568-
raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}')
2573+
raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}. long_factors = {len(long_factors)}, short_factors = {len(short_factors)}.')
25692574

25702575
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32))
25712576
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32))

convert_hf_to_gguf_update.py

+5
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ class TOKENIZER_TYPE(IntEnum):
109109
{"name": "megrez", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Infinigence/Megrez-3B-Instruct"},
110110
{"name": "deepseek-v3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"},
111111
{"name": "deepseek-r1-qwen", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"},
112+
{"name": "gpt-4o", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Xenova/gpt-4o", },
112113
]
113114

114115

@@ -131,6 +132,10 @@ def download_model(model):
131132

132133
files = ["config.json", "tokenizer.json", "tokenizer_config.json"]
133134

135+
if name == "gpt-4o":
136+
# Xenova/gpt-4o is tokenizer-only, it does not contain config.json
137+
files = ["tokenizer.json", "tokenizer_config.json"]
138+
134139
if tokt == TOKENIZER_TYPE.SPM:
135140
files.append("tokenizer.model")
136141

include/llama.h

+1
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ extern "C" {
105105
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
106106
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
107107
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
108+
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
108109
};
109110

110111
enum llama_rope_type {

models/ggml-vocab-gpt-4o.gguf.inp

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
ied 4 ½ months
2+
__ggml_vocab_test__
3+
Führer
4+
__ggml_vocab_test__
5+
6+
__ggml_vocab_test__
7+
8+
__ggml_vocab_test__
9+
10+
__ggml_vocab_test__
11+
12+
__ggml_vocab_test__
13+
14+
__ggml_vocab_test__
15+
16+
17+
__ggml_vocab_test__
18+
19+
20+
21+
__ggml_vocab_test__
22+
23+
24+
25+
26+
__ggml_vocab_test__
27+
28+
29+
__ggml_vocab_test__
30+
Hello world
31+
__ggml_vocab_test__
32+
Hello world
33+
__ggml_vocab_test__
34+
Hello World
35+
__ggml_vocab_test__
36+
Hello World
37+
__ggml_vocab_test__
38+
Hello World!
39+
__ggml_vocab_test__
40+
Hello, world!
41+
__ggml_vocab_test__
42+
Hello, world!
43+
__ggml_vocab_test__
44+
this is 🦙.cpp
45+
__ggml_vocab_test__
46+
w048 7tuijk dsdfhu
47+
__ggml_vocab_test__
48+
нещо на Български
49+
__ggml_vocab_test__
50+
កាន់តែពិសេសអាចខលចេញ
51+
__ggml_vocab_test__
52+
🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)
53+
__ggml_vocab_test__
54+
Hello
55+
__ggml_vocab_test__
56+
Hello
57+
__ggml_vocab_test__
58+
Hello
59+
__ggml_vocab_test__
60+
Hello
61+
__ggml_vocab_test__
62+
Hello
63+
__ggml_vocab_test__
64+
Hello
65+
Hello
66+
__ggml_vocab_test__
67+
(
68+
__ggml_vocab_test__
69+
70+
=
71+
__ggml_vocab_test__
72+
' era
73+
__ggml_vocab_test__
74+
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
75+
__ggml_vocab_test__
76+
!!!!!!
77+
__ggml_vocab_test__
78+
3
79+
__ggml_vocab_test__
80+
33
81+
__ggml_vocab_test__
82+
333
83+
__ggml_vocab_test__
84+
3333
85+
__ggml_vocab_test__
86+
33333
87+
__ggml_vocab_test__
88+
333333
89+
__ggml_vocab_test__
90+
3333333
91+
__ggml_vocab_test__
92+
33333333
93+
__ggml_vocab_test__
94+
333333333
95+
__ggml_vocab_test__
96+
Cửa Việt
97+
__ggml_vocab_test__
98+
discards
99+
__ggml_vocab_test__
100+
101+
102+
103+
104+
105+
106+
107+
108+
109+
110+
111+
🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL
112+
__ggml_vocab_test__

models/ggml-vocab-gpt-4o.gguf.out

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
1165 220 19 220 27124 5503
2+
37 19194 259
3+
4+
220
5+
256
6+
271
7+
197
8+
198
9+
279
10+
2499
11+
2775
12+
13225 2375
13+
32949 2375
14+
13225 5922
15+
32949 5922
16+
32949 5922 0
17+
13225 11 2375 0
18+
32949 11 2375 0
19+
495 382 9552 99 247 13 17159
20+
86 45404 220 22 10191 2852 22924 4750 6916
21+
3907 53641 1235 185386 8118
22+
11400 107516 15867 20804 22851 134178 77431 32010 104312 37984 16329 27751 89335
23+
112927 222 350 14559 8 22861 114 2524 64364 104 15148 350 76466 166700 121942 780 8 91349 350 7393 74471 484 853 1617 2316 6602 8
24+
13225
25+
32949
26+
220 32949
27+
256 32949
28+
271 32949
29+
271 32949 198 271 32949
30+
350
31+
198 314
32+
6 6837
33+
13225 11 342 70653 0 3253 553 481 22861 223 1423 7522 18165 2178 34058 22369 16412 32999 16 867 8208
34+
147475
35+
18
36+
2546
37+
15517
38+
15517 18
39+
15517 2546
40+
15517 15517
41+
15517 15517 18
42+
15517 15517 2546
43+
15517 15517 15517
44+
34 60213 53904
45+
2960 3098
46+
126470 25980 160432 16609 2775 4066 172261 19432 112927 222 350 14559 8 22861 114 2524 64364 104 15148 350 76466 166700 121942 780 8 91349 9552 99 247 4103 99 247 220 18 220 2546 220 15517 220 15517 18 220 15517 2546 220 15517 15517 220 15517 15517 18 220 15517 15517 2546 220 18 13 18 220 18 485 18 220 18 1008 18 44735 107516 15867 20804 22851 134178 77431 32010 104312 156437 1423 7522 18165 2178 34058 22369 16412 32999 16 867 8208 105024 106657 1967 53641 1235 185386 8118 22434 39336 26178 26178 168394 194663 27271 147475 25883 6961 9790 1339 461 83 1280 19016 1354 11 461 1099 481 3239 30 461 44 625 3239 17291 1520 480 11 461 35 481 1299 1236 17966 30 1416 6 27493 261 54602 43

src/llama-model.cpp

+8-5
Original file line numberDiff line numberDiff line change
@@ -2202,13 +2202,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
22022202
} break;
22032203
case LLM_ARCH_PHI3:
22042204
{
2205-
const int64_t n_embd_head = n_embd / n_head;
2206-
22072205
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
22082206

22092207
// output
22102208
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
2211-
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0);
2209+
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
2210+
2211+
// if output is NULL, init from the input tok embed
2212+
if (output == NULL) {
2213+
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
2214+
}
22122215

22132216
for (int i = 0; i < n_layer; ++i) {
22142217
auto & layer = layers[i];
@@ -2223,8 +2226,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
22232226
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
22242227
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0);
22252228

2226-
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
2227-
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
2229+
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
2230+
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
22282231
}
22292232
} break;
22302233
case LLM_ARCH_PHIMOE:

src/llama-vocab.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
392392
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
393393
};
394394
break;
395+
case LLAMA_VOCAB_PRE_TYPE_GPT4O:
396+
regex_exprs = {
397+
// original regex from tokenizer.json
398+
// "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
399+
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
400+
};
401+
break;
395402
default:
396403
// default regex for BPE tokenization pre-processing
397404
regex_exprs = {
@@ -1592,6 +1599,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
15921599
} else if (
15931600
tokenizer_pre == "megrez") {
15941601
pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
1602+
} else if (
1603+
tokenizer_pre == "gpt-4o") {
1604+
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O;
1605+
clean_spaces = false;
15951606
} else {
15961607
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
15971608
}

0 commit comments

Comments
 (0)