Generate tests: modality-agnostic input preparation (#33685)

This commit is contained in:
Joao Gante
2024-10-03 14:01:24 +01:00
committed by GitHub
parent f2bf4fcf3d
commit d29738f5b4
34 changed files with 241 additions and 906 deletions

View File

@@ -684,20 +684,15 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
def test_left_padding_compatibility(self):
pass
def _get_input_ids_and_config(self, batch_size=2):
def prepare_config_and_inputs_for_generate(self, *args, **kwargs):
# override because overwise we hit max possible seq length for model (4*8=32)
# decreasing the seq_length in tester causes errors for "training_tests", those need exactly max seq length
# NOTE: seq_length has to be multiple of 4, otherwise it fails for other tests
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict.pop(self.input_name)
_ = inputs_dict.pop("attention_mask", None)
_ = inputs_dict.pop("decoder_input_ids", None)
_ = inputs_dict.pop("decoder_attention_mask", None)
input_ids = input_ids[:batch_size, :16]
attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:batch_size, :16]
config.eos_token_id = None
config.forced_eos_token_id = None
return config, input_ids, attention_mask, inputs_dict
original_sequence_length = self.model_tester.seq_length
self.model_tester.seq_length = 16
test_inputs = super().prepare_config_and_inputs_for_generate(*args, **kwargs)
self.model_tester.seq_length = original_sequence_length
return test_inputs
@require_torch