From a400fe8931cce276df74c7c7a5ee4dd28b5674ec Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 7 Jan 2021 12:29:03 +0100 Subject: [PATCH] [LED Test] fix common inputs pt for flaky pt-tf led test (#9459) * fix common inputs pt flakey led * fix other tests correspondingly --- tests/test_modeling_led.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/tests/test_modeling_led.py b/tests/test_modeling_led.py index 55b63dd24b..927aab71b1 100644 --- a/tests/test_modeling_led.py +++ b/tests/test_modeling_led.py @@ -110,7 +110,8 @@ class LEDModelTester: # because its local attention only attends to `self.attention_window + 1` locations # (assuming no token with global attention, otherwise the last dimension of attentions # is x + self.attention_window + 1, where x is the number of tokens with global attention) - self.encoder_key_length = self.attention_window + 1 + # x is set to 1 + self.encoder_key_length = self.attention_window + 2 # because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for # the `test_attention_outputs` and `test_hidden_states_output` tests @@ -149,6 +150,10 @@ class LEDModelTester: def prepare_config_and_inputs_for_common(self): config, inputs_dict = self.prepare_config_and_inputs() + global_attention_mask = torch.zeros_like(inputs_dict["input_ids"]) + global_attention_mask[:, -1] = 1 + inputs_dict["global_attention_mask"] = global_attention_mask + return config, inputs_dict def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict): @@ -196,9 +201,11 @@ class LEDModelTester: encoder.save_pretrained(tmpdirname) encoder = LEDEncoder.from_pretrained(tmpdirname).to(torch_device) - encoder_last_hidden_state_2 = encoder(inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[ - 0 - ] + encoder_last_hidden_state_2 = encoder( + inputs_dict["input_ids"], + attention_mask=inputs_dict["attention_mask"], + global_attention_mask=inputs_dict["global_attention_mask"], + )[0] self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3) @@ -390,7 +397,8 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ) out_len = len(outputs) - correct_outlen = 5 + # global attention outputs are added as well => so +1 here + correct_outlen = 6 # loss is at first position if "labels" in inputs_dict: