[LED Test] fix common inputs pt for flaky pt-tf led test (#9459)
* fix common inputs pt flakey led * fix other tests correspondingly
This commit is contained in:
committed by
GitHub
parent
ae5a32bb0d
commit
a400fe8931
@@ -110,7 +110,8 @@ class LEDModelTester:
|
||||
# because its local attention only attends to `self.attention_window + 1` locations
|
||||
# (assuming no token with global attention, otherwise the last dimension of attentions
|
||||
# is x + self.attention_window + 1, where x is the number of tokens with global attention)
|
||||
self.encoder_key_length = self.attention_window + 1
|
||||
# x is set to 1
|
||||
self.encoder_key_length = self.attention_window + 2
|
||||
|
||||
# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
|
||||
# the `test_attention_outputs` and `test_hidden_states_output` tests
|
||||
@@ -149,6 +150,10 @@ class LEDModelTester:
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, inputs_dict = self.prepare_config_and_inputs()
|
||||
global_attention_mask = torch.zeros_like(inputs_dict["input_ids"])
|
||||
global_attention_mask[:, -1] = 1
|
||||
inputs_dict["global_attention_mask"] = global_attention_mask
|
||||
|
||||
return config, inputs_dict
|
||||
|
||||
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
|
||||
@@ -196,9 +201,11 @@ class LEDModelTester:
|
||||
encoder.save_pretrained(tmpdirname)
|
||||
encoder = LEDEncoder.from_pretrained(tmpdirname).to(torch_device)
|
||||
|
||||
encoder_last_hidden_state_2 = encoder(inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[
|
||||
0
|
||||
]
|
||||
encoder_last_hidden_state_2 = encoder(
|
||||
inputs_dict["input_ids"],
|
||||
attention_mask=inputs_dict["attention_mask"],
|
||||
global_attention_mask=inputs_dict["global_attention_mask"],
|
||||
)[0]
|
||||
|
||||
self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3)
|
||||
|
||||
@@ -390,7 +397,8 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
)
|
||||
out_len = len(outputs)
|
||||
|
||||
correct_outlen = 5
|
||||
# global attention outputs are added as well => so +1 here
|
||||
correct_outlen = 6
|
||||
|
||||
# loss is at first position
|
||||
if "labels" in inputs_dict:
|
||||
|
||||
Reference in New Issue
Block a user