Generate tests: modality-agnostic input preparation (#33685)
This commit is contained in:
@@ -338,13 +338,11 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
self.model_tester.check_global_attention(*config_and_inputs)
|
||||
|
||||
def _get_input_ids_and_config(self, batch_size=2):
|
||||
config, input_ids, attention_mask, inputs_dict = GenerationTesterMixin._get_input_ids_and_config(
|
||||
self, batch_size=batch_size
|
||||
)
|
||||
def prepare_config_and_inputs_for_generate(self, *args, **kwargs):
|
||||
config, inputs_dict = super().prepare_config_and_inputs_for_generate(*args, **kwargs)
|
||||
# LED computes attention scores based on mask indices if `is_global`
|
||||
inputs_dict.pop("global_attention_mask")
|
||||
return config, input_ids, attention_mask, inputs_dict
|
||||
return config, inputs_dict
|
||||
|
||||
# LEDForSequenceClassification does not support inputs_embeds
|
||||
def test_inputs_embeds(self):
|
||||
|
||||
Reference in New Issue
Block a user