From b309fa4676eeb023d005d727d49a4c56f3d3f9ce Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 22 Sep 2023 10:53:36 -0400 Subject: [PATCH] [TextGeneration] Update pipeline inputs to support GenerationConfig (#1250) * add streaming functionality * set back default value * update pipeline.py * update tests * fix tests * update pipeline to use kwargs * add TODO statements * add streaming functionality * Update pipeline inputs to support GenerationConfig * add max_new_tokens * remove todo * update post local test runs * remove todo missed from rebase * refactor to use helpers, update reference to generation config variables * update helper functions to include all generation config handling and overriding * fix tests * update to work with new session commit * update to use config * cleanup --- .../transformers/pipelines/text_generation.py | 192 +++++++++++------- src/deepsparse/transformers/utils/helpers.py | 122 ++++++++++- .../pipelines/test_text_generation.py | 80 ++++---- 3 files changed, 283 insertions(+), 111 deletions(-) diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index 9ee5ba7331..300979d1f0 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -15,6 +15,7 @@ import datetime import logging import os +import pathlib import warnings from enum import Enum from typing import ( @@ -33,6 +34,7 @@ import numpy import onnx from pydantic import BaseModel, Field +from transformers import GenerationConfig from deepsparse import Pipeline from deepsparse.pipeline import DEEPSPARSE_ENGINE @@ -40,10 +42,13 @@ from deepsparse.transformers.pipelines import TransformersPipeline from deepsparse.transformers.utils import DecoderKVCache from deepsparse.transformers.utils.helpers import ( + check_and_return_generation_config, create_causal_mask, initialize_kv_cache_state, + override_config, pad_to_fixed_length, prepends_bos_token, + process_generation_config, repeat_inputs, ) from deepsparse.transformers.utils.timings import TextGenerationTimings @@ -56,6 +61,16 @@ __all__ = ["TextGenerationPipeline"] +class GenerationDefaults: + num_return_sequences = 1 + max_length = 1024 + max_new_tokens = None + output_scores = False + top_k = 0 + top_p = 0.0 + repetition_penalty = 0.0 + + class FinishReason(Enum): STOP = "stop" LENGTH = "length" @@ -70,33 +85,14 @@ class Config: sequences: Union[str, List[str]] = Field( description="The input sequences to generate the text from.", ) - num_generated_predictions: int = Field( - default=1, - description="The number of text generations to create from a single prompt. If " - "the same sequence is given as an input multiple times, the number of generated" - "the number of generated predictins is equivalent to the number of times the " - "the sequence is repeated.", - ) - max_tokens: int = Field( - default=1024, - description="Maximum number of tokens to generate per output sequence. If no " - "value is provided, will default to 1024.", - ) - return_logits: bool = Field( - default=False, - description="A flag that indicates whether to return " - "the logits for the input text sequence and the " - "generated text sequence. ", - ) include_prompt_logits: bool = Field( default=False, description="A flag that indicates whether to return " "the logits for the prompt. If set, prompt_logits are " "`prepended` to the logits for the generated text sequence." - "Note: This flag is only applicable when return_logits " + "Note: This flag is only applicable when output_scores " "is `True`.", ) - fixed_sequences_length: bool = Field( default=False, description="A flag that indicates whether to modify " @@ -126,28 +122,27 @@ class Config: " tokens is generated). Set to `None` to ignore this parameter." " Default is `None`.", ) - top_p: Optional[float] = Field( - default=0.0, - description="Used for filtering generated tokens. Keep the" - " tokens where its cumulative probability is >= top_p" - " Default set to 0.0", - ) - top_k: Optional[int] = Field( - default=0, - description="Used for filtering generated tokens. Keep" - " top_k generated tokens. Default set to 0", - ) + presence_penalty: Optional[float] = Field( default=0.0, description="Penalty applied for generating new token. Any existing" " token results in the subtraction of its corresponding logit value." " Default set to 0.0", ) - frequency_penalty: Optional[float] = Field( - default=0.0, - description="Penalty applied for generating new token. Existing" - " token frequencies summed to subtraction the logit of its" - " corresponding logit value. Default set to 0.0.", + + generation_config: Union[None, str, pathlib.Path, Dict, GenerationConfig] = Field( + default=None, + description="GenerationConfig file consisting of parameters used to control " + "sequences generated for each prompt. The current supported parameters are: " + "max_length, max_new_tokens, num_return_sequences, output_scores, top_p, " + "top_k, repetition_penalty.", + ) + + kwargs: Optional[Dict] = Field( + default=None, + description="Any arguments to override generation_config arguments. Refer to " + "the generation_config argument for a full list of supported variables. Only " + "valid when generation_config is not None.", ) @@ -219,9 +214,10 @@ def __init__( deterministic: bool = True, sampling_temperature: float = 1.0, prompt_sequence_length: int = 64, - sequence_length: int = 512, + sequence_length: int = 1024, force_max_tokens: bool = False, internal_kv_cache: bool = True, + generation_config: Union[str, pathlib.Path, Dict, GenerationConfig] = None, **kwargs, ): kwargs_engine_type = kwargs.get("engine_type", DEEPSPARSE_ENGINE) @@ -271,6 +267,12 @@ def __init__( # auxiliary flag for devs to enable debug mode for the pipeline self._debug = False + self.generation_config = process_generation_config(generation_config) + if self.generation_config: + _LOGGER.info( + "Generation config provided for pipline. This will be used " + "for all inputs unless an input-specific config is provided. " + ) def initialize_engines( self, @@ -410,22 +412,29 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]: :param inputs: the input schema for the pipeline :return: the inputs for the engine """ - if not self.cache_support_enabled and inputs.max_tokens > 1: + generation_config = check_and_return_generation_config( + self.generation_config, inputs.generation_config, GenerationDefaults() + ) + + generation_config = override_config(inputs.kwargs, generation_config) + + self.streaming = inputs.streaming + if not self.cache_support_enabled and generation_config.max_length > 1: raise ValueError( "The model used for inference does not support kv cache. It is " "assumed that it maps from the token sequence to predicted logits." - "Set `max_tokens` to 1 to support that scenario." + "Set `max_length` to 1 to support that scenario." ) - # If the num_generated_predictions > 1, repeat the prompt - # num_generated_predictions times. Also, update the engine so that deterministic + # If the num_return_sequences > 1, repeat the prompt + # num_return_sequences times. Also, update the engine so that deterministic # is set to False. original_inputs = inputs.sequences - if inputs.num_generated_predictions > 1: + if generation_config.num_return_sequences > 1: if isinstance(inputs.sequences, str): inputs.sequences = [inputs.sequences] inputs.sequences = repeat_inputs( - inputs.sequences, inputs.num_generated_predictions + inputs.sequences, generation_config.num_return_sequences ) if self.engine: self.engine.deterministic = False @@ -474,16 +483,14 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]: context = dict( prompts=original_inputs, streaming=inputs.streaming, - num_generated_predictions=inputs.num_generated_predictions, - return_logits=inputs.return_logits, + generation_config=generation_config, include_prompt_logits=inputs.include_prompt_logits, callback=inputs.callback, stop=inputs.stop, - top_p=inputs.top_p, - top_k=inputs.top_k, + top_p=generation_config.top_p, + top_k=generation_config.top_k, presence_penalty=inputs.presence_penalty, - frequency_penalty=inputs.frequency_penalty, - max_tokens=inputs.max_tokens, + frequency_penalty=generation_config.repetition_penalty, ) return engine_input, context @@ -532,22 +539,48 @@ def process_engine_outputs( :return: the output schema for the pipeline """ + def _create_generated_text_output( + sequence: str, + finish_reason: FinishReason = None, + logits: Optional[numpy.array] = None, + ): + if finish_reason: + return GeneratedText( + text=sequence, + score=logits, + finished=True, + finished_reason=finish_reason.value, + ) + return GeneratedText( + text=sequence, + score=logits, + finished=False, + ) + + generation_config = kwargs.get("generation_config") prompts = kwargs.get("prompts") streaming = kwargs.get("streaming") if streaming: return self._stream_engine_outputs(engine_outputs, prompts, kwargs) - generated_tokens, generated_logits, finished_reason, *debug = list( - *engine_outputs - ) + if self._debug: + ( + generated_tokens, + generated_logits, + finished_reason, + kv_cache_state, + total_num_processed_tokens, + ) = list(*engine_outputs) + else: + generated_tokens, generated_logits, finished_reason = list(*engine_outputs) sequences = self.tokenizer.batch_decode( generated_tokens, skip_special_tokens=True ) - logits = generated_logits if kwargs.get("return_logits") else None + logits = generated_logits if generation_config.output_scores else None - num_preds = kwargs.get("num_generated_predictions", 1) + num_preds = generation_config.num_return_sequences finished_reason = [f[0] for f in finished_reason] if logits is not None: @@ -566,7 +599,7 @@ def process_engine_outputs( ) ) - # If the num_generated_predictions > 1, group the generations and return + # If the num_return_sequences > 1, group the generations and return # them as a list of lists where each list consists of the generated # predictions for a given prompt, and all the lists are in the order matching # the order that the prompts were given as inputs. @@ -581,8 +614,7 @@ def process_engine_outputs( created=datetime.datetime.now(), prompts=prompts, generations=generations ) - if debug: - kv_cache_state, total_num_processed_tokens = debug + if self._debug: debug_params = dict( kv_cache_state=kv_cache_state, total_num_processed_tokens=total_num_processed_tokens, @@ -614,6 +646,7 @@ def engine_forward( with self.timer_manager.new_timer_context(total_inference=False) as timer: finished_reason = [] streaming = context.get("streaming") + generation_config = context.get("generation_config") if not self.cache_support_enabled: prompt_logits = self.multitoken_engine(engine_inputs) @@ -640,10 +673,6 @@ def engine_forward( ) token_generator.generate(prompt_logits[-1][0, -1, :]) - # create the generated output - max_tokens = context.get("max_tokens", 0) - max_tokens = max_tokens if max_tokens > 0 else (100 * self.sequence_length) - # last prompt token is the first generated token # add it to generated tokens, and the logits generated_tokens = [token_generator.tokens[-1]] @@ -655,6 +684,15 @@ def engine_forward( callback = context.get("callback") stop = context.get("stop") + max_new_tokens = generation_config.max_new_tokens + if max_new_tokens: + max_tokens = max_new_tokens + len(generated_tokens) + else: + max_tokens = generation_config.max_length + max_tokens = ( + max_tokens if max_tokens > 0 else (100 * self.sequence_length) + ) + with timer.time(TextGenerationTimings.TOKEN_GENERATION): while len(generated_tokens) < max_tokens: with timer.time(TextGenerationTimings.TOKEN_GENERATION_SINGLE): @@ -702,14 +740,19 @@ def engine_forward( ) if not streaming: - returns = ( - numpy.array([generated_tokens]), - numpy.concatenate(generated_logits, axis=1), - finished_reason, - ) - - if self._debug is True: - yield *returns, session + if self._debug: + returns = ( + numpy.array([generated_tokens]), + numpy.concatenate(generated_logits, axis=1), + finished_reason, + [session], + ) + else: + returns = ( + numpy.array([generated_tokens]), + numpy.concatenate(generated_logits, axis=1), + finished_reason, + ) yield returns @@ -921,7 +964,12 @@ def join_engine_outputs( yield outputs else: batch_outputs = [list(*b) for b in batch_outputs] - tokens, logits, finish_reason, *debug = zip(*batch_outputs) + if self._debug: + tokens, logits, finish_reason, debug = zip(*batch_outputs) + else: + tokens, logits, finish_reason = zip(*batch_outputs) + debug = None + if self.cache_support_enabled: # if the model has kv cache, we need to account for # the fact that the predicted outputs may have @@ -969,8 +1017,8 @@ def join_engine_outputs( kv_cache_state, num_processed_tokens, ] - - yield [tokens, logits, finish_reason] + else: + yield [tokens, logits, finish_reason] @staticmethod def causal_mask_input_present(model_path: str) -> bool: diff --git a/src/deepsparse/transformers/utils/helpers.py b/src/deepsparse/transformers/utils/helpers.py index 57d09e309e..0b769c254c 100644 --- a/src/deepsparse/transformers/utils/helpers.py +++ b/src/deepsparse/transformers/utils/helpers.py @@ -11,12 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import logging +import os +import pathlib import uuid from typing import Dict, List, Optional, Tuple, Union import numpy -from transformers import AutoTokenizer +from transformers import AutoTokenizer, GenerationConfig from deepsparse.utils.onnx import CACHE_INPUT_PREFIX, CACHE_OUTPUT_PREFIX @@ -28,6 +31,9 @@ "repeat_inputs", "initialize_kv_cache_state", "prepends_bos_token", + "check_and_return_generation_config", + "override_config", + "process_generation_config", ] _LOGGER = logging.getLogger(__name__) @@ -108,6 +114,120 @@ def generate_session_id() -> str: return session_id +def process_generation_config( + generation_config: Union[None, str, pathlib.Path, Dict, GenerationConfig] +) -> Union[GenerationConfig, None]: + """ + Process and return a GenerationConfig. The function can take in a filepath + pointing to a json consisting of the config values, a dictionary with the config + values, or a loaded GenerationConfig object. If None is given, the defaults are, + the pipeline GenerationConfig is used, if provided. If both are None, default + are used for generation. + + :param generation_config: either a json filepath, dictionary or loaded + GenerationConfig object + + :return: GenerationConfig object or None + + """ + if isinstance(generation_config, GenerationConfig): + return generation_config + + if not generation_config: + return None + + # TODO: move to tmp folder + if isinstance(generation_config, dict): + config_dir = os.getcwd() + config_name = "generation_config.json" + local_config_path = os.path.join(config_dir, config_name) + _LOGGER.info( + "Dictionary provided for the generation config. Creating temporary " + " generation_config.json" + ) + with open(local_config_path, "w") as f: + json.dump(generation_config, f) + + if isinstance(generation_config, (str, pathlib.Path)): + generation_config = pathlib.Path(generation_config) + config_dir = generation_config.parent.absolute() + config_name = generation_config.name + + generation_config = GenerationConfig.from_pretrained(config_dir, config_name) + return generation_config + + +def check_and_return_generation_config( + pipeline_generation_config: [None, str, pathlib.Path, Dict, GenerationConfig], + input_generation_config: [None, str, pathlib.Path, Dict, GenerationConfig], + defaults: "GenerationDefaults", # noqa F821 +) -> Union[GenerationConfig, None]: + """ + Check if an input generation config is provided. If not, check if a pipeline + generation config exists. If neither exists, use the defualt generation configs, + either deespsparse defaults or hugging face defaults. If a pipeline config exists + and an input config exists, use the input config. + + :param pipeline_generation_config: either a json filepath, dictionary or loaded + GenerationConfig object provided by the user during pipeline creation + :param input_generation_config: either a json filepath, dictionary or loaded + GenerationConfig object provided by the user during inference + :param defaults: defaults to use for the GenerationConfig if a config is not + provided during inference or pipeline creation. + + :return: GenerationConfig object or None + + """ + generation_config = process_generation_config(input_generation_config) + if generation_config is None: + if pipeline_generation_config: + generation_config = pipeline_generation_config + else: + _LOGGER.info( + "Input generation config detected. This will override any" + " config provided during pipeline creation." + ) + + if not generation_config: + _LOGGER.info(" No GenerationConfig detected. Using GenerationDefaults values") + generation_config = defaults + return generation_config + + +def override_config( + overrides: Optional[Dict], generation_config: GenerationConfig +) -> GenerationConfig: + """ + Override any generation config properties using the `kwargs` argument in + TextGenerationInput. If None, the generation config is returned unchanged. + If provided, update all attribute stored in the dictionary. An errror will be + raised if the dictionary contains an key which is not a GenerationConfig + attribute. + + :param overrides: dictionary containing GenerationConfig attributes to update + :param generation_config: GenerationConfig to update + + :return: GenerationConfig object + + + """ + if overrides is None: + return generation_config + + for k, v in overrides.items(): + try: + if getattr(generation_config, k): + setattr(generation_config, k, v) + _LOGGER.info(f"Overriding attribute {k} in the generation config") + except AttributeError as exception: + raise AttributeError( + "Argument provided for GenerationConfig is not " + "valid. Refer to the TextGenerationInput for supported attributes. " + ) from exception + + return generation_config + + def repeat_inputs( input_sequences: List[str], num_generated_predictions: int ) -> List[str]: diff --git a/tests/deepsparse/transformers/pipelines/test_text_generation.py b/tests/deepsparse/transformers/pipelines/test_text_generation.py index 2f96d422b2..ad64dae3f3 100644 --- a/tests/deepsparse/transformers/pipelines/test_text_generation.py +++ b/tests/deepsparse/transformers/pipelines/test_text_generation.py @@ -15,6 +15,7 @@ from typing import List, Optional, Tuple import numpy +from transformers import GenerationConfig import pytest from deepsparse import Pipeline @@ -175,11 +176,13 @@ def test_ort_single_token_prefill(self, setup): engine_type="onnxruntime", ) pipeline._debug = True + + config = GenerationConfig( + output_scores=True, max_length=self.num_tokens_generate + ) + output = pipeline( - sequences=self.prompt, - return_logits=True, - include_prompt_logits=True, - max_tokens=self.num_tokens_generate, + sequences=self.prompt, include_prompt_logits=True, generation_config=config ) assert output.total_num_processed_tokens[0] < self.sequence_length self._test_output( @@ -207,11 +210,11 @@ def test_ort_multi_token_prefill(self, setup): engine_type="onnxruntime", ) pipeline._debug = True + config = GenerationConfig( + output_scores=True, max_length=self.num_tokens_generate + ) output = pipeline( - sequences=self.prompt, - return_logits=True, - include_prompt_logits=True, - max_tokens=self.num_tokens_generate, + sequences=self.prompt, include_prompt_logits=True, generation_config=config ) assert output.total_num_processed_tokens[0] < self.sequence_length @@ -241,11 +244,12 @@ def test_ort_generation_after_kv_cache_has_been_filled(self, setup): engine_type="onnxruntime", ) pipeline._debug = True + + config = GenerationConfig( + output_scores=True, max_length=self.num_tokens_generate + ) output = pipeline( - sequences=self.prompt, - return_logits=True, - include_prompt_logits=True, - max_tokens=self.num_tokens_generate, + sequences=self.prompt, include_prompt_logits=True, generation_config=config ) assert output.total_num_processed_tokens[0] > self.sequence_length_short, ( @@ -276,11 +280,11 @@ def test_deepsparse_single_token_prefill(self, setup): internal_kv_cache=self.internal_kv_cache, ) pipeline._debug = True + config = GenerationConfig( + output_scores=True, max_length=self.num_tokens_generate + ) output = pipeline( - sequences=self.prompt, - return_logits=True, - include_prompt_logits=True, - max_tokens=self.num_tokens_generate, + sequences=self.prompt, include_prompt_logits=True, generation_config=config ) assert output.total_num_processed_tokens[0] < self.sequence_length @@ -307,11 +311,11 @@ def test_deepsparse_multi_token_prefill(self, setup): ) pipeline._debug = True + config = GenerationConfig( + output_scores=True, max_length=self.num_tokens_generate + ) output = pipeline( - sequences=self.prompt, - return_logits=True, - include_prompt_logits=True, - max_tokens=self.num_tokens_generate, + sequences=self.prompt, include_prompt_logits=True, generation_config=config ) assert output.total_num_processed_tokens[0] < self.sequence_length @@ -337,11 +341,11 @@ def test_deepsparse_generation_after_kv_cache_has_been_filled(self, setup): internal_kv_cache=self.internal_kv_cache, ) pipeline._debug = True + config = GenerationConfig( + output_scores=True, max_length=self.num_tokens_generate + ) output = pipeline( - sequences=self.prompt, - return_logits=True, - include_prompt_logits=True, - max_tokens=self.num_tokens_generate, + sequences=self.prompt, include_prompt_logits=True, generation_config=config ) assert output.total_num_processed_tokens[0] > self.sequence_length_short, ( @@ -361,18 +365,16 @@ def test_run_same_prompt_multiple_times(self, setup): # Test the scenario, where the same prompt is run multiple times # Every run should produce the same output pipeline = self.get_pipeline() + config = GenerationConfig( + output_scores=True, max_length=self.num_tokens_generate + ) output_1 = pipeline( - sequences=self.prompt, - return_logits=True, - include_prompt_logits=True, - max_tokens=self.num_tokens_generate, + sequences=self.prompt, include_prompt_logits=True, generation_config=config ) + output_2 = pipeline( - sequences=self.prompt, - return_logits=True, - include_prompt_logits=True, - max_tokens=self.num_tokens_generate, + sequences=self.prompt, include_prompt_logits=True, generation_config=config ) assert output_1.generations[0].text == output_2.generations[0].text @@ -387,11 +389,13 @@ def test_run_multiple_prompts_in_parallel(self, setup): # Same two prompts should produce the same output pipeline = self.get_pipeline() + config = GenerationConfig( + output_scores=True, max_length=self.num_tokens_generate + ) output = pipeline( sequences=[self.prompt, self.prompt], - return_logits=True, + generation_config=config, include_prompt_logits=True, - max_tokens=self.num_tokens_generate, ) logits_0 = output.generations[0].score @@ -408,17 +412,17 @@ def test_num_generated_predictions(self, setup): # from the same prompt pipeline = self.get_pipeline() - output_sequences = pipeline( - sequences=[self.prompt], num_generated_predictions=2 + config = GenerationConfig( + num_return_sequences=2, max_length=self.num_tokens_generate ) + output_sequences = pipeline(sequences=[self.prompt], generation_config=config) assert len(output_sequences.generations) == 1 assert len(output_sequences.generations[0]) == 2 output_sequences = pipeline( - sequences=[self.prompt, self.prompt], num_generated_predictions=2 + sequences=[self.prompt, self.prompt], generation_config=config ) - assert len(output_sequences.generations) == 2 for generation in output_sequences.generations: