Generate tests: modality-agnostic input preparation (#33685)
This commit is contained in:
@@ -283,28 +283,6 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
|
||||
return False
|
||||
|
||||
# overwrite from GenerationTesterMixin to solve problem
|
||||
# with conflicting random seeds
|
||||
def _get_input_ids_and_config(self, batch_size=2):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.attention_type = "original_full"
|
||||
|
||||
input_ids = inputs_dict.pop(self.input_name)
|
||||
_ = inputs_dict.pop("attention_mask", None)
|
||||
_ = inputs_dict.pop("decoder_input_ids", None)
|
||||
_ = inputs_dict.pop("decoder_attention_mask", None)
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
||||
|
||||
# cut to half length & take max batch_size 3
|
||||
sequence_length = input_ids.shape[-1] // 2
|
||||
input_ids = input_ids[:batch_size, :sequence_length]
|
||||
attention_mask = attention_mask[:batch_size, :sequence_length]
|
||||
|
||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
||||
config.pad_token_id = config.eos_token_id
|
||||
return config, input_ids, attention_mask, inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BigBirdPegasusModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BigBirdPegasusConfig)
|
||||
@@ -485,6 +463,13 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
def test_load_save_without_tied_weights(self):
|
||||
pass
|
||||
|
||||
def test_generate_with_head_masking(self):
|
||||
# overwritten to temporarily switch the attention type to `original_full`
|
||||
original_self_attention_type = self.model_tester.attention_type
|
||||
self.model_tester.attention_type = "original_full"
|
||||
super().test_generate_with_head_masking()
|
||||
self.model_tester.attention_type = original_self_attention_type
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
|
||||
Reference in New Issue
Block a user