Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move audio top_k tests to the right file and add slow decorator #36072

Merged
merged 2 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions tests/pipelines/test_pipelines_audio_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,60 @@ def test_large_model_pt(self):
@unittest.skip(reason="Audio classification is not implemented for TF")
def test_small_model_tf(self):
pass

@require_torch
@slow
def test_top_k_none_returns_all_labels(self):
model_name = "superb/wav2vec2-base-superb-ks" # model with more than 5 labels
classification_pipeline = pipeline(
"audio-classification",
model=model_name,
top_k=None,
)

# Create dummy input
sampling_rate = 16000
signal = np.zeros((sampling_rate,), dtype=np.float32)

result = classification_pipeline(signal)
num_labels = classification_pipeline.model.config.num_labels

self.assertEqual(len(result), num_labels, "Should return all labels when top_k is None")

@require_torch
@slow
def test_top_k_none_with_few_labels(self):
model_name = "superb/hubert-base-superb-er" # model with fewer labels
classification_pipeline = pipeline(
"audio-classification",
model=model_name,
top_k=None,
)

# Create dummy input
sampling_rate = 16000
signal = np.zeros((sampling_rate,), dtype=np.float32)

result = classification_pipeline(signal)
num_labels = classification_pipeline.model.config.num_labels

self.assertEqual(len(result), num_labels, "Should handle models with fewer labels correctly")

@require_torch
@slow
def test_top_k_greater_than_labels(self):
model_name = "superb/hubert-base-superb-er"
classification_pipeline = pipeline(
"audio-classification",
model=model_name,
top_k=100, # intentionally large number
)

# Create dummy input
sampling_rate = 16000
signal = np.zeros((sampling_rate,), dtype=np.float32)

result = classification_pipeline(signal)
num_labels = classification_pipeline.model.config.num_labels

self.assertEqual(len(result), num_labels, "Should cap top_k to number of labels")
60 changes: 0 additions & 60 deletions tests/test_audio_classification_top_k.py

This file was deleted.