diff --git a/docs/api-reference.md b/docs/api-reference.md index 562410fe1..47269a6b5 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -28,16 +28,6 @@ High-level Python bindings for llama.cpp. - token_eos show_root_heading: true -::: llama_cpp.LlamaGrammar - options: - members: - - from_string - - from_json_schema - -::: llama_cpp.LlamaCache - options: - show_root_heading: true - ::: llama_cpp.LlamaState options: show_root_heading: true @@ -58,6 +48,13 @@ High-level Python bindings for llama.cpp. options: show_root_heading: true +::: llama_cpp.LlamaGrammar + options: + members: + - from_string + - from_json_schema + + ## Low Level API Low-level Python bindings for llama.cpp using Python's ctypes library. diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index 208de8c2a..66b0b54e1 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -767,4 +767,4 @@ def sample( def accept(self, ctx_main: _LlamaContext, id: int, apply_grammar: bool): if apply_grammar and self.grammar is not None: ctx_main.grammar_accept_token(self.grammar, id) - self.prev.append(id) \ No newline at end of file + self.prev.append(id) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 25abf36cb..972014493 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -19,6 +19,9 @@ import ctypes +import numpy as np +import numpy.typing as npt + from .llama_types import * from .llama_grammar import LlamaGrammar from .llama_cache import ( @@ -30,16 +33,16 @@ import llama_cpp.llama_cpp as llama_cpp import llama_cpp.llama_chat_format as llama_chat_format -import numpy as np -import numpy.typing as npt - -from ._utils import suppress_stdout_stderr from ._internals import ( _LlamaModel, # type: ignore _LlamaContext, # type: ignore _LlamaBatch, # type: ignore _LlamaTokenDataArray, # type: ignore + _LlamaSamplingParams, # type: ignore + _LlamaSamplingContext, # type: ignore + ) +from ._utils import suppress_stdout_stderr class Llama: @@ -468,77 +471,39 @@ def sample( """ assert self._ctx is not None assert self.n_tokens > 0 - last_n_tokens_data = [llama_cpp.llama_token(0)] * max( - 0, self.last_n_tokens_size - self.n_tokens - ) + self._input_ids[-self.last_n_tokens_size :].tolist() - last_n_tokens_size = len(last_n_tokens_data) - n_vocab = self._n_vocab - n_ctx = self._n_ctx - top_k = n_vocab if top_k <= 0 else top_k - last_n_tokens_size = n_ctx if last_n_tokens_size < 0 else last_n_tokens_size - last_n_tokens_data_c = (llama_cpp.llama_token * last_n_tokens_size)( - *last_n_tokens_data - ) + logits: npt.NDArray[np.single] = self._scores[-1, :] if logits_processor is not None: logits[:] = logits_processor(self._input_ids, logits) - nl_logit = logits[self._token_nl] - self._candidates.copy_logits(logits) - self._ctx.sample_repetition_penalties( - candidates=self._candidates, - last_tokens_data=last_n_tokens_data_c, - penalty_last_n=last_n_tokens_size, + sampling_params = _LlamaSamplingParams( + top_k=top_k, + top_p=top_p, + min_p=min_p, + tfs_z=tfs_z, + typical_p=typical_p, + temp=temp, + penalty_last_n=self.last_n_tokens_size, penalty_repeat=repeat_penalty, penalty_freq=frequency_penalty, penalty_present=presence_penalty, + mirostat=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + penalize_nl=penalize_nl, + ) + sampling_context = _LlamaSamplingContext( + params=sampling_params, + grammar=grammar, + ) + sampling_context.prev = list(self.eval_tokens) + id = sampling_context.sample(ctx_main=self._ctx, logits_array=logits) + sampling_context.accept( + ctx_main=self._ctx, + id=id, + apply_grammar=grammar is not None, ) - if not penalize_nl: - self._candidates.candidates.data[self._token_nl].logit = llama_cpp.c_float( - nl_logit - ) - - if grammar is not None: - self._ctx.sample_grammar( - candidates=self._candidates, - grammar=grammar, - ) - - if temp < 0.0: - self._ctx.sample_softmax(candidates=self._candidates) - id = self._candidates.candidates.data[0].id - elif temp == 0.0: - id = self._ctx.sample_token_greedy(candidates=self._candidates) - elif mirostat_mode == 1: - self._ctx.sample_temp(candidates=self._candidates, temp=temp) - id = self._ctx.sample_token_mirostat( - candidates=self._candidates, - tau=mirostat_tau, - eta=mirostat_eta, - mu=2.0 * mirostat_tau, - m=100, - ) - elif mirostat_mode == 2: - self._ctx.sample_temp(candidates=self._candidates, temp=temp) - id = self._ctx.sample_token_mirostat_v2( - candidates=self._candidates, - tau=mirostat_tau, - eta=mirostat_eta, - mu=2.0 * mirostat_tau, - ) - else: - self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1) - self._ctx.sample_tail_free(candidates=self._candidates, z=tfs_z, min_keep=1) - self._ctx.sample_typical( - candidates=self._candidates, p=typical_p, min_keep=1 - ) - self._ctx.sample_top_p(candidates=self._candidates, p=top_p, min_keep=1) - self._ctx.sample_min_p(candidates=self._candidates, p=min_p, min_keep=1) - self._ctx.sample_temp(candidates=self._candidates, temp=temp) - id = self._ctx.sample_token(candidates=self._candidates) - if grammar is not None: - self._ctx.grammar_accept_token(grammar=grammar, token=id) return id def generate( @@ -826,6 +791,51 @@ def logit_bias_processor( if seed is not None: self._ctx.set_rng_seed(seed) + def _completion_stream_response( + text: str, + logprobs_or_none: Optional[CompletionLogprobs] = None, + finish_reason: Optional[Literal["stop", "length"]] = None, + ) -> CreateCompletionStreamResponse: + return { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": model_name, + "choices": [ + { + "text": text, + "index": 0, + "logprobs": logprobs_or_none, + "finish_reason": finish_reason, + } + ], + } + + def _completion_response( + text: str, + finish_reason: Literal["stop", "length"], + logprobs_or_none: Optional[CompletionLogprobs] = None, + ) -> CreateCompletionResponse: + return { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": model_name, + "choices": [ + { + "text": text, + "index": 0, + "logprobs": logprobs_or_none, + "finish_reason": finish_reason, + } + ], + "usage": { + "prompt_tokens": len(prompt_tokens), + "completion_tokens": len(completion_tokens), + "total_tokens": len(prompt_tokens) + len(completion_tokens), + }, + } + finish_reason = "length" multibyte_fix = 0 for token in self.generate( @@ -876,7 +886,7 @@ def logit_bias_processor( break if stream: - remaining_tokens = completion_tokens[returned_tokens:] + remaining_tokens = completion_tokens[returned_tokens:-1] remaining_text = self.detokenize(remaining_tokens) remaining_length = len(remaining_text) @@ -915,10 +925,10 @@ def logit_bias_processor( ) token_offset = len(prompt_tokens) + returned_tokens logits = self._scores[token_offset - 1, :] - current_logprobs = Llama.logits_to_logprobs(logits).tolist() + token_logprob = Llama.logits_to_logprobs(logits).tolist() sorted_logprobs = list( sorted( - zip(current_logprobs, range(len(current_logprobs))), + zip(token_logprob, range(len(token_logprob))), reverse=True, ) ) @@ -928,7 +938,7 @@ def logit_bias_processor( ): logprob for logprob, i in sorted_logprobs[:logprobs] } - top_logprob.update({token_str: current_logprobs[int(token)]}) + top_logprob.update({token_str: token_logprob[int(token)]}) logprobs_or_none = { "tokens": [ self.detokenize([token]).decode( @@ -936,26 +946,14 @@ def logit_bias_processor( ) ], "text_offset": [text_offset], - "token_logprobs": [current_logprobs[int(token)]], + "token_logprobs": [token_logprob[int(token)]], "top_logprobs": [top_logprob], } returned_tokens += 1 - yield { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": model_name, - "choices": [ - { - "text": self.detokenize([token]).decode( - "utf-8", errors="ignore" - ), - "index": 0, - "logprobs": logprobs_or_none, - "finish_reason": None, - } - ], - } + yield _completion_stream_response( + self.detokenize([token]).decode("utf-8", errors="ignore"), + logprobs_or_none, + ) else: while len(remaining_tokens) > 0: decode_success = False @@ -980,20 +978,7 @@ def logit_bias_processor( remaining_tokens = remaining_tokens[i:] returned_tokens += i - yield { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": model_name, - "choices": [ - { - "text": ts, - "index": 0, - "logprobs": None, - "finish_reason": None, - } - ], - } + yield _completion_stream_response(text=ts) if len(completion_tokens) >= max_tokens: text = self.detokenize(completion_tokens) @@ -1034,10 +1019,10 @@ def logit_bias_processor( ) token_offset = len(prompt_tokens) + returned_tokens - 1 logits = self._scores[token_offset, :] - current_logprobs = Llama.logits_to_logprobs(logits).tolist() + token_logprob = Llama.logits_to_logprobs(logits).tolist() sorted_logprobs = list( sorted( - zip(current_logprobs, range(len(current_logprobs))), + zip(token_logprob, range(len(token_logprob))), reverse=True, ) ) @@ -1045,13 +1030,11 @@ def logit_bias_processor( self.detokenize([i]).decode("utf-8", errors="ignore"): logprob for logprob, i in sorted_logprobs[:logprobs] } - top_logprob.update({token_str: current_logprobs[int(token)]}) + top_logprob.update({token_str: token_logprob[int(token)]}) logprobs_or_none = { - "tokens": [ - self.detokenize([token]).decode("utf-8", errors="ignore") - ], + "tokens": [token_str], "text_offset": [text_offset], - "token_logprobs": [current_logprobs[int(token)]], + "token_logprobs": [token_logprob[int(token)]], "top_logprobs": [top_logprob], } @@ -1060,54 +1043,34 @@ def logit_bias_processor( if token_end_position == end - 1: break returned_tokens += 1 - yield { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": model_name, - "choices": [ - { - "text": last_text[ - : len(last_text) - (token_end_position - end) - ].decode("utf-8", errors="ignore"), - "index": 0, - "logprobs": logprobs_or_none, - "finish_reason": None, - } - ], - } - break + yield _completion_stream_response( + text=last_text[ + : len(last_text) - (token_end_position - end) + ].decode("utf-8", errors="ignore"), + logprobs_or_none=logprobs_or_none, + finish_reason=finish_reason, + ) + if self.cache: + if self.verbose: + print( + "Llama._create_completion: cache save", file=sys.stderr + ) + self.cache[ + prompt_tokens + completion_tokens + ] = self.save_state() + print("Llama._create_completion: cache saved", file=sys.stderr) + return returned_tokens += 1 - yield { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": model_name, - "choices": [ - { - "text": self.detokenize([token]).decode( - "utf-8", errors="ignore" - ), - "index": 0, - "logprobs": logprobs_or_none, - "finish_reason": None, - } - ], - } - yield { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": model_name, - "choices": [ - { - "text": "", - "index": 0, - "logprobs": None, - "finish_reason": finish_reason, - } - ], - } + yield _completion_stream_response( + text=self.detokenize([token]).decode("utf-8", errors="ignore"), + logprobs_or_none=logprobs_or_none, + ) + yield _completion_stream_response( + text=self.detokenize(completion_tokens[returned_tokens:]).decode( + "utf-8", errors="ignore" + ), + finish_reason=finish_reason, + ) if self.cache: if self.verbose: print("Llama._create_completion: cache save", file=sys.stderr) @@ -1123,6 +1086,7 @@ def logit_bias_processor( text_str = text.decode("utf-8", errors="ignore") if echo: + assert isinstance(prompt, str) text_str = prompt + text_str if suffix is not None: @@ -1130,19 +1094,16 @@ def logit_bias_processor( logprobs_or_none: Optional[CompletionLogprobs] = None if logprobs is not None: + # Remove leading BOS token + all_tokens = ( + prompt_tokens[1:] + completion_tokens if echo else completion_tokens + ) text_offset = 0 if echo else len(prompt) token_offset = 0 if echo else len(prompt_tokens[1:]) text_offsets: List[int] = [] token_logprobs: List[Optional[float]] = [] tokens: List[str] = [] top_logprobs: List[Optional[Dict[str, float]]] = [] - - if echo: - # Remove leading BOS token - all_tokens = prompt_tokens[1:] + completion_tokens - else: - all_tokens = completion_tokens - all_token_strs = [ self.detokenize([token]).decode("utf-8", errors="ignore") for token in all_tokens @@ -1179,34 +1140,10 @@ def logit_bias_processor( # token_logprobs and top_logprobs are null for # the first token. if echo and len(all_tokens) > 0: - token_logprobs[0] = None - top_logprobs[0] = None - logprobs_or_none = { - "tokens": tokens, - "text_offset": text_offsets, - "token_logprobs": token_logprobs, - "top_logprobs": top_logprobs, - } + logprobs_or_none["token_logprobs"][0] = None + logprobs_or_none["top_logprobs"][0] = None - yield { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": model_name, - "choices": [ - { - "text": text_str, - "index": 0, - "logprobs": logprobs_or_none, - "finish_reason": finish_reason, - } - ], - "usage": { - "prompt_tokens": len(prompt_tokens), - "completion_tokens": len(completion_tokens), - "total_tokens": len(prompt_tokens) + len(completion_tokens), - }, - } + yield _completion_response(text_str, finish_reason, logprobs_or_none) def create_completion( self, diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 0ef7bd4a8..b5d490950 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -260,13 +260,25 @@ def _convert_text_completion_chunks_to_chat( "index": 0, "delta": { "content": chunk["choices"][0]["text"], - } - if chunk["choices"][0]["finish_reason"] is None - else {}, - "finish_reason": chunk["choices"][0]["finish_reason"], + }, + "finish_reason": None, } ], } + if chunk["choices"][0]["finish_reason"] is not None: + yield { + "id": "chat" + chunk["id"], + "model": chunk["model"], + "created": chunk["created"], + "object": "chat.completion.chunk", + "choices": [ + { + "index": 0, + "delta": {}, + "finish_reason": chunk["choices"][0]["finish_reason"], + } + ], + } def _convert_completion_to_chat( diff --git a/pyproject.toml b/pyproject.toml index 806127d89..413097201 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,8 +11,6 @@ license = { text = "MIT" } authors = [ { name = "Andrei Betlen", email = "abetlen@gmail.com" }, ] -# mkdocs-martiral requires "jinja2~=3.0" -# transformers requires "jinja2>=2.11.3" dependencies = [ "typing-extensions>=4.5.0", "numpy>=1.20.0",