diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 0c0d48fa7..ca83901f1 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -17,6 +17,7 @@ Callable, ) from collections import deque, OrderedDict +from dataclasses import dataclass import diskcache import ctypes @@ -205,6 +206,44 @@ def __call__( ) -> bool: return any([stopping_criteria(input_ids, logits) for stopping_criteria in self]) +# Custom data that is accessible to the beam_search_callback() function. +@dataclass +class beam_search_callback_data: + ctx: llama_cpp.llama_context_p + response_tokens: List[int] + +# Used for debugging to view beam states +def beam_view_to_string(ctx, beam_view): + string = f"p({beam_view.p}): " + for i in range(beam_view.n_tokens): + string += llama_cpp.llama_token_get_text(ctx, beam_view.tokens[i]).decode("utf-8") + return string + +# One requirement of the callback is that it MUST determine end-of-beam. +def is_at_eob(ctx, tokens, n_tokens) : + return 0 < n_tokens and tokens[n_tokens-1] == llama_cpp.llama_token_eos(ctx); + +# beam_search_callback requires a global dictionary to pass data via their object id. +beam_search_dictionary = {} + +# beam_search_callback() must flag beams when they reach end-of-sentence. +# TODO: Use stop_sequences. +def beam_search_callback(callback_data_id, beams_state): + callback_data = beam_search_dictionary[callback_data_id] + for i in range(beams_state.n_beams): + beam_view = beams_state.beam_views[i] + if not beam_view.eob and is_at_eob(callback_data.ctx, beam_view.tokens, beam_view.n_tokens): + beam_view.eob = True; # Flag beams as EOB as required. + # Collect tokens into callback_data.response_tokens + if 0 < beams_state.common_prefix_length: + assert(0 < beams_state.n_beams); + tokens = ctypes.cast(beams_state.beam_views[0].tokens, ctypes.POINTER(ctypes.c_int * beams_state.common_prefix_length)).contents + callback_data.response_tokens.extend(tokens) + + # DEBUG print beams and their relative probabilities + #print(f"\n\nCurrent beams (last_call={beams_state.last_call}):\n") + #for i in range(beams_state.n_beams): + # print(f"beams[{i}]", beam_view_to_string(callback_data.ctx,beams_state.beam_views[i])) class Llama: """High-level Python wrapper for a llama.cpp model.""" @@ -494,6 +533,7 @@ def eval(self, tokens: Sequence[int]): tokens: The list of tokens to evaluate. """ assert self.ctx is not None + n_ctx = self._n_ctx for i in range(0, len(tokens), self.n_batch): batch = tokens[i : min(len(tokens), i + self.n_batch)] @@ -734,6 +774,7 @@ def generate( logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, grammar: Optional[LlamaGrammar] = None, + beam_width: int = 0, ) -> Generator[int, Optional[Sequence[int]], None]: """Create a generator of tokens from a prompt. @@ -775,6 +816,26 @@ def generate( if grammar is not None: grammar.reset() + if 0 < beam_width: + self.eval(tokens) + callback_data = beam_search_callback_data(self.ctx, []) + beam_search_dictionary[id(callback_data)] = callback_data + callback = llama_cpp.llama_beam_search_callback_fn_t(beam_search_callback) + n_remain = llama_cpp.llama_n_ctx(self.ctx) - self.n_tokens + llama_cpp.llama_beam_search(self.ctx, callback, id(callback_data), + beam_width, + self.n_tokens, + n_remain, + self.n_threads) + beam_search_dictionary.pop(id(callback_data)) + # It would be nicer if we could yield from within the callback, but that is impossible. + for token in callback_data.response_tokens: + np.append(self.input_ids, [token]) + np.append(self.scores, [0.0]) + self.n_tokens += 1 + yield token + return + while True: self.eval(tokens) token = self.sample( @@ -791,6 +852,7 @@ def generate( logits_processor=logits_processor, grammar=grammar, ) + if stopping_criteria is not None and stopping_criteria( self._input_ids, self._scores[-1, :] ): @@ -893,6 +955,7 @@ def _create_completion( stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, + beam_width: int = 0, ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: assert self.ctx is not None @@ -971,6 +1034,7 @@ def _create_completion( stopping_criteria=stopping_criteria, logits_processor=logits_processor, grammar=grammar, + beam_width=beam_width, ): if token == self._token_eos: text = self.detokenize(completion_tokens) @@ -1354,6 +1418,7 @@ def create_completion( stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, + beam_width: int = 0, ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt. @@ -1369,6 +1434,7 @@ def create_completion( repeat_penalty: The penalty to apply to repeated tokens. top_k: The top-k value to use for sampling. stream: Whether to stream the results. + beam_width: Number of beams to use in beam search. 0 disables. Raises: ValueError: If the requested tokens exceed the context window. @@ -1398,7 +1464,8 @@ def create_completion( model=model, stopping_criteria=stopping_criteria, logits_processor=logits_processor, - grammar=grammar + grammar=grammar, + beam_width=beam_width, ) if stream: chunks: Iterator[CompletionChunk] = completion_or_chunks @@ -1429,6 +1496,7 @@ def __call__( stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, + beam_width: int = 0, ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt. @@ -1444,6 +1512,7 @@ def __call__( repeat_penalty: The penalty to apply to repeated tokens. top_k: The top-k value to use for sampling. stream: Whether to stream the results. + beam_width: Number of beams to use in beam search. 0 disables. Raises: ValueError: If the requested tokens exceed the context window. @@ -1474,6 +1543,7 @@ def __call__( stopping_criteria=stopping_criteria, logits_processor=logits_processor, grammar=grammar, + beam_width=beam_width, ) def _convert_text_completion_to_chat( diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 53298df15..21c8da759 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -1456,7 +1456,7 @@ class llama_beams_state(ctypes.Structure): # LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads); def llama_beam_search( ctx: llama_context_p, - callback: "ctypes._CFuncPtr[None, c_void_p, llama_beams_state]", # type: ignore + callback: llama_beam_search_callback_fn_t, callback_data: c_void_p, n_beams: Union[c_size_t, int], n_past: Union[c_int, int], @@ -1467,6 +1467,8 @@ def llama_beam_search( ctx, callback, callback_data, n_beams, n_past, n_predict, n_threads ) +_lib.llama_beam_search.argtypes = [llama_context_p, llama_beam_search_callback_fn_t, c_void_p, c_size_t, c_int, c_int, c_int] +_lib.llama_beam_search.restype = None # Performance information diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 053c3081b..27ece16aa 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -70,7 +70,7 @@ class Settings(BaseSettings): default=True, description="if true, use experimental mul_mat_q kernels" ) f16_kv: bool = Field(default=True, description="Whether to use f16 key/value.") - logits_all: bool = Field(default=True, description="Whether to return logits.") + logits_all: bool = Field(default=False, description="Whether to return logits.") vocab_only: bool = Field( default=False, description="Whether to only return the vocabulary." ) @@ -549,6 +549,7 @@ class CreateCompletionRequest(BaseModel): top_k: int = top_k_field repeat_penalty: float = repeat_penalty_field logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None) + beam_width: int = 0 model_config = { "json_schema_extra": {