* create model

* add integration

* save current state

* make integration tests pass

* add one more test

* add explanation to tests

* remove from bart

* add padding

* remove unnecessary test

* make all tests pass

* re-add cookie cutter tests

* finish PyTorch

* fix attention test

* Update tests/test_modeling_common.py

* revert change

* remove unused file

* add string to doc

* save intermediate

* make tf integration tests pass

* finish tf

* fix doc

* fix docs again

* add led to doctree

* add to auto tokenizer

* added tips for led

* make style

* apply jplus statements

* correct tf longformer

* apply lysandres suggestions

* apply sylvains suggestions

* Apply suggestions from code review
This commit is contained in:
Patrick von Platen
2021-01-05 13:14:30 +01:00
committed by GitHub
parent 314cca2842
commit 189387e9b2
24 changed files with 6242 additions and 40 deletions

View File

@@ -488,13 +488,13 @@ class LongformerModelIntegrationTest(unittest.TestCase):
is_index_global_attn = attention_mask > 0
is_global_attn = is_index_global_attn.flatten().any().item()
output_hidden_states, _ = layer(
output_hidden_states = layer(
hidden_states,
attention_mask=attention_mask,
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn,
)
)[0]
self.assertTrue(output_hidden_states.shape, (1, 4, 8))
self.assertTrue(
@@ -526,13 +526,13 @@ class LongformerModelIntegrationTest(unittest.TestCase):
is_index_global_attn = attention_mask > 0
is_global_attn = is_index_global_attn.flatten().any().item()
output_hidden_states, _, _ = layer(
output_hidden_states = layer(
hidden_states,
attention_mask=attention_mask,
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn,
)
)[0]
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
@@ -583,6 +583,7 @@ class LongformerModelIntegrationTest(unittest.TestCase):
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn,
output_attentions=True,
)
self.assertEqual(local_attentions.shape, (2, 4, 2, 8))