Generate tests: modality-agnostic input preparation (#33685)

This commit is contained in:
Joao Gante
2024-10-03 14:01:24 +01:00
committed by GitHub
parent f2bf4fcf3d
commit d29738f5b4
34 changed files with 241 additions and 906 deletions

View File

@@ -283,28 +283,6 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
return False
# overwrite from GenerationTesterMixin to solve problem
# with conflicting random seeds
def _get_input_ids_and_config(self, batch_size=2):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.attention_type = "original_full"
input_ids = inputs_dict.pop(self.input_name)
_ = inputs_dict.pop("attention_mask", None)
_ = inputs_dict.pop("decoder_input_ids", None)
_ = inputs_dict.pop("decoder_attention_mask", None)
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
# cut to half length & take max batch_size 3
sequence_length = input_ids.shape[-1] // 2
input_ids = input_ids[:batch_size, :sequence_length]
attention_mask = attention_mask[:batch_size, :sequence_length]
if config.eos_token_id is not None and config.pad_token_id is None:
# hack to allow generate for models such as GPT2 as is done in `generate()`
config.pad_token_id = config.eos_token_id
return config, input_ids, attention_mask, inputs_dict
def setUp(self):
self.model_tester = BigBirdPegasusModelTester(self)
self.config_tester = ConfigTester(self, config_class=BigBirdPegasusConfig)
@@ -485,6 +463,13 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
def test_load_save_without_tied_weights(self):
pass
def test_generate_with_head_masking(self):
# overwritten to temporarily switch the attention type to `original_full`
original_self_attention_type = self.model_tester.attention_type
self.model_tester.attention_type = "original_full"
super().test_generate_with_head_masking()
self.model_tester.attention_type = original_self_attention_type
@require_torch
@require_sentencepiece