Fix-ci-whisper (#21767)
* fix history * input_features instead of input ids for TFWhisport doctest * use translate intead of transcribe
This commit is contained in:
@@ -1283,7 +1283,7 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua
|
|||||||
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="tf")
|
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="tf")
|
||||||
>>> input_features = inputs.input_features
|
>>> input_features = inputs.input_features
|
||||||
|
|
||||||
>>> generated_ids = model.generate(input_ids=input_features)
|
>>> generated_ids = model.generate(input_features=input_features)
|
||||||
|
|
||||||
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||||
>>> transcription
|
>>> transcription
|
||||||
|
|||||||
@@ -1187,7 +1187,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
input_speech = self._load_datasamples(4)
|
input_speech = self._load_datasamples(4)
|
||||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features
|
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features
|
||||||
generated_ids = model.generate(input_features, max_length=20)
|
generated_ids = model.generate(input_features, max_length=20, task="translate")
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
EXPECTED_LOGITS = torch.tensor(
|
EXPECTED_LOGITS = torch.tensor(
|
||||||
|
|||||||
Reference in New Issue
Block a user