|
| 1 | +import copy |
| 2 | +from pathlib import Path |
| 3 | +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, BinaryIO |
| 4 | +from aiconfig.util.config_utils import get_api_key_from_environment |
| 5 | + |
| 6 | +from aiconfig_extension_hugging_face.local_inference.util import get_hf_model |
| 7 | +from aiconfig.callback import CallbackEvent |
| 8 | + |
| 9 | +from aiconfig import InferenceOptions, ModelParser |
| 10 | +from aiconfig.schema import ( |
| 11 | + Attachment, |
| 12 | + AttachmentDataWithStringValue, |
| 13 | + ExecuteResult, |
| 14 | + Output, |
| 15 | + Prompt, |
| 16 | + PromptInput, |
| 17 | + PromptMetadata, |
| 18 | +) |
| 19 | + |
| 20 | +# HuggingFace API imports |
| 21 | +from huggingface_hub import InferenceClient |
| 22 | + |
| 23 | +if TYPE_CHECKING: |
| 24 | + from aiconfig import AIConfigRuntime |
| 25 | + |
| 26 | + |
| 27 | +# Step 1: define Helpers |
| 28 | +def refine_completion_params(model_settings: dict[Any, Any]) -> dict[str, Any]: |
| 29 | + """ |
| 30 | + Refines the completion params for the HF Automatic Speech Recognition api. Removes any unsupported params. |
| 31 | + See https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/inference/_client.py#L302 |
| 32 | + for supported params. |
| 33 | +
|
| 34 | + Note: The inference endpoint does not support all the same params as transformers' pipelines() |
| 35 | + """ |
| 36 | + |
| 37 | + supported_keys = { |
| 38 | + "model", |
| 39 | + } |
| 40 | + |
| 41 | + completion_data = {} |
| 42 | + for key in model_settings: |
| 43 | + if key.lower() in supported_keys: |
| 44 | + completion_data[key.lower()] = model_settings[key] |
| 45 | + |
| 46 | + return completion_data |
| 47 | + |
| 48 | + |
| 49 | +def construct_output(response: str) -> Output: |
| 50 | + """ |
| 51 | + Constructs an output from the response of the HF Inference Endpoint. |
| 52 | +
|
| 53 | + Response only contains output in the form of a text. |
| 54 | + See https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/inference/_client.py#L302 |
| 55 | + """ |
| 56 | + output = ExecuteResult( |
| 57 | + output_type="execute_result", |
| 58 | + data=response, |
| 59 | + execution_count=0, |
| 60 | + metadata={}, |
| 61 | + ) |
| 62 | + return output |
| 63 | + |
| 64 | + |
| 65 | +class HuggingFaceAutomaticSpeechRecognitionRemoteInference(ModelParser): |
| 66 | + """ |
| 67 | + Model Parser for HuggingFace ASR (Automatic Speech Recognition) models. |
| 68 | +
|
| 69 | + Uses the Inference Endpoint for inference. |
| 70 | + """ |
| 71 | + |
| 72 | + def __init__(self, model_id: str = None, use_api_token: bool = False): |
| 73 | + """ |
| 74 | + Returns: |
| 75 | + HuggingFaceAutomaticSpeechRecognitionRemoteInference |
| 76 | +
|
| 77 | + Usage: |
| 78 | + 1. Create a new model parser object with the model ID of the model to use. |
| 79 | + parser = HuggingFaceAutomaticSpeechRecognitionRemoteInference() |
| 80 | + 2. Add the model parser to the registry. |
| 81 | + config.register_model_parser(parser) |
| 82 | + """ |
| 83 | + super().__init__() |
| 84 | + |
| 85 | + token = None |
| 86 | + |
| 87 | + if use_api_token: |
| 88 | + # You are allowed to use Hugging Face for a bit before you get |
| 89 | + # rate limited, in which case you will receive a clear error |
| 90 | + token = get_api_key_from_environment( |
| 91 | + "HUGGING_FACE_API_TOKEN", required=False |
| 92 | + ).unwrap() |
| 93 | + |
| 94 | + self.client = InferenceClient(model_id, token=token) |
| 95 | + |
| 96 | + def id(self) -> str: |
| 97 | + """ |
| 98 | + Returns an identifier for the Model Parser |
| 99 | + """ |
| 100 | + return "HuggingFaceAutomaticSpeechRecognitionRemoteInference" |
| 101 | + |
| 102 | + async def serialize( |
| 103 | + self, |
| 104 | + prompt_name: str, |
| 105 | + data: Any, |
| 106 | + ai_config: "AIConfigRuntime", |
| 107 | + parameters: Optional[dict[Any, Any]] = None, |
| 108 | + **kwargs, |
| 109 | + ) -> list[Prompt]: |
| 110 | + """ |
| 111 | + Defines how a prompt and model inference settings get serialized in the .aiconfig. |
| 112 | + Assume input in the form of input(s) being passed into an already constructed pipeline. |
| 113 | +
|
| 114 | + Args: |
| 115 | + prompt (str): The prompt to be serialized. |
| 116 | + data (Any): Model-specific inference settings to be serialized. |
| 117 | + ai_config (AIConfigRuntime): The AIConfig Runtime. |
| 118 | + parameters (Dict[str, Any], optional): Model-specific parameters. Defaults to None. |
| 119 | +
|
| 120 | + Returns: |
| 121 | + str: Serialized representation of the prompt and inference settings. |
| 122 | + """ |
| 123 | + await ai_config.callback_manager.run_callbacks( |
| 124 | + CallbackEvent( |
| 125 | + "on_serialize_start", |
| 126 | + __name__, |
| 127 | + { |
| 128 | + "prompt_name": prompt_name, |
| 129 | + "data": data, |
| 130 | + "parameters": parameters, |
| 131 | + "kwargs": kwargs, |
| 132 | + }, |
| 133 | + ) |
| 134 | + ) |
| 135 | + |
| 136 | + # assume data is completion params for HF automatic speech recognition inference api |
| 137 | + data = copy.deepcopy(data) |
| 138 | + |
| 139 | + # For now, support Path as path to local audio file or str as uri only |
| 140 | + # TODO: Support bytes and BinaryIO |
| 141 | + audio: Union[Path, str, bytes, BinaryIO] = data["audio"] |
| 142 | + |
| 143 | + # In some cases (e.g. for uri), we can't determine the mimetype subtype |
| 144 | + # without loading it, so just use the discrete type by default |
| 145 | + mime_type = "audio" |
| 146 | + |
| 147 | + if isinstance(audio, Path): |
| 148 | + data["audio"] = str(audio.as_uri()) |
| 149 | + # Assume the audio is saved with extension matching mimetype |
| 150 | + file_extension = audio.suffix.lower()[1:] |
| 151 | + mime_type = f"audio/{file_extension}" |
| 152 | + elif isinstance(audio, str): |
| 153 | + # Assume it's a uri |
| 154 | + pass |
| 155 | + else: |
| 156 | + raise ValueError( |
| 157 | + f"Invalid audio type. Expected Path or str, got {type(audio)}" |
| 158 | + ) |
| 159 | + |
| 160 | + attachment_data = AttachmentDataWithStringValue( |
| 161 | + kind="file_uri", value=data["audio"] |
| 162 | + ) |
| 163 | + attachments: List[Attachment] = [ |
| 164 | + Attachment(data=attachment_data, mime_type=mime_type) |
| 165 | + ] |
| 166 | + prompt_input = PromptInput(attachments=attachments) |
| 167 | + |
| 168 | + # audio is handled, remove from data |
| 169 | + data.pop("audio", None) |
| 170 | + |
| 171 | + prompts: list[Prompt] = [] |
| 172 | + |
| 173 | + model_metadata = ai_config.get_model_metadata(data, self.id()) |
| 174 | + prompt = Prompt( |
| 175 | + name=prompt_name, |
| 176 | + input=prompt_input, |
| 177 | + metadata=PromptMetadata( |
| 178 | + model=model_metadata, parameters=parameters, **kwargs |
| 179 | + ), |
| 180 | + ) |
| 181 | + |
| 182 | + prompts.append(prompt) |
| 183 | + |
| 184 | + await ai_config.callback_manager.run_callbacks( |
| 185 | + CallbackEvent( |
| 186 | + "on_serialize_complete", __name__, {"result": prompts} |
| 187 | + ) |
| 188 | + ) |
| 189 | + |
| 190 | + return prompts |
| 191 | + |
| 192 | + async def deserialize( |
| 193 | + self, |
| 194 | + prompt: Prompt, |
| 195 | + aiconfig: "AIConfigRuntime", |
| 196 | + params: Optional[dict[Any, Any]] = {}, |
| 197 | + ) -> dict[Any, Any]: |
| 198 | + await aiconfig.callback_manager.run_callbacks( |
| 199 | + CallbackEvent( |
| 200 | + "on_deserialize_start", |
| 201 | + __name__, |
| 202 | + {"prompt": prompt, "params": params}, |
| 203 | + ) |
| 204 | + ) |
| 205 | + |
| 206 | + # Build Completion data |
| 207 | + model_settings = self.get_model_settings(prompt, aiconfig) |
| 208 | + |
| 209 | + completion_data = refine_completion_params(model_settings) |
| 210 | + |
| 211 | + # HF Python API supports input types of bytes, file path, and a dict containing raw sampled audio. Suports only one input. |
| 212 | + # See https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/inference/_client.py#L302 |
| 213 | + # For now, support multiple or single uri's as input |
| 214 | + audio_input = validate_and_retrieve_audio_from_attachments(prompt) |
| 215 | + |
| 216 | + completion_data["audio"] = audio_input |
| 217 | + |
| 218 | + await aiconfig.callback_manager.run_callbacks( |
| 219 | + CallbackEvent( |
| 220 | + "on_deserialize_complete", |
| 221 | + __name__, |
| 222 | + {"output": completion_data}, |
| 223 | + ) |
| 224 | + ) |
| 225 | + return completion_data |
| 226 | + |
| 227 | + async def run( |
| 228 | + self, |
| 229 | + prompt: Prompt, |
| 230 | + aiconfig: "AIConfigRuntime", |
| 231 | + options: InferenceOptions, |
| 232 | + parameters: Dict[str, Any], |
| 233 | + **kwargs, |
| 234 | + ) -> list[Output]: |
| 235 | + """ |
| 236 | + Invoked to run a prompt in the .aiconfig. This method should perform |
| 237 | + the actual model inference based on the provided prompt and inference settings. |
| 238 | +
|
| 239 | + Args: |
| 240 | + prompt (str): The input prompt. |
| 241 | + inference_settings (dict): Model-specific inference settings. |
| 242 | +
|
| 243 | + Returns: |
| 244 | + InferenceResponse: The response from the model. |
| 245 | + """ |
| 246 | + sanitized_options = copy.deepcopy(options) |
| 247 | + run_override_api_token = getattr(sanitized_options, "api_token", None) |
| 248 | + # Redact api token from logs if it exists |
| 249 | + if run_override_api_token: |
| 250 | + setattr(sanitized_options, "api_token", "hf_********") |
| 251 | + await aiconfig.callback_manager.run_callbacks( |
| 252 | + CallbackEvent( |
| 253 | + "on_run_start", |
| 254 | + __name__, |
| 255 | + { |
| 256 | + "prompt": prompt, |
| 257 | + "options": options, |
| 258 | + "parameters": parameters, |
| 259 | + }, |
| 260 | + ) |
| 261 | + ) |
| 262 | + |
| 263 | + completion_data = await self.deserialize(prompt, aiconfig, parameters) |
| 264 | + |
| 265 | + # If api token is provided in the options, use it for the client |
| 266 | + client = self.client |
| 267 | + if run_override_api_token: |
| 268 | + client = InferenceClient( |
| 269 | + self.client.model, token=run_override_api_token |
| 270 | + ) |
| 271 | + |
| 272 | + response = client.automatic_speech_recognition(**completion_data) |
| 273 | + |
| 274 | + # HF Automatic Speech Recognition api doesn't support multiple outputs. Expect only one output. |
| 275 | + # Output spec: response is str |
| 276 | + outputs = [construct_output(response)] |
| 277 | + prompt.outputs = outputs |
| 278 | + |
| 279 | + await aiconfig.callback_manager.run_callbacks( |
| 280 | + CallbackEvent( |
| 281 | + "on_run_complete", __name__, {"result": prompt.outputs} |
| 282 | + ) |
| 283 | + ) |
| 284 | + return prompt.outputs |
| 285 | + |
| 286 | + def get_output_text( |
| 287 | + self, |
| 288 | + prompt: Prompt, |
| 289 | + aiconfig: "AIConfigRuntime", |
| 290 | + output: Optional[Output] = None, |
| 291 | + ) -> str: |
| 292 | + if output is None: |
| 293 | + output = aiconfig.get_latest_output(prompt) |
| 294 | + |
| 295 | + if output is None: |
| 296 | + return "" |
| 297 | + |
| 298 | + if output.output_type == "execute_result": |
| 299 | + output_data = output.data |
| 300 | + if isinstance(output_data, str): |
| 301 | + return output_data |
| 302 | + |
| 303 | + else: |
| 304 | + raise ValueError( |
| 305 | + f"Invalid output data type {type(output_data)} for prompt '{prompt.name}'. Expected string." |
| 306 | + ) |
| 307 | + return "" |
| 308 | + |
| 309 | + |
| 310 | +def validate_attachment_type_is_audio(attachment: Attachment): |
| 311 | + if not hasattr(attachment, "mime_type"): |
| 312 | + raise ValueError( |
| 313 | + f"Attachment has no mime type. Specify the audio mimetype in the aiconfig" |
| 314 | + ) |
| 315 | + |
| 316 | + if not attachment.mime_type.startswith("audio"): |
| 317 | + raise ValueError( |
| 318 | + f"Invalid attachment mimetype {attachment.mime_type}. Expected audio mimetype." |
| 319 | + ) |
| 320 | + |
| 321 | + |
| 322 | +def validate_and_retrieve_audio_from_attachments(prompt: Prompt) -> str: |
| 323 | + """ |
| 324 | + Retrieves the audio uri or base64 from each attachment in the prompt input. |
| 325 | +
|
| 326 | + Throws an exception if |
| 327 | + - attachment is not audio |
| 328 | + - attachment data is not a uri |
| 329 | + - no attachments are found |
| 330 | + - operation fails for any reason |
| 331 | + - more than one audio attachment is found. Inference api only supports one audio input. |
| 332 | + """ |
| 333 | + |
| 334 | + if not isinstance(prompt.input, PromptInput): |
| 335 | + raise ValueError( |
| 336 | + f"Prompt input is of type {type(prompt.input) }. Please specify a PromptInput with attachments for prompt {prompt.name}." |
| 337 | + ) |
| 338 | + |
| 339 | + if prompt.input.attachments is None or len(prompt.input.attachments) == 0: |
| 340 | + raise ValueError( |
| 341 | + f"No attachments found in input for prompt {prompt.name}. Please add an audio attachment to the prompt input." |
| 342 | + ) |
| 343 | + |
| 344 | + audio_input: str | None = None |
| 345 | + |
| 346 | + if len(prompt.input.attachments) > 1: |
| 347 | + raise ValueError( |
| 348 | + "Multiple audio inputs are not supported for the HF Automatic Speech Recognition Inference api. Please specify a single audio input attachment for Prompt: {prompt.name}." |
| 349 | + ) |
| 350 | + |
| 351 | + attachment = prompt.input.attachments[0] |
| 352 | + |
| 353 | + validate_attachment_type_is_audio(attachment) |
| 354 | + |
| 355 | + if not isinstance(attachment.data, AttachmentDataWithStringValue): |
| 356 | + raise ValueError( |
| 357 | + f"""Attachment data must be of type `AttachmentDataWithStringValue` with a `kind` and `value` field. |
| 358 | + Please specify a uri for the audio attachment in prompt {prompt.name}.""" |
| 359 | + ) |
| 360 | + |
| 361 | + audio_input = attachment.data.value |
| 362 | + return audio_input |
0 commit comments