Soft error whisper. (#22475)
* Soft error whisper. * Fix format. --------- Co-authored-by: Ubuntu <ubuntu@ip-172-31-34-94.taildb5d.ts.net>
This commit is contained in:
@@ -877,9 +877,7 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
|
|||||||
|
|
||||||
if previous_tokens:
|
if previous_tokens:
|
||||||
if return_timestamps:
|
if return_timestamps:
|
||||||
# Last token should always be timestamps, so there shouldn't be
|
logger.warning(
|
||||||
# leftover
|
|
||||||
raise ValueError(
|
|
||||||
"There was an error while processing timestamps, we haven't found a timestamp as last token. Was"
|
"There was an error while processing timestamps, we haven't found a timestamp as last token. Was"
|
||||||
" WhisperTimeStampLogitsProcessor used?"
|
" WhisperTimeStampLogitsProcessor used?"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import hf_hub_download, snapshot_download
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
MODEL_FOR_CTC_MAPPING,
|
MODEL_FOR_CTC_MAPPING,
|
||||||
@@ -39,6 +39,7 @@ from transformers.testing_utils import (
|
|||||||
require_pyctcdecode,
|
require_pyctcdecode,
|
||||||
require_tf,
|
require_tf,
|
||||||
require_torch,
|
require_torch,
|
||||||
|
require_torch_gpu,
|
||||||
require_torchaudio,
|
require_torchaudio,
|
||||||
slow,
|
slow,
|
||||||
)
|
)
|
||||||
@@ -1158,6 +1159,36 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
output = speech_recognizer({"raw": waveform, "stride": (1000, 8000), "sampling_rate": 16_000})
|
output = speech_recognizer({"raw": waveform, "stride": (1000, 8000), "sampling_rate": 16_000})
|
||||||
self.assertEqual(output, {"text": "XB"})
|
self.assertEqual(output, {"text": "XB"})
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_slow_unfinished_sequence(self):
|
||||||
|
from transformers import GenerationConfig
|
||||||
|
|
||||||
|
pipe = pipeline(
|
||||||
|
"automatic-speech-recognition",
|
||||||
|
model="vasista22/whisper-hindi-large-v2",
|
||||||
|
device="cuda:0",
|
||||||
|
)
|
||||||
|
# Original model wasn't trained with timestamps and has incorrect generation config
|
||||||
|
pipe.model.generation_config = GenerationConfig.from_pretrained("openai/whisper-large-v2")
|
||||||
|
|
||||||
|
audio = hf_hub_download("Narsil/asr_dummy", filename="hindi.ogg", repo_type="dataset")
|
||||||
|
|
||||||
|
out = pipe(
|
||||||
|
audio,
|
||||||
|
return_timestamps=True,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
out,
|
||||||
|
{
|
||||||
|
"chunks": [
|
||||||
|
{"text": "", "timestamp": (18.94, 0.0)},
|
||||||
|
{"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं", "timestamp": (None, None)},
|
||||||
|
],
|
||||||
|
"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def require_ffmpeg(test_case):
|
def require_ffmpeg(test_case):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user