minor cleanup of test_attention_outputs
This commit is contained in:
committed by
Lysandre Debut
parent
3bf5417258
commit
cbcb83f21d
@@ -117,23 +117,11 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
def test_attention_outputs(self):
|
def test_attention_outputs(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
seq_len = self.model_tester.seq_length
|
||||||
decoder_seq_length = (
|
decoder_seq_length = getattr(self.model_tester, 'decoder_seq_length', seq_len)
|
||||||
self.model_tester.decoder_seq_length
|
encoder_seq_length = getattr(self.model_tester, 'encoder_seq_length', seq_len)
|
||||||
if hasattr(self.model_tester, "decoder_seq_length")
|
decoder_key_length = getattr(self.model_tester, "key_length", decoder_seq_length)
|
||||||
else self.model_tester.seq_length
|
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
||||||
)
|
|
||||||
encoder_seq_length = (
|
|
||||||
self.model_tester.encoder_seq_length
|
|
||||||
if hasattr(self.model_tester, "encoder_seq_length")
|
|
||||||
else self.model_tester.seq_length
|
|
||||||
)
|
|
||||||
decoder_key_length = (
|
|
||||||
self.model_tester.key_length if hasattr(self.model_tester, "key_length") else decoder_seq_length
|
|
||||||
)
|
|
||||||
encoder_key_length = (
|
|
||||||
self.model_tester.key_length if hasattr(self.model_tester, "key_length") else encoder_seq_length
|
|
||||||
)
|
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
config.output_attentions = True
|
config.output_attentions = True
|
||||||
|
|||||||
Reference in New Issue
Block a user