From d51ca3240489fcd9b5541857aba993c207b7048a Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 18 Oct 2022 06:45:48 +0200 Subject: [PATCH] fix tests (#19670) --- .../whisper/test_modeling_tf_whisper.py | 28 +++++++++---------- tests/models/whisper/test_modeling_whisper.py | 14 ++++++---- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/tests/models/whisper/test_modeling_tf_whisper.py b/tests/models/whisper/test_modeling_tf_whisper.py index 62aeeb1367..a5503a0b2a 100644 --- a/tests/models/whisper/test_modeling_tf_whisper.py +++ b/tests/models/whisper/test_modeling_tf_whisper.py @@ -763,7 +763,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): input_speech = self._load_datasamples(1) input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features - generated_ids = model.generate(input_features, num_beams=5) + generated_ids = model.generate(input_features, num_beams=5, max_length=20) transcript = processor.tokenizer.batch_decode(generated_ids)[0] EXPECTED_TRANSCRIPT = ( @@ -781,7 +781,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): input_speech = self._load_datasamples(1) input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features - generated_ids = model.generate(input_features, num_beams=5) + generated_ids = model.generate(input_features, num_beams=5, max_length=20) transcript = processor.tokenizer.decode(generated_ids[0]) EXPECTED_TRANSCRIPT = ( @@ -801,8 +801,8 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): xla_generate = tf.function(model.generate, jit_compile=True) - generated_ids = model.generate(input_features, num_beams=5) - generated_ids_xla = xla_generate(input_features, num_beams=5) + generated_ids = model.generate(input_features, num_beams=5, max_length=20) + generated_ids_xla = xla_generate(input_features, num_beams=5, max_length=20) transcript = processor.tokenizer.decode(generated_ids[0]) transcript_xla = processor.tokenizer.decode(generated_ids_xla[0]) @@ -824,10 +824,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe") - generated_ids = model.generate( - input_features, - do_sample=False, - ) + generated_ids = model.generate(input_features, do_sample=False, max_length=20) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad" @@ -845,7 +842,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe") - generated_ids = model.generate(input_features, do_sample=False) + generated_ids = model.generate(input_features, do_sample=False, max_length=20) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました" @@ -855,6 +852,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): generated_ids = model.generate( input_features, do_sample=False, + max_length=20, ) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] @@ -862,7 +860,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): self.assertEqual(transcript, EXPECTED_TRANSCRIPT) model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate") - generated_ids = model.generate(input_features, do_sample=False) + generated_ids = model.generate(input_features, do_sample=False, max_length=20) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san" @@ -876,7 +874,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): input_speech = self._load_datasamples(4) input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features - generated_ids = model.generate(input_features) + generated_ids = model.generate(input_features, max_length=20) # fmt: off EXPECTED_LOGITS = tf.convert_to_tensor( @@ -893,7 +891,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): # fmt: off EXPECTED_TRANSCRIPT = [ - ' Mr. Quilter is the apostle of the middle classes, and we are glad to', + ' Mr. Quilter is the apostle of the middle classes and we are glad to', " Nor is Mr. Quilter's manner less interesting than his matter.", " He tells us that at this festive season of the year, with Christmas and roast beef", " He has grave doubts whether Sir Frederick Layton's work is really Greek after all," @@ -911,7 +909,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): input_speech = self._load_datasamples(4) input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features - generated_ids = model.generate(input_features) + generated_ids = model.generate(input_features, max_length=20) # fmt: off EXPECTED_LOGITS = tf.convert_to_tensor( @@ -950,8 +948,8 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): xla_generate = tf.function(model.generate, jit_compile=True) - generated_ids = model.generate(input_features) - generated_ids_xla = xla_generate(input_features) + generated_ids = model.generate(input_features, max_length=20) + generated_ids_xla = xla_generate(input_features, max_length=20) # fmt: off EXPECTED_LOGITS = tf.convert_to_tensor( diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 7907aaa1eb..d03d3cbb54 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -895,7 +895,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): torch_device ) - generated_ids = model.generate(input_features, num_beams=5) + generated_ids = model.generate(input_features, num_beams=5, max_length=20) transcript = processor.tokenizer.batch_decode(generated_ids)[0] EXPECTED_TRANSCRIPT = ( @@ -918,7 +918,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): torch_device ) - generated_ids = model.generate(input_features, num_beams=5) + generated_ids = model.generate(input_features, num_beams=5, max_length=20) transcript = processor.tokenizer.decode(generated_ids[0]) EXPECTED_TRANSCRIPT = ( @@ -944,6 +944,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): generated_ids = model.generate( input_features, do_sample=False, + max_length=20, ) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] @@ -966,7 +967,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): ) model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe") - generated_ids = model.generate(input_features, do_sample=False) + generated_ids = model.generate(input_features, do_sample=False, max_length=20) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました" @@ -976,6 +977,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): generated_ids = model.generate( input_features, do_sample=False, + max_length=20, ) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] @@ -983,7 +985,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): self.assertEqual(transcript, EXPECTED_TRANSCRIPT) model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate") - generated_ids = model.generate(input_features, do_sample=False) + generated_ids = model.generate(input_features, do_sample=False, max_length=20) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san" @@ -997,7 +999,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): input_speech = self._load_datasamples(4) input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features - generated_ids = model.generate(input_features) + generated_ids = model.generate(input_features, max_length=20) # fmt: off EXPECTED_LOGITS = torch.tensor( @@ -1036,7 +1038,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to( torch_device ) - generated_ids = model.generate(input_features).to("cpu") + generated_ids = model.generate(input_features, max_length=20).to("cpu") # fmt: off EXPECTED_LOGITS = torch.tensor(