Add TF<>PT and Flax<>PT everywhere (#14047)
* up * up * up * up * up * up * up * add clip * fix clip PyTorch * fix clip PyTorch * up * up * up * up * up * up * up
This commit is contained in:
committed by
GitHub
parent
8560b55b5e
commit
0c3174c758
@@ -72,7 +72,7 @@ class LongformerModelTester:
|
||||
# because its local attention only attends to `self.attention_window + 1` locations
|
||||
# (assuming no token with global attention, otherwise the last dimension of attentions
|
||||
# is x + self.attention_window + 1, where x is the number of tokens with global attention)
|
||||
self.key_length = self.attention_window + 1
|
||||
self.key_length = self.attention_window + 2
|
||||
|
||||
# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
|
||||
# the `test_attention_outputs` and `test_hidden_states_output` tests
|
||||
@@ -243,6 +243,8 @@ class LongformerModelTester:
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
global_attention_mask = torch.zeros_like(input_ids)
|
||||
global_attention_mask[:, -1] = 1
|
||||
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"token_type_ids": token_type_ids,
|
||||
|
||||
Reference in New Issue
Block a user