Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add beam search #631

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 71 additions & 1 deletion llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Callable,
)
from collections import deque, OrderedDict
from dataclasses import dataclass

import diskcache
import ctypes
Expand Down Expand Up @@ -205,6 +206,44 @@ def __call__(
) -> bool:
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])

# Custom data that is accessible to the beam_search_callback() function.
@dataclass
class beam_search_callback_data:
ctx: llama_cpp.llama_context_p
response_tokens: List[int]

# Used for debugging to view beam states
def beam_view_to_string(ctx, beam_view):
string = f"p({beam_view.p}): "
for i in range(beam_view.n_tokens):
string += llama_cpp.llama_token_get_text(ctx, beam_view.tokens[i]).decode("utf-8")
return string

# One requirement of the callback is that it MUST determine end-of-beam.
def is_at_eob(ctx, tokens, n_tokens) :
return 0 < n_tokens and tokens[n_tokens-1] == llama_cpp.llama_token_eos(ctx);

# beam_search_callback requires a global dictionary to pass data via their object id.
beam_search_dictionary = {}

# beam_search_callback() must flag beams when they reach end-of-sentence.
# TODO: Use stop_sequences.
Copy link
Contributor

@cebtenzzre cebtenzzre Oct 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this TODO the reason I'm seeing this in the debug output? Note the EOS and BOS. The prompt is not code-related, FWIW.

beams[0] p(0.493342787027359): <0x0A></s><s><0x0A>#include▁<iostream><0x0A>#include▁<cmath.h><0x0A>#include▁<c
beams[1] p(0.5066572427749634): <0x0A></s><s><0x0A>#include▁<iostream><0x0A>#include▁<cmath.h><0x0A>#include▁<vector

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you mean exactly by "Note the EOS and BOS."

The TODO note relates to the is_at_eob() function above. Currently, EOB (end-of-beam) is determined by the character llama_cpp.llama_token_eos(ctx). If EOB is to be generalized to user-defined EOB sequences, then this would be the function to add the logic to.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I mean is that </s> (EOS) is generated by the model, but the beam search keeps going (onto BOS, and then it starts making up something unrelated). I think this shouldn't happen.

Copy link
Author

@mattpulver mattpulver Oct 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. To answer your original question, yes, that is exactly what the TODO is talking about.

A good follow-up item would be to add stop_sequences to the class beam_search_callback_data and set them to custom stop sequences (e.g. </s>) when the class is instantiated below. Then pass it to is_at_eob() when called from beam_search_callback().

It may require a bit more logic to accommodate the possibility of stop sequences being split across separate tokens.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear, this is not a custom stop sequence, this is just the regular EOS token (AFAIK), which is rendered this way in the output. You say EOB is determined by llama_token_eos, but that doesn't seem to work for me.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One way to debug this is to modify the line above

 string += llama_cpp.llama_token_get_text(ctx, beam_view.tokens[i]).decode("utf-8")

to something that appends both the numeric token id beam_view.tokens[i] along with the decoded substring. If you're really encountering the llama_token_eos() token, last I checked, it should have token id 2.

Copy link
Contributor

@cebtenzzre cebtenzzre Oct 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops! Sorry, this is one is my fault. llama_token_eos takes model now, not ctx, and I missed that when I merged this PR into my local branch. Unfortunately, ctypes gives no indication of pointer type mismatches. Normally I would use mypy, but it seems as though llama-cpp-python is not tested against it - there are many type errors and other complaints.

def beam_search_callback(callback_data_id, beams_state):
callback_data = beam_search_dictionary[callback_data_id]
for i in range(beams_state.n_beams):
beam_view = beams_state.beam_views[i]
if not beam_view.eob and is_at_eob(callback_data.ctx, beam_view.tokens, beam_view.n_tokens):
beam_view.eob = True; # Flag beams as EOB as required.
# Collect tokens into callback_data.response_tokens
if 0 < beams_state.common_prefix_length:
assert(0 < beams_state.n_beams);
tokens = ctypes.cast(beams_state.beam_views[0].tokens, ctypes.POINTER(ctypes.c_int * beams_state.common_prefix_length)).contents
callback_data.response_tokens.extend(tokens)

# DEBUG print beams and their relative probabilities
#print(f"\n\nCurrent beams (last_call={beams_state.last_call}):\n")
#for i in range(beams_state.n_beams):
# print(f"beams[{i}]", beam_view_to_string(callback_data.ctx,beams_state.beam_views[i]))

class Llama:
"""High-level Python wrapper for a llama.cpp model."""
Expand Down Expand Up @@ -494,6 +533,7 @@ def eval(self, tokens: Sequence[int]):
tokens: The list of tokens to evaluate.
"""
assert self.ctx is not None

n_ctx = self._n_ctx
for i in range(0, len(tokens), self.n_batch):
batch = tokens[i : min(len(tokens), i + self.n_batch)]
Expand Down Expand Up @@ -734,6 +774,7 @@ def generate(
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
grammar: Optional[LlamaGrammar] = None,
beam_width: int = 0,
) -> Generator[int, Optional[Sequence[int]], None]:
"""Create a generator of tokens from a prompt.

Expand Down Expand Up @@ -775,6 +816,26 @@ def generate(
if grammar is not None:
grammar.reset()

if 0 < beam_width:
self.eval(tokens)
callback_data = beam_search_callback_data(self.ctx, [])
beam_search_dictionary[id(callback_data)] = callback_data
callback = llama_cpp.llama_beam_search_callback_fn_t(beam_search_callback)
n_remain = llama_cpp.llama_n_ctx(self.ctx) - self.n_tokens
llama_cpp.llama_beam_search(self.ctx, callback, id(callback_data),
beam_width,
self.n_tokens,
n_remain,
self.n_threads)
beam_search_dictionary.pop(id(callback_data))
# It would be nicer if we could yield from within the callback, but that is impossible.
for token in callback_data.response_tokens:
np.append(self.input_ids, [token])
np.append(self.scores, [0.0])
self.n_tokens += 1
yield token
return

while True:
self.eval(tokens)
token = self.sample(
Expand All @@ -791,6 +852,7 @@ def generate(
logits_processor=logits_processor,
grammar=grammar,
)

if stopping_criteria is not None and stopping_criteria(
self._input_ids, self._scores[-1, :]
):
Expand Down Expand Up @@ -893,6 +955,7 @@ def _create_completion(
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
beam_width: int = 0,
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
assert self.ctx is not None

Expand Down Expand Up @@ -971,6 +1034,7 @@ def _create_completion(
stopping_criteria=stopping_criteria,
logits_processor=logits_processor,
grammar=grammar,
beam_width=beam_width,
):
if token == self._token_eos:
text = self.detokenize(completion_tokens)
Expand Down Expand Up @@ -1354,6 +1418,7 @@ def create_completion(
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
beam_width: int = 0,
) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate text from a prompt.

Expand All @@ -1369,6 +1434,7 @@ def create_completion(
repeat_penalty: The penalty to apply to repeated tokens.
top_k: The top-k value to use for sampling.
stream: Whether to stream the results.
beam_width: Number of beams to use in beam search. 0 disables.

Raises:
ValueError: If the requested tokens exceed the context window.
Expand Down Expand Up @@ -1398,7 +1464,8 @@ def create_completion(
model=model,
stopping_criteria=stopping_criteria,
logits_processor=logits_processor,
grammar=grammar
grammar=grammar,
beam_width=beam_width,
)
if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks
Expand Down Expand Up @@ -1429,6 +1496,7 @@ def __call__(
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
beam_width: int = 0,
) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate text from a prompt.

Expand All @@ -1444,6 +1512,7 @@ def __call__(
repeat_penalty: The penalty to apply to repeated tokens.
top_k: The top-k value to use for sampling.
stream: Whether to stream the results.
beam_width: Number of beams to use in beam search. 0 disables.

Raises:
ValueError: If the requested tokens exceed the context window.
Expand Down Expand Up @@ -1474,6 +1543,7 @@ def __call__(
stopping_criteria=stopping_criteria,
logits_processor=logits_processor,
grammar=grammar,
beam_width=beam_width,
)

def _convert_text_completion_to_chat(
Expand Down
4 changes: 3 additions & 1 deletion llama_cpp/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1456,7 +1456,7 @@ class llama_beams_state(ctypes.Structure):
# LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads);
def llama_beam_search(
ctx: llama_context_p,
callback: "ctypes._CFuncPtr[None, c_void_p, llama_beams_state]", # type: ignore
callback: llama_beam_search_callback_fn_t,
callback_data: c_void_p,
n_beams: Union[c_size_t, int],
n_past: Union[c_int, int],
Expand All @@ -1467,6 +1467,8 @@ def llama_beam_search(
ctx, callback, callback_data, n_beams, n_past, n_predict, n_threads
)

_lib.llama_beam_search.argtypes = [llama_context_p, llama_beam_search_callback_fn_t, c_void_p, c_size_t, c_int, c_int, c_int]
_lib.llama_beam_search.restype = None

# Performance information

Expand Down
3 changes: 2 additions & 1 deletion llama_cpp/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class Settings(BaseSettings):
default=True, description="if true, use experimental mul_mat_q kernels"
)
f16_kv: bool = Field(default=True, description="Whether to use f16 key/value.")
logits_all: bool = Field(default=True, description="Whether to return logits.")
logits_all: bool = Field(default=False, description="Whether to return logits.")
vocab_only: bool = Field(
default=False, description="Whether to only return the vocabulary."
)
Expand Down Expand Up @@ -549,6 +549,7 @@ class CreateCompletionRequest(BaseModel):
top_k: int = top_k_field
repeat_penalty: float = repeat_penalty_field
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
beam_width: int = 0

model_config = {
"json_schema_extra": {
Expand Down