[LED Test] fix common inputs pt for flaky pt-tf led test (#9459)
* fix common inputs pt flakey led * fix other tests correspondingly
This commit is contained in:
committed by
GitHub
parent
ae5a32bb0d
commit
a400fe8931
@@ -110,7 +110,8 @@ class LEDModelTester:
|
|||||||
# because its local attention only attends to `self.attention_window + 1` locations
|
# because its local attention only attends to `self.attention_window + 1` locations
|
||||||
# (assuming no token with global attention, otherwise the last dimension of attentions
|
# (assuming no token with global attention, otherwise the last dimension of attentions
|
||||||
# is x + self.attention_window + 1, where x is the number of tokens with global attention)
|
# is x + self.attention_window + 1, where x is the number of tokens with global attention)
|
||||||
self.encoder_key_length = self.attention_window + 1
|
# x is set to 1
|
||||||
|
self.encoder_key_length = self.attention_window + 2
|
||||||
|
|
||||||
# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
|
# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
|
||||||
# the `test_attention_outputs` and `test_hidden_states_output` tests
|
# the `test_attention_outputs` and `test_hidden_states_output` tests
|
||||||
@@ -149,6 +150,10 @@ class LEDModelTester:
|
|||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config, inputs_dict = self.prepare_config_and_inputs()
|
config, inputs_dict = self.prepare_config_and_inputs()
|
||||||
|
global_attention_mask = torch.zeros_like(inputs_dict["input_ids"])
|
||||||
|
global_attention_mask[:, -1] = 1
|
||||||
|
inputs_dict["global_attention_mask"] = global_attention_mask
|
||||||
|
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
|
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
|
||||||
@@ -196,9 +201,11 @@ class LEDModelTester:
|
|||||||
encoder.save_pretrained(tmpdirname)
|
encoder.save_pretrained(tmpdirname)
|
||||||
encoder = LEDEncoder.from_pretrained(tmpdirname).to(torch_device)
|
encoder = LEDEncoder.from_pretrained(tmpdirname).to(torch_device)
|
||||||
|
|
||||||
encoder_last_hidden_state_2 = encoder(inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[
|
encoder_last_hidden_state_2 = encoder(
|
||||||
0
|
inputs_dict["input_ids"],
|
||||||
]
|
attention_mask=inputs_dict["attention_mask"],
|
||||||
|
global_attention_mask=inputs_dict["global_attention_mask"],
|
||||||
|
)[0]
|
||||||
|
|
||||||
self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3)
|
self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3)
|
||||||
|
|
||||||
@@ -390,7 +397,8 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
out_len = len(outputs)
|
out_len = len(outputs)
|
||||||
|
|
||||||
correct_outlen = 5
|
# global attention outputs are added as well => so +1 here
|
||||||
|
correct_outlen = 6
|
||||||
|
|
||||||
# loss is at first position
|
# loss is at first position
|
||||||
if "labels" in inputs_dict:
|
if "labels" in inputs_dict:
|
||||||
|
|||||||
Reference in New Issue
Block a user