fix tests (#19670)
This commit is contained in:
@@ -763,7 +763,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
input_speech = self._load_datasamples(1)
|
input_speech = self._load_datasamples(1)
|
||||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||||
|
|
||||||
generated_ids = model.generate(input_features, num_beams=5)
|
generated_ids = model.generate(input_features, num_beams=5, max_length=20)
|
||||||
transcript = processor.tokenizer.batch_decode(generated_ids)[0]
|
transcript = processor.tokenizer.batch_decode(generated_ids)[0]
|
||||||
|
|
||||||
EXPECTED_TRANSCRIPT = (
|
EXPECTED_TRANSCRIPT = (
|
||||||
@@ -781,7 +781,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
input_speech = self._load_datasamples(1)
|
input_speech = self._load_datasamples(1)
|
||||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||||
|
|
||||||
generated_ids = model.generate(input_features, num_beams=5)
|
generated_ids = model.generate(input_features, num_beams=5, max_length=20)
|
||||||
transcript = processor.tokenizer.decode(generated_ids[0])
|
transcript = processor.tokenizer.decode(generated_ids[0])
|
||||||
|
|
||||||
EXPECTED_TRANSCRIPT = (
|
EXPECTED_TRANSCRIPT = (
|
||||||
@@ -801,8 +801,8 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
xla_generate = tf.function(model.generate, jit_compile=True)
|
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||||
|
|
||||||
generated_ids = model.generate(input_features, num_beams=5)
|
generated_ids = model.generate(input_features, num_beams=5, max_length=20)
|
||||||
generated_ids_xla = xla_generate(input_features, num_beams=5)
|
generated_ids_xla = xla_generate(input_features, num_beams=5, max_length=20)
|
||||||
|
|
||||||
transcript = processor.tokenizer.decode(generated_ids[0])
|
transcript = processor.tokenizer.decode(generated_ids[0])
|
||||||
transcript_xla = processor.tokenizer.decode(generated_ids_xla[0])
|
transcript_xla = processor.tokenizer.decode(generated_ids_xla[0])
|
||||||
@@ -824,10 +824,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||||
|
|
||||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
|
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
|
||||||
generated_ids = model.generate(
|
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
|
||||||
input_features,
|
|
||||||
do_sample=False,
|
|
||||||
)
|
|
||||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||||
|
|
||||||
EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad"
|
EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad"
|
||||||
@@ -845,7 +842,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||||
|
|
||||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe")
|
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe")
|
||||||
generated_ids = model.generate(input_features, do_sample=False)
|
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
|
||||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||||
|
|
||||||
EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
|
EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
|
||||||
@@ -855,6 +852,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
generated_ids = model.generate(
|
generated_ids = model.generate(
|
||||||
input_features,
|
input_features,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
|
max_length=20,
|
||||||
)
|
)
|
||||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||||
|
|
||||||
@@ -862,7 +860,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||||
|
|
||||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate")
|
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate")
|
||||||
generated_ids = model.generate(input_features, do_sample=False)
|
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
|
||||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||||
|
|
||||||
EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"
|
EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"
|
||||||
@@ -876,7 +874,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
input_speech = self._load_datasamples(4)
|
input_speech = self._load_datasamples(4)
|
||||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||||
generated_ids = model.generate(input_features)
|
generated_ids = model.generate(input_features, max_length=20)
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
EXPECTED_LOGITS = tf.convert_to_tensor(
|
EXPECTED_LOGITS = tf.convert_to_tensor(
|
||||||
@@ -893,7 +891,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
EXPECTED_TRANSCRIPT = [
|
EXPECTED_TRANSCRIPT = [
|
||||||
' Mr. Quilter is the apostle of the middle classes, and we are glad to',
|
' Mr. Quilter is the apostle of the middle classes and we are glad to',
|
||||||
" Nor is Mr. Quilter's manner less interesting than his matter.",
|
" Nor is Mr. Quilter's manner less interesting than his matter.",
|
||||||
" He tells us that at this festive season of the year, with Christmas and roast beef",
|
" He tells us that at this festive season of the year, with Christmas and roast beef",
|
||||||
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all,"
|
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all,"
|
||||||
@@ -911,7 +909,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
input_speech = self._load_datasamples(4)
|
input_speech = self._load_datasamples(4)
|
||||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||||
generated_ids = model.generate(input_features)
|
generated_ids = model.generate(input_features, max_length=20)
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
EXPECTED_LOGITS = tf.convert_to_tensor(
|
EXPECTED_LOGITS = tf.convert_to_tensor(
|
||||||
@@ -950,8 +948,8 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
xla_generate = tf.function(model.generate, jit_compile=True)
|
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||||
|
|
||||||
generated_ids = model.generate(input_features)
|
generated_ids = model.generate(input_features, max_length=20)
|
||||||
generated_ids_xla = xla_generate(input_features)
|
generated_ids_xla = xla_generate(input_features, max_length=20)
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
EXPECTED_LOGITS = tf.convert_to_tensor(
|
EXPECTED_LOGITS = tf.convert_to_tensor(
|
||||||
|
|||||||
@@ -895,7 +895,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
torch_device
|
torch_device
|
||||||
)
|
)
|
||||||
|
|
||||||
generated_ids = model.generate(input_features, num_beams=5)
|
generated_ids = model.generate(input_features, num_beams=5, max_length=20)
|
||||||
transcript = processor.tokenizer.batch_decode(generated_ids)[0]
|
transcript = processor.tokenizer.batch_decode(generated_ids)[0]
|
||||||
|
|
||||||
EXPECTED_TRANSCRIPT = (
|
EXPECTED_TRANSCRIPT = (
|
||||||
@@ -918,7 +918,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
torch_device
|
torch_device
|
||||||
)
|
)
|
||||||
|
|
||||||
generated_ids = model.generate(input_features, num_beams=5)
|
generated_ids = model.generate(input_features, num_beams=5, max_length=20)
|
||||||
transcript = processor.tokenizer.decode(generated_ids[0])
|
transcript = processor.tokenizer.decode(generated_ids[0])
|
||||||
|
|
||||||
EXPECTED_TRANSCRIPT = (
|
EXPECTED_TRANSCRIPT = (
|
||||||
@@ -944,6 +944,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
generated_ids = model.generate(
|
generated_ids = model.generate(
|
||||||
input_features,
|
input_features,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
|
max_length=20,
|
||||||
)
|
)
|
||||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||||
|
|
||||||
@@ -966,7 +967,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe")
|
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe")
|
||||||
generated_ids = model.generate(input_features, do_sample=False)
|
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
|
||||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||||
|
|
||||||
EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
|
EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
|
||||||
@@ -976,6 +977,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
generated_ids = model.generate(
|
generated_ids = model.generate(
|
||||||
input_features,
|
input_features,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
|
max_length=20,
|
||||||
)
|
)
|
||||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||||
|
|
||||||
@@ -983,7 +985,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||||
|
|
||||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate")
|
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate")
|
||||||
generated_ids = model.generate(input_features, do_sample=False)
|
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
|
||||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||||
|
|
||||||
EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"
|
EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"
|
||||||
@@ -997,7 +999,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
input_speech = self._load_datasamples(4)
|
input_speech = self._load_datasamples(4)
|
||||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features
|
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features
|
||||||
generated_ids = model.generate(input_features)
|
generated_ids = model.generate(input_features, max_length=20)
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
EXPECTED_LOGITS = torch.tensor(
|
EXPECTED_LOGITS = torch.tensor(
|
||||||
@@ -1036,7 +1038,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
|
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
|
||||||
torch_device
|
torch_device
|
||||||
)
|
)
|
||||||
generated_ids = model.generate(input_features).to("cpu")
|
generated_ids = model.generate(input_features, max_length=20).to("cpu")
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
EXPECTED_LOGITS = torch.tensor(
|
EXPECTED_LOGITS = torch.tensor(
|
||||||
|
|||||||
Reference in New Issue
Block a user