Generation tests: don't rely on main input name (#34228)
* don't rely on main input name * update
This commit is contained in:
committed by
GitHub
parent
816f442496
commit
ca541bd4f4
@@ -53,6 +53,7 @@ class ReformerModelTester:
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=32,
|
||||
text_seq_length=None,
|
||||
is_training=True,
|
||||
is_decoder=True,
|
||||
use_input_mask=True,
|
||||
@@ -128,6 +129,7 @@ class ReformerModelTester:
|
||||
self.attn_layers = attn_layers
|
||||
self.pad_token_id = pad_token_id
|
||||
self.hash_seed = hash_seed
|
||||
self.text_seq_length = text_seq_length or seq_length
|
||||
|
||||
attn_chunk_length = local_attn_chunk_length if local_attn_chunk_length is not None else lsh_attn_chunk_length
|
||||
num_chunks_after = local_num_chunks_after if local_num_chunks_after is not None else lsh_num_chunks_after
|
||||
@@ -608,7 +610,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ReformerModelTester(self)
|
||||
self.model_tester = ReformerModelTester(self, text_seq_length=16)
|
||||
self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37)
|
||||
|
||||
@slow
|
||||
@@ -689,7 +691,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
|
||||
# decreasing the seq_length in tester causes errors for "training_tests", those need exactly max seq length
|
||||
# NOTE: seq_length has to be multiple of 4, otherwise it fails for other tests
|
||||
original_sequence_length = self.model_tester.seq_length
|
||||
self.model_tester.seq_length = 16
|
||||
self.model_tester.seq_length = self.model_tester.text_seq_length
|
||||
test_inputs = super().prepare_config_and_inputs_for_generate(*args, **kwargs)
|
||||
self.model_tester.seq_length = original_sequence_length
|
||||
return test_inputs
|
||||
|
||||
Reference in New Issue
Block a user