fix tests (#19670)
This commit is contained in:
@@ -763,7 +763,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
|
||||
input_speech = self._load_datasamples(1)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||
|
||||
generated_ids = model.generate(input_features, num_beams=5)
|
||||
generated_ids = model.generate(input_features, num_beams=5, max_length=20)
|
||||
transcript = processor.tokenizer.batch_decode(generated_ids)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = (
|
||||
@@ -781,7 +781,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
|
||||
input_speech = self._load_datasamples(1)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||
|
||||
generated_ids = model.generate(input_features, num_beams=5)
|
||||
generated_ids = model.generate(input_features, num_beams=5, max_length=20)
|
||||
transcript = processor.tokenizer.decode(generated_ids[0])
|
||||
|
||||
EXPECTED_TRANSCRIPT = (
|
||||
@@ -801,8 +801,8 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||
|
||||
generated_ids = model.generate(input_features, num_beams=5)
|
||||
generated_ids_xla = xla_generate(input_features, num_beams=5)
|
||||
generated_ids = model.generate(input_features, num_beams=5, max_length=20)
|
||||
generated_ids_xla = xla_generate(input_features, num_beams=5, max_length=20)
|
||||
|
||||
transcript = processor.tokenizer.decode(generated_ids[0])
|
||||
transcript_xla = processor.tokenizer.decode(generated_ids_xla[0])
|
||||
@@ -824,10 +824,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
|
||||
generated_ids = model.generate(
|
||||
input_features,
|
||||
do_sample=False,
|
||||
)
|
||||
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad"
|
||||
@@ -845,7 +842,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe")
|
||||
generated_ids = model.generate(input_features, do_sample=False)
|
||||
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
|
||||
@@ -855,6 +852,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
|
||||
generated_ids = model.generate(
|
||||
input_features,
|
||||
do_sample=False,
|
||||
max_length=20,
|
||||
)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
@@ -862,7 +860,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate")
|
||||
generated_ids = model.generate(input_features, do_sample=False)
|
||||
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"
|
||||
@@ -876,7 +874,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||
generated_ids = model.generate(input_features)
|
||||
generated_ids = model.generate(input_features, max_length=20)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = tf.convert_to_tensor(
|
||||
@@ -893,7 +891,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
' Mr. Quilter is the apostle of the middle classes, and we are glad to',
|
||||
' Mr. Quilter is the apostle of the middle classes and we are glad to',
|
||||
" Nor is Mr. Quilter's manner less interesting than his matter.",
|
||||
" He tells us that at this festive season of the year, with Christmas and roast beef",
|
||||
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all,"
|
||||
@@ -911,7 +909,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||
generated_ids = model.generate(input_features)
|
||||
generated_ids = model.generate(input_features, max_length=20)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = tf.convert_to_tensor(
|
||||
@@ -950,8 +948,8 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||
|
||||
generated_ids = model.generate(input_features)
|
||||
generated_ids_xla = xla_generate(input_features)
|
||||
generated_ids = model.generate(input_features, max_length=20)
|
||||
generated_ids_xla = xla_generate(input_features, max_length=20)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = tf.convert_to_tensor(
|
||||
|
||||
@@ -895,7 +895,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
torch_device
|
||||
)
|
||||
|
||||
generated_ids = model.generate(input_features, num_beams=5)
|
||||
generated_ids = model.generate(input_features, num_beams=5, max_length=20)
|
||||
transcript = processor.tokenizer.batch_decode(generated_ids)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = (
|
||||
@@ -918,7 +918,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
torch_device
|
||||
)
|
||||
|
||||
generated_ids = model.generate(input_features, num_beams=5)
|
||||
generated_ids = model.generate(input_features, num_beams=5, max_length=20)
|
||||
transcript = processor.tokenizer.decode(generated_ids[0])
|
||||
|
||||
EXPECTED_TRANSCRIPT = (
|
||||
@@ -944,6 +944,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
generated_ids = model.generate(
|
||||
input_features,
|
||||
do_sample=False,
|
||||
max_length=20,
|
||||
)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
@@ -966,7 +967,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe")
|
||||
generated_ids = model.generate(input_features, do_sample=False)
|
||||
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
|
||||
@@ -976,6 +977,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
generated_ids = model.generate(
|
||||
input_features,
|
||||
do_sample=False,
|
||||
max_length=20,
|
||||
)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
@@ -983,7 +985,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate")
|
||||
generated_ids = model.generate(input_features, do_sample=False)
|
||||
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"
|
||||
@@ -997,7 +999,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features
|
||||
generated_ids = model.generate(input_features)
|
||||
generated_ids = model.generate(input_features, max_length=20)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = torch.tensor(
|
||||
@@ -1036,7 +1038,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
|
||||
torch_device
|
||||
)
|
||||
generated_ids = model.generate(input_features).to("cpu")
|
||||
generated_ids = model.generate(input_features, max_length=20).to("cpu")
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = torch.tensor(
|
||||
|
||||
Reference in New Issue
Block a user