Skip to content

Commit 9b49d85

Browse files
authored
Update Default Models in HF Prompt Schemas (#1222)
# Update Default Models in HF Prompt Schemas Update the prompt schemas for ASR and TTS remote inference to match the defaults set in #1221 <img width="1479" alt="Screenshot 2024-02-13 at 11 25 38 AM" src="https://github.com/lastmile-ai/aiconfig/assets/5060851/28e48030-0fc3-49d2-8c6b-ad0ea142d4d8"> <img width="1465" alt="Screenshot 2024-02-13 at 11 26 42 AM" src="https://github.com/lastmile-ai/aiconfig/assets/5060851/b5d132da-8777-47e9-95f1-73cf4f25a906"> --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/lastmile-ai/aiconfig/pull/1222). * __->__ #1222 * #1221
2 parents 1b6b28e + 864b758 commit 9b49d85

File tree

4 files changed

+52
-37
lines changed

4 files changed

+52
-37
lines changed

extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/automatic_speech_recognition.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,13 @@ def refine_completion_params(model_settings: dict[Any, Any]) -> dict[str, Any]:
4343
if key.lower() in supported_keys:
4444
completion_data[key.lower()] = model_settings[key]
4545

46+
# The default model is openai/whisper-large-v3, which does not work as of
47+
# 02/13/2024. Instead, default to a free model (which supports remote
48+
# inference) with the next most "likes" in HF
49+
# https://huggingface.co/models?pipeline_tag=automatic-speech-recognition&sort=likes
50+
if completion_data.get("model") is None:
51+
completion_data["model"] = "openai/whisper-large-v2"
52+
4653
return completion_data
4754

4855

@@ -299,7 +306,7 @@ def get_output_text(
299306
output_data = output.data
300307
if isinstance(output_data, str):
301308
return output_data
302-
309+
303310
else:
304311
raise ValueError(
305312
f"Invalid output data type {type(output_data)} for prompt '{prompt.name}'. Expected string."
@@ -347,7 +354,7 @@ def validate_and_retrieve_audio_from_attachments(prompt: Prompt) -> str:
347354
raise ValueError(
348355
"Multiple audio inputs are not supported for the HF Automatic Speech Recognition Inference api. Please specify a single audio input attachment for Prompt: {prompt.name}."
349356
)
350-
357+
351358
attachment = prompt.input.attachments[0]
352359

353360
validate_attachment_type_is_audio(attachment)

extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_2_speech.py

+7
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@ def refine_completion_params(model_settings: dict[Any, Any]) -> dict[str, Any]:
4747
if key.lower() in supported_keys:
4848
completion_data[key.lower()] = model_settings[key]
4949

50+
# The default model is suno/bark, which requires HF Pro subscription
51+
# Instead, default to a free model (which supports remote inference) with
52+
# the next most "likes" in HF
53+
# https://huggingface.co/models?pipeline_tag=text-to-speech&sort=likes
54+
if completion_data.get("model") is None:
55+
completion_data["model"] = "facebook/fastspeech2-en-ljspeech"
56+
5057
return completion_data
5158

5259

Original file line numberDiff line numberDiff line change
@@ -1,43 +1,44 @@
11
import { PromptSchema } from "../../utils/promptUtils";
22

3-
export const HuggingFaceAutomaticSpeechRecognitionRemoteInferencePromptSchema: PromptSchema = {
4-
// See https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/inference/_client.py#L302for supported params.
5-
// The settings below are supported settings specified in the HuggingFaceAutomaticSpeechRecognitionRemoteInference refine_completion_params implementation.
6-
input: {
7-
type: "object",
8-
required: ["attachments"],
9-
properties: {
10-
attachments: {
11-
type: "array",
12-
items: {
13-
type: "attachment",
14-
required: ["data"],
15-
mime_types: [
16-
"audio/mpeg",
17-
"audio/wav",
18-
"audio/webm",
19-
"audio/flac",
20-
"audio/ogg",
21-
"audio/ogg",
22-
],
23-
properties: {
24-
data: {
25-
type: "string",
3+
export const HuggingFaceAutomaticSpeechRecognitionRemoteInferencePromptSchema: PromptSchema =
4+
{
5+
// See https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/inference/_client.py#L302for supported params.
6+
// The settings below are supported settings specified in the HuggingFaceAutomaticSpeechRecognitionRemoteInference refine_completion_params implementation.
7+
input: {
8+
type: "object",
9+
required: ["attachments"],
10+
properties: {
11+
attachments: {
12+
type: "array",
13+
items: {
14+
type: "attachment",
15+
required: ["data"],
16+
mime_types: [
17+
"audio/mpeg",
18+
"audio/wav",
19+
"audio/webm",
20+
"audio/flac",
21+
"audio/ogg",
22+
"audio/ogg",
23+
],
24+
properties: {
25+
data: {
26+
type: "string",
27+
},
2628
},
2729
},
30+
max_items: 1,
2831
},
29-
max_items: 1,
3032
},
3133
},
32-
},
33-
model_settings: {
34-
type: "object",
35-
properties: {
36-
model: {
37-
type: "string",
38-
description: `Hugging Face model to use. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint`,
39-
default: "openai/whisper-large-v3"
34+
model_settings: {
35+
type: "object",
36+
properties: {
37+
model: {
38+
type: "string",
39+
description: `Hugging Face model to use. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint`,
40+
default: "openai/whisper-large-v2",
41+
},
4042
},
4143
},
42-
},
43-
};
44+
};

python/src/aiconfig/editor/client/src/shared/prompt_schemas/HuggingFaceText2SpeechRemoteInferencePromptSchema.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ export const HuggingFaceText2SpeechRemoteInferencePromptSchema: PromptSchema = {
1313
type: "string",
1414
description: `Hugging Face model to use. Can be a model ID hosted on the Hugging Face Hub or a URL
1515
to a deployed Inference Endpoint`,
16-
default: "suno/bark",
16+
default: "facebook/fastspeech2-en-ljspeech",
1717
},
1818
},
1919
},

0 commit comments

Comments
 (0)