From cbcb83f21d83d357f2a806bfa2bb191a3b787ed3 Mon Sep 17 00:00:00 2001 From: sshleifer Date: Mon, 3 Feb 2020 17:03:16 -0500 Subject: [PATCH] minor cleanup of test_attention_outputs --- tests/test_modeling_common.py | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a5d69fbd6c..06dcf3199c 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -117,23 +117,11 @@ class ModelTesterMixin: def test_attention_outputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - decoder_seq_length = ( - self.model_tester.decoder_seq_length - if hasattr(self.model_tester, "decoder_seq_length") - else self.model_tester.seq_length - ) - encoder_seq_length = ( - self.model_tester.encoder_seq_length - if hasattr(self.model_tester, "encoder_seq_length") - else self.model_tester.seq_length - ) - decoder_key_length = ( - self.model_tester.key_length if hasattr(self.model_tester, "key_length") else decoder_seq_length - ) - encoder_key_length = ( - self.model_tester.key_length if hasattr(self.model_tester, "key_length") else encoder_seq_length - ) + seq_len = self.model_tester.seq_length + decoder_seq_length = getattr(self.model_tester, 'decoder_seq_length', seq_len) + encoder_seq_length = getattr(self.model_tester, 'encoder_seq_length', seq_len) + decoder_key_length = getattr(self.model_tester, "key_length", decoder_seq_length) + encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) for model_class in self.all_model_classes: config.output_attentions = True