[ASR Pipeline] Fix init with timestamps (#25438)

* [ASR Pipeline] Fix init

* refactor test

* change default kwarg setting

* only perform checks if we have to

* override init

* move pre/forward/post checks to sanitize
This commit is contained in:
Sanchit Gandhi
2023-08-16 18:04:19 +01:00
committed by GitHub
parent 6bca43bb90
commit 36f183ebab
2 changed files with 139 additions and 23 deletions

View File

@@ -343,6 +343,58 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
)
# fmt: on
@require_torch
def test_return_timestamps_in_init(self):
# segment-level timestamps are accepted
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
tokenizer = AutoTokenizer.from_pretrained("openai/whisper-tiny")
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny")
dummy_speech = np.ones(100)
pipe = pipeline(
task="automatic-speech-recognition",
model=model,
feature_extractor=feature_extractor,
tokenizer=tokenizer,
chunk_length_s=8,
stride_length_s=1,
return_timestamps=True,
)
_ = pipe(dummy_speech)
# word-level timestamps are accepted
pipe = pipeline(
task="automatic-speech-recognition",
model=model,
feature_extractor=feature_extractor,
tokenizer=tokenizer,
chunk_length_s=8,
stride_length_s=1,
return_timestamps="word",
)
_ = pipe(dummy_speech)
# char-level timestamps are not accepted
with self.assertRaisesRegex(
ValueError,
"^Whisper cannot return `char` timestamps, only word level or segment level timestamps. "
"Use `return_timestamps='word'` or `return_timestamps=True` respectively.$",
):
pipe = pipeline(
task="automatic-speech-recognition",
model=model,
feature_extractor=feature_extractor,
tokenizer=tokenizer,
chunk_length_s=8,
stride_length_s=1,
return_timestamps="char",
)
_ = pipe(dummy_speech)
@require_torch
@slow
def test_torch_whisper(self):