Generate tests: modality-agnostic input preparation (#33685)

This commit is contained in:
Joao Gante
2024-10-03 14:01:24 +01:00
committed by GitHub
parent f2bf4fcf3d
commit d29738f5b4
34 changed files with 241 additions and 906 deletions

View File

@@ -338,13 +338,11 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.check_global_attention(*config_and_inputs)
def _get_input_ids_and_config(self, batch_size=2):
config, input_ids, attention_mask, inputs_dict = GenerationTesterMixin._get_input_ids_and_config(
self, batch_size=batch_size
)
def prepare_config_and_inputs_for_generate(self, *args, **kwargs):
config, inputs_dict = super().prepare_config_and_inputs_for_generate(*args, **kwargs)
# LED computes attention scores based on mask indices if `is_global`
inputs_dict.pop("global_attention_mask")
return config, input_ids, attention_mask, inputs_dict
return config, inputs_dict
# LEDForSequenceClassification does not support inputs_embeds
def test_inputs_embeds(self):