Fix some TFWhisperModelIntegrationTests (#24428)
* fix * fix * fix * fix * fix * fix * fix * fix * fix * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -634,6 +634,48 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
generated_ids = output_tokens[:, input_features.shape[-1] :]
|
||||
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
|
||||
|
||||
def test_generate_with_prompt_ids_and_task_and_language(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = TFWhisperForConditionalGeneration(config)
|
||||
input_features = input_dict["input_features"]
|
||||
prompt_ids = np.arange(5)
|
||||
language = "<|de|>"
|
||||
task = "translate"
|
||||
lang_id = 6
|
||||
task_id = 7
|
||||
model.generation_config.__setattr__("lang_to_id", {language: lang_id})
|
||||
model.generation_config.__setattr__("task_to_id", {task: task_id})
|
||||
|
||||
output = model.generate(input_features, max_new_tokens=5, task=task, language=language, prompt_ids=prompt_ids)
|
||||
|
||||
expected_output_start = [
|
||||
*prompt_ids.tolist(),
|
||||
model.generation_config.decoder_start_token_id,
|
||||
lang_id,
|
||||
task_id,
|
||||
]
|
||||
for row in output.numpy().tolist():
|
||||
self.assertListEqual(row[: len(expected_output_start)], expected_output_start)
|
||||
|
||||
def test_generate_with_prompt_ids_and_forced_decoder_ids(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = TFWhisperForConditionalGeneration(config)
|
||||
input_features = input_dict["input_features"]
|
||||
prompt_ids = np.asarray(range(5))
|
||||
forced_decoder_ids = [(1, 6), (2, 7), (3, 8)]
|
||||
|
||||
output = model.generate(
|
||||
input_features, max_new_tokens=5, forced_decoder_ids=forced_decoder_ids, prompt_ids=prompt_ids
|
||||
)
|
||||
|
||||
expected_output_start = [
|
||||
*prompt_ids.tolist(),
|
||||
model.generation_config.decoder_start_token_id,
|
||||
*[token for _rank, token in forced_decoder_ids],
|
||||
]
|
||||
for row in output.numpy().tolist():
|
||||
self.assertListEqual(row[: len(expected_output_start)], expected_output_start)
|
||||
|
||||
|
||||
def _load_datasamples(num_samples):
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
@@ -779,24 +821,22 @@ def _test_large_batched_generation(in_queue, out_queue, timeout):
|
||||
generated_ids = np.concatenate([generated_ids_1, generated_ids_2])
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = tf.convert_to_tensor(
|
||||
[
|
||||
[50258, 50259, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404],
|
||||
[50258, 50259, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257],
|
||||
[50258, 50259, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904],
|
||||
[50258, 50259, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439]
|
||||
]
|
||||
)
|
||||
EXPECTED_IDS = [
|
||||
[50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281],
|
||||
[50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257],
|
||||
[50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256],
|
||||
[50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11]
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
unittest.TestCase().assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS))
|
||||
unittest.TestCase().assertEqual(generated_ids.tolist(), EXPECTED_IDS)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
" Mr. Quilter is the apostle of the middle classes and we are glad",
|
||||
" Mr. Quilter is the apostle of the middle classes and we are glad to",
|
||||
" Nor is Mr. Quilter's manner less interesting than his matter.",
|
||||
" He tells us that at this festive season of the year, with Christmas and roast",
|
||||
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all",
|
||||
" He tells us that at this festive season of the year, with Christmas and roast beef",
|
||||
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all,"
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
|
||||
Reference in New Issue
Block a user