Skip to content

Commit 50ac544

Browse files
[hf inference] ASR remote inference model parser impl (#1020)
[hf inference] ASR remote inference model parser impl Implementation of the HuggingFaceAutomaticSpeechRecognition Model parser using the inference endpoint to run inference. Python API takes in bytes as well as path, skip binary for now. Very similar to #1018 ## Testplan <img width="1000" alt="Screenshot 2024-01-24 at 10 37 05 PM" src="https://github.com/lastmile-ai/aiconfig/assets/141073967/808956ce-e3be-4528-9f34-c8d31d704ddb"> 1. Temporarily add model parser to Gradio Cookbook model parser registry. ``` asr = HuggingFaceAutomaticSpeechRecognitionRemoteInference() AIConfigRuntime.register_model_parser( asr, asr.id() ) ``` 2. run AIConfig Edit on Gradio example `python3 -m 'aiconfig.scripts.aiconfig_cli' edit --aiconfig-path=cookbooks/Gradio/huggingface.aiconfig.json --parsers-module-path=cookbooks/Gradio/hf_model_parsers.py --server-mode=debug_servers`
2 parents c6bd7c5 + 02e43fd commit 50ac544

File tree

2 files changed

+367
-0
lines changed

2 files changed

+367
-0
lines changed

extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
from .remote_inference_client.text_translation import (
3333
HuggingFaceTextTranslationRemoteInference,
3434
)
35+
from .remote_inference_client.automatic_speech_recognition import (
36+
HuggingFaceAutomaticSpeechRecognitionRemoteInference,
37+
)
3538

3639
UTILS = [get_hf_model]
3740

@@ -44,12 +47,14 @@
4447
"HuggingFaceTextSummarizationTransformer",
4548
"HuggingFaceTextTranslationTransformer",
4649
]
50+
4751
REMOTE_INFERENCE_CLASSES = [
4852
"HuggingFaceImage2TextRemoteInference",
4953
"HuggingFaceText2ImageRemoteInference",
5054
"HuggingFaceText2SpeechRemoteInference",
5155
"HuggingFaceTextGenerationRemoteInference",
5256
"HuggingFaceTextSummarizationRemoteInference",
5357
"HuggingFaceTextTranslationRemoteInference",
58+
"HuggingFaceAutomaticSpeechRecognitionRemoteInference"
5459
]
5560
__ALL__ = LOCAL_INFERENCE_CLASSES + REMOTE_INFERENCE_CLASSES + UTILS
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,362 @@
1+
import copy
2+
from pathlib import Path
3+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, BinaryIO
4+
from aiconfig.util.config_utils import get_api_key_from_environment
5+
6+
from aiconfig_extension_hugging_face.local_inference.util import get_hf_model
7+
from aiconfig.callback import CallbackEvent
8+
9+
from aiconfig import InferenceOptions, ModelParser
10+
from aiconfig.schema import (
11+
Attachment,
12+
AttachmentDataWithStringValue,
13+
ExecuteResult,
14+
Output,
15+
Prompt,
16+
PromptInput,
17+
PromptMetadata,
18+
)
19+
20+
# HuggingFace API imports
21+
from huggingface_hub import InferenceClient
22+
23+
if TYPE_CHECKING:
24+
from aiconfig import AIConfigRuntime
25+
26+
27+
# Step 1: define Helpers
28+
def refine_completion_params(model_settings: dict[Any, Any]) -> dict[str, Any]:
29+
"""
30+
Refines the completion params for the HF Automatic Speech Recognition api. Removes any unsupported params.
31+
See https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/inference/_client.py#L302
32+
for supported params.
33+
34+
Note: The inference endpoint does not support all the same params as transformers' pipelines()
35+
"""
36+
37+
supported_keys = {
38+
"model",
39+
}
40+
41+
completion_data = {}
42+
for key in model_settings:
43+
if key.lower() in supported_keys:
44+
completion_data[key.lower()] = model_settings[key]
45+
46+
return completion_data
47+
48+
49+
def construct_output(response: str) -> Output:
50+
"""
51+
Constructs an output from the response of the HF Inference Endpoint.
52+
53+
Response only contains output in the form of a text.
54+
See https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/inference/_client.py#L302
55+
"""
56+
output = ExecuteResult(
57+
output_type="execute_result",
58+
data=response,
59+
execution_count=0,
60+
metadata={},
61+
)
62+
return output
63+
64+
65+
class HuggingFaceAutomaticSpeechRecognitionRemoteInference(ModelParser):
66+
"""
67+
Model Parser for HuggingFace ASR (Automatic Speech Recognition) models.
68+
69+
Uses the Inference Endpoint for inference.
70+
"""
71+
72+
def __init__(self, model_id: str = None, use_api_token: bool = False):
73+
"""
74+
Returns:
75+
HuggingFaceAutomaticSpeechRecognitionRemoteInference
76+
77+
Usage:
78+
1. Create a new model parser object with the model ID of the model to use.
79+
parser = HuggingFaceAutomaticSpeechRecognitionRemoteInference()
80+
2. Add the model parser to the registry.
81+
config.register_model_parser(parser)
82+
"""
83+
super().__init__()
84+
85+
token = None
86+
87+
if use_api_token:
88+
# You are allowed to use Hugging Face for a bit before you get
89+
# rate limited, in which case you will receive a clear error
90+
token = get_api_key_from_environment(
91+
"HUGGING_FACE_API_TOKEN", required=False
92+
).unwrap()
93+
94+
self.client = InferenceClient(model_id, token=token)
95+
96+
def id(self) -> str:
97+
"""
98+
Returns an identifier for the Model Parser
99+
"""
100+
return "HuggingFaceAutomaticSpeechRecognitionRemoteInference"
101+
102+
async def serialize(
103+
self,
104+
prompt_name: str,
105+
data: Any,
106+
ai_config: "AIConfigRuntime",
107+
parameters: Optional[dict[Any, Any]] = None,
108+
**kwargs,
109+
) -> list[Prompt]:
110+
"""
111+
Defines how a prompt and model inference settings get serialized in the .aiconfig.
112+
Assume input in the form of input(s) being passed into an already constructed pipeline.
113+
114+
Args:
115+
prompt (str): The prompt to be serialized.
116+
data (Any): Model-specific inference settings to be serialized.
117+
ai_config (AIConfigRuntime): The AIConfig Runtime.
118+
parameters (Dict[str, Any], optional): Model-specific parameters. Defaults to None.
119+
120+
Returns:
121+
str: Serialized representation of the prompt and inference settings.
122+
"""
123+
await ai_config.callback_manager.run_callbacks(
124+
CallbackEvent(
125+
"on_serialize_start",
126+
__name__,
127+
{
128+
"prompt_name": prompt_name,
129+
"data": data,
130+
"parameters": parameters,
131+
"kwargs": kwargs,
132+
},
133+
)
134+
)
135+
136+
# assume data is completion params for HF automatic speech recognition inference api
137+
data = copy.deepcopy(data)
138+
139+
# For now, support Path as path to local audio file or str as uri only
140+
# TODO: Support bytes and BinaryIO
141+
audio: Union[Path, str, bytes, BinaryIO] = data["audio"]
142+
143+
# In some cases (e.g. for uri), we can't determine the mimetype subtype
144+
# without loading it, so just use the discrete type by default
145+
mime_type = "audio"
146+
147+
if isinstance(audio, Path):
148+
data["audio"] = str(audio.as_uri())
149+
# Assume the audio is saved with extension matching mimetype
150+
file_extension = audio.suffix.lower()[1:]
151+
mime_type = f"audio/{file_extension}"
152+
elif isinstance(audio, str):
153+
# Assume it's a uri
154+
pass
155+
else:
156+
raise ValueError(
157+
f"Invalid audio type. Expected Path or str, got {type(audio)}"
158+
)
159+
160+
attachment_data = AttachmentDataWithStringValue(
161+
kind="file_uri", value=data["audio"]
162+
)
163+
attachments: List[Attachment] = [
164+
Attachment(data=attachment_data, mime_type=mime_type)
165+
]
166+
prompt_input = PromptInput(attachments=attachments)
167+
168+
# audio is handled, remove from data
169+
data.pop("audio", None)
170+
171+
prompts: list[Prompt] = []
172+
173+
model_metadata = ai_config.get_model_metadata(data, self.id())
174+
prompt = Prompt(
175+
name=prompt_name,
176+
input=prompt_input,
177+
metadata=PromptMetadata(
178+
model=model_metadata, parameters=parameters, **kwargs
179+
),
180+
)
181+
182+
prompts.append(prompt)
183+
184+
await ai_config.callback_manager.run_callbacks(
185+
CallbackEvent(
186+
"on_serialize_complete", __name__, {"result": prompts}
187+
)
188+
)
189+
190+
return prompts
191+
192+
async def deserialize(
193+
self,
194+
prompt: Prompt,
195+
aiconfig: "AIConfigRuntime",
196+
params: Optional[dict[Any, Any]] = {},
197+
) -> dict[Any, Any]:
198+
await aiconfig.callback_manager.run_callbacks(
199+
CallbackEvent(
200+
"on_deserialize_start",
201+
__name__,
202+
{"prompt": prompt, "params": params},
203+
)
204+
)
205+
206+
# Build Completion data
207+
model_settings = self.get_model_settings(prompt, aiconfig)
208+
209+
completion_data = refine_completion_params(model_settings)
210+
211+
# HF Python API supports input types of bytes, file path, and a dict containing raw sampled audio. Suports only one input.
212+
# See https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/inference/_client.py#L302
213+
# For now, support multiple or single uri's as input
214+
audio_input = validate_and_retrieve_audio_from_attachments(prompt)
215+
216+
completion_data["audio"] = audio_input
217+
218+
await aiconfig.callback_manager.run_callbacks(
219+
CallbackEvent(
220+
"on_deserialize_complete",
221+
__name__,
222+
{"output": completion_data},
223+
)
224+
)
225+
return completion_data
226+
227+
async def run(
228+
self,
229+
prompt: Prompt,
230+
aiconfig: "AIConfigRuntime",
231+
options: InferenceOptions,
232+
parameters: Dict[str, Any],
233+
**kwargs,
234+
) -> list[Output]:
235+
"""
236+
Invoked to run a prompt in the .aiconfig. This method should perform
237+
the actual model inference based on the provided prompt and inference settings.
238+
239+
Args:
240+
prompt (str): The input prompt.
241+
inference_settings (dict): Model-specific inference settings.
242+
243+
Returns:
244+
InferenceResponse: The response from the model.
245+
"""
246+
sanitized_options = copy.deepcopy(options)
247+
run_override_api_token = getattr(sanitized_options, "api_token", None)
248+
# Redact api token from logs if it exists
249+
if run_override_api_token:
250+
setattr(sanitized_options, "api_token", "hf_********")
251+
await aiconfig.callback_manager.run_callbacks(
252+
CallbackEvent(
253+
"on_run_start",
254+
__name__,
255+
{
256+
"prompt": prompt,
257+
"options": options,
258+
"parameters": parameters,
259+
},
260+
)
261+
)
262+
263+
completion_data = await self.deserialize(prompt, aiconfig, parameters)
264+
265+
# If api token is provided in the options, use it for the client
266+
client = self.client
267+
if run_override_api_token:
268+
client = InferenceClient(
269+
self.client.model, token=run_override_api_token
270+
)
271+
272+
response = client.automatic_speech_recognition(**completion_data)
273+
274+
# HF Automatic Speech Recognition api doesn't support multiple outputs. Expect only one output.
275+
# Output spec: response is str
276+
outputs = [construct_output(response)]
277+
prompt.outputs = outputs
278+
279+
await aiconfig.callback_manager.run_callbacks(
280+
CallbackEvent(
281+
"on_run_complete", __name__, {"result": prompt.outputs}
282+
)
283+
)
284+
return prompt.outputs
285+
286+
def get_output_text(
287+
self,
288+
prompt: Prompt,
289+
aiconfig: "AIConfigRuntime",
290+
output: Optional[Output] = None,
291+
) -> str:
292+
if output is None:
293+
output = aiconfig.get_latest_output(prompt)
294+
295+
if output is None:
296+
return ""
297+
298+
if output.output_type == "execute_result":
299+
output_data = output.data
300+
if isinstance(output_data, str):
301+
return output_data
302+
303+
else:
304+
raise ValueError(
305+
f"Invalid output data type {type(output_data)} for prompt '{prompt.name}'. Expected string."
306+
)
307+
return ""
308+
309+
310+
def validate_attachment_type_is_audio(attachment: Attachment):
311+
if not hasattr(attachment, "mime_type"):
312+
raise ValueError(
313+
f"Attachment has no mime type. Specify the audio mimetype in the aiconfig"
314+
)
315+
316+
if not attachment.mime_type.startswith("audio"):
317+
raise ValueError(
318+
f"Invalid attachment mimetype {attachment.mime_type}. Expected audio mimetype."
319+
)
320+
321+
322+
def validate_and_retrieve_audio_from_attachments(prompt: Prompt) -> str:
323+
"""
324+
Retrieves the audio uri or base64 from each attachment in the prompt input.
325+
326+
Throws an exception if
327+
- attachment is not audio
328+
- attachment data is not a uri
329+
- no attachments are found
330+
- operation fails for any reason
331+
- more than one audio attachment is found. Inference api only supports one audio input.
332+
"""
333+
334+
if not isinstance(prompt.input, PromptInput):
335+
raise ValueError(
336+
f"Prompt input is of type {type(prompt.input) }. Please specify a PromptInput with attachments for prompt {prompt.name}."
337+
)
338+
339+
if prompt.input.attachments is None or len(prompt.input.attachments) == 0:
340+
raise ValueError(
341+
f"No attachments found in input for prompt {prompt.name}. Please add an audio attachment to the prompt input."
342+
)
343+
344+
audio_input: str | None = None
345+
346+
if len(prompt.input.attachments) > 1:
347+
raise ValueError(
348+
"Multiple audio inputs are not supported for the HF Automatic Speech Recognition Inference api. Please specify a single audio input attachment for Prompt: {prompt.name}."
349+
)
350+
351+
attachment = prompt.input.attachments[0]
352+
353+
validate_attachment_type_is_audio(attachment)
354+
355+
if not isinstance(attachment.data, AttachmentDataWithStringValue):
356+
raise ValueError(
357+
f"""Attachment data must be of type `AttachmentDataWithStringValue` with a `kind` and `value` field.
358+
Please specify a uri for the audio attachment in prompt {prompt.name}."""
359+
)
360+
361+
audio_input = attachment.data.value
362+
return audio_input

0 commit comments

Comments
 (0)