From 087436c98e82237334956ddf26fc1abbc7a88f30 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Fri, 24 Feb 2023 11:39:25 +0100 Subject: [PATCH] Fix-ci-whisper (#21767) * fix history * input_features instead of input ids for TFWhisport doctest * use translate intead of transcribe --- src/transformers/models/whisper/modeling_tf_whisper.py | 2 +- tests/models/whisper/test_modeling_whisper.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/whisper/modeling_tf_whisper.py b/src/transformers/models/whisper/modeling_tf_whisper.py index dd6ebcb2c8..12f6e7db5e 100644 --- a/src/transformers/models/whisper/modeling_tf_whisper.py +++ b/src/transformers/models/whisper/modeling_tf_whisper.py @@ -1283,7 +1283,7 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="tf") >>> input_features = inputs.input_features - >>> generated_ids = model.generate(input_ids=input_features) + >>> generated_ids = model.generate(input_features=input_features) >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] >>> transcription diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index b5d3b2e648..ac21370f98 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1187,7 +1187,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): input_speech = self._load_datasamples(4) input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features - generated_ids = model.generate(input_features, max_length=20) + generated_ids = model.generate(input_features, max_length=20, task="translate") # fmt: off EXPECTED_LOGITS = torch.tensor(