Skip to content

Commit af0a5b6

Browse files
server: fix incorrectly reported token probabilities (#7125)
* server: normalize token probabilities * fix temperature == 0.0f
1 parent b6aa670 commit af0a5b6

File tree

4 files changed

+31
-11
lines changed

4 files changed

+31
-11
lines changed

common/sampling.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
3535

3636
result->prev.resize(params.n_prev);
3737

38+
result->n_considered = 0;
39+
3840
llama_sampling_set_rng_seed(result, params.seed);
3941

4042
return result;
@@ -64,6 +66,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
6466

6567
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
6668
ctx->cur.clear();
69+
ctx->n_considered = 0;
6770
}
6871

6972
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
@@ -253,6 +256,8 @@ static llama_token llama_sampling_sample_impl(
253256
}
254257
}
255258

259+
ctx_sampling->n_considered = cur_p.size;
260+
256261
return id;
257262
}
258263

common/sampling.h

+1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ struct llama_sampling_context {
8181
// TODO: replace with ring-buffer
8282
std::vector<llama_token> prev;
8383
std::vector<llama_token_data> cur;
84+
size_t n_considered;
8485

8586
std::mt19937 rng;
8687
};

examples/server/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ node index.js
272272

273273
`logit_bias`: Modify the likelihood of a token appearing in the generated text completion. For example, use `"logit_bias": [[15043,1.0]]` to increase the likelihood of the token 'Hello', or `"logit_bias": [[15043,-1.0]]` to decrease its likelihood. Setting the value to false, `"logit_bias": [[15043,false]]` ensures that the token `Hello` is never produced. The tokens can also be represented as strings, e.g. `[["Hello, World!",-0.5]]` will reduce the likelihood of all the individual tokens that represent the string `Hello, World!`, just like the `presence_penalty` does. Default: `[]`
274274

275-
`n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token. Default: `0`
275+
`n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token given the sampling settings. Note that for temperature < 0 the tokens are sampled greedily but token probabilities are still being calculated via a simple softmax of the logits without considering any other sampler settings. Default: `0`
276276

277277
`min_keep`: If greater than 0, force samplers to return N possible tokens at minimum. Default: `0`
278278

examples/server/server.cpp

+24-10
Original file line numberDiff line numberDiff line change
@@ -2266,17 +2266,31 @@ struct server_context {
22662266
llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
22672267
result.tok = id;
22682268

2269-
const int32_t n_probs = slot.sparams.n_probs;
2270-
if (slot.sparams.temp <= 0 && n_probs > 0) {
2271-
// for llama_sample_token_greedy we need to sort candidates
2272-
llama_sample_softmax(ctx, &cur_p);
2273-
}
2269+
const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs);
2270+
if (n_probs > 0) {
2271+
const size_t n_considered = slot.ctx_sampling->n_considered;
22742272

2275-
for (size_t i = 0; i < std::min(cur_p.size, (size_t) n_probs); ++i) {
2276-
result.probs.push_back({
2277-
cur_p.data[i].id,
2278-
cur_p.data[i].p
2279-
});
2273+
// Make sure at least n_probs top tokens are at the front of the vector:
2274+
if (slot.sparams.temp == 0.0f && n_probs > n_considered) {
2275+
llama_sample_top_k(ctx, &cur_p, n_probs, 0);
2276+
}
2277+
2278+
if (slot.sparams.temp == 0.0f) {
2279+
// With greedy sampling the probabilities have possibly not been calculated.
2280+
for (size_t i = 0; i < n_probs; ++i) {
2281+
result.probs.push_back({
2282+
cur_p.data[i].id,
2283+
i == 0 ? 1.0f : 0.0f
2284+
});
2285+
}
2286+
} else {
2287+
for (size_t i = 0; i < n_probs; ++i) {
2288+
result.probs.push_back({
2289+
cur_p.data[i].id,
2290+
i >= n_considered ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
2291+
});
2292+
}
2293+
}
22802294
}
22812295

22822296
if (!process_token(result, slot)) {

0 commit comments

Comments
 (0)