[tests] remove tf/flax tests in /generation (#36235)
This commit is contained in:
@@ -524,127 +524,6 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
[self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length],
|
||||
)
|
||||
|
||||
def test_generate_without_input_ids(self):
|
||||
pass
|
||||
|
||||
# overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is
|
||||
# `input_features`
|
||||
def test_lm_head_model_random_no_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_features = inputs_dict.get("input_features", None)
|
||||
|
||||
# iterate over all generative models
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
if config.bos_token_id is None:
|
||||
# if bos token id is not defined model needs input_features
|
||||
with self.assertRaises(AssertionError):
|
||||
model.generate(do_sample=True, max_length=5)
|
||||
# num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(input_features, do_sample=True))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# generating multiple sequences when no beam search generation
|
||||
# is not allowed as it would always generate the same sequences
|
||||
model.generate(input_features, do_sample=False, num_return_sequences=2)
|
||||
|
||||
# num_return_sequences > 1, sample
|
||||
self._check_generated_ids(model.generate(input_features, do_sample=True, num_return_sequences=2))
|
||||
|
||||
# check bad words tokens language generation
|
||||
# create list of 1-seq bad token and list of 2-seq of bad tokens
|
||||
bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]
|
||||
output_tokens = model.generate(
|
||||
input_features, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=2
|
||||
)
|
||||
# only count generated tokens
|
||||
generated_ids = output_tokens[:, input_features.shape[-1] :]
|
||||
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
|
||||
|
||||
# overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is
|
||||
# `input_features`
|
||||
def test_lm_head_model_random_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_features = inputs_dict.get("input_features", None)
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
if config.bos_token_id is None:
|
||||
# if bos token id is not defined model needs input_ids, num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(input_features, do_sample=True, num_beams=2))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# generating more sequences than having beams leads is not possible
|
||||
model.generate(input_features, do_sample=False, num_return_sequences=3, num_beams=2)
|
||||
|
||||
# num_return_sequences > 1, sample
|
||||
self._check_generated_ids(
|
||||
model.generate(
|
||||
input_features,
|
||||
do_sample=True,
|
||||
num_beams=2,
|
||||
num_return_sequences=2,
|
||||
)
|
||||
)
|
||||
# num_return_sequences > 1, greedy
|
||||
self._check_generated_ids(
|
||||
model.generate(input_features, do_sample=False, num_beams=2, num_return_sequences=2)
|
||||
)
|
||||
|
||||
# check bad words tokens language generation
|
||||
# create list of 1-seq bad token and list of 2-seq of bad tokens
|
||||
bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]
|
||||
output_tokens = model.generate(
|
||||
input_features, do_sample=False, bad_words_ids=bad_words_ids, num_beams=2, num_return_sequences=2
|
||||
)
|
||||
# only count generated tokens
|
||||
generated_ids = output_tokens[:, input_features.shape[-1] :]
|
||||
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
|
||||
|
||||
def test_generate_with_prompt_ids_and_task_and_language(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = TFWhisperForConditionalGeneration(config)
|
||||
input_features = input_dict["input_features"]
|
||||
prompt_ids = np.arange(5)
|
||||
language = "<|de|>"
|
||||
task = "translate"
|
||||
lang_id = 6
|
||||
task_id = 7
|
||||
model.generation_config.__setattr__("lang_to_id", {language: lang_id})
|
||||
model.generation_config.__setattr__("task_to_id", {task: task_id})
|
||||
|
||||
output = model.generate(input_features, max_new_tokens=5, task=task, language=language, prompt_ids=prompt_ids)
|
||||
|
||||
expected_output_start = [
|
||||
*prompt_ids.tolist(),
|
||||
model.generation_config.decoder_start_token_id,
|
||||
lang_id,
|
||||
task_id,
|
||||
]
|
||||
for row in output.numpy().tolist():
|
||||
self.assertListEqual(row[: len(expected_output_start)], expected_output_start)
|
||||
|
||||
def test_generate_with_prompt_ids_and_forced_decoder_ids(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = TFWhisperForConditionalGeneration(config)
|
||||
input_features = input_dict["input_features"]
|
||||
prompt_ids = np.asarray(range(5))
|
||||
forced_decoder_ids = [(1, 6), (2, 7), (3, 8)]
|
||||
|
||||
output = model.generate(
|
||||
input_features, max_new_tokens=5, forced_decoder_ids=forced_decoder_ids, prompt_ids=prompt_ids
|
||||
)
|
||||
|
||||
expected_output_start = [
|
||||
*prompt_ids.tolist(),
|
||||
model.generation_config.decoder_start_token_id,
|
||||
*[token for _rank, token in forced_decoder_ids],
|
||||
]
|
||||
for row in output.numpy().tolist():
|
||||
self.assertListEqual(row[: len(expected_output_start)], expected_output_start)
|
||||
|
||||
|
||||
def _load_datasamples(num_samples):
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
|
||||
Reference in New Issue
Block a user