LED (#9278)
* create model * add integration * save current state * make integration tests pass * add one more test * add explanation to tests * remove from bart * add padding * remove unnecessary test * make all tests pass * re-add cookie cutter tests * finish PyTorch * fix attention test * Update tests/test_modeling_common.py * revert change * remove unused file * add string to doc * save intermediate * make tf integration tests pass * finish tf * fix doc * fix docs again * add led to doctree * add to auto tokenizer * added tips for led * make style * apply jplus statements * correct tf longformer * apply lysandres suggestions * apply sylvains suggestions * Apply suggestions from code review
This commit is contained in:
committed by
GitHub
parent
314cca2842
commit
189387e9b2
@@ -488,13 +488,13 @@ class LongformerModelIntegrationTest(unittest.TestCase):
|
||||
is_index_global_attn = attention_mask > 0
|
||||
is_global_attn = is_index_global_attn.flatten().any().item()
|
||||
|
||||
output_hidden_states, _ = layer(
|
||||
output_hidden_states = layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
is_index_masked=is_index_masked,
|
||||
is_index_global_attn=is_index_global_attn,
|
||||
is_global_attn=is_global_attn,
|
||||
)
|
||||
)[0]
|
||||
|
||||
self.assertTrue(output_hidden_states.shape, (1, 4, 8))
|
||||
self.assertTrue(
|
||||
@@ -526,13 +526,13 @@ class LongformerModelIntegrationTest(unittest.TestCase):
|
||||
is_index_global_attn = attention_mask > 0
|
||||
is_global_attn = is_index_global_attn.flatten().any().item()
|
||||
|
||||
output_hidden_states, _, _ = layer(
|
||||
output_hidden_states = layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
is_index_masked=is_index_masked,
|
||||
is_index_global_attn=is_index_global_attn,
|
||||
is_global_attn=is_global_attn,
|
||||
)
|
||||
)[0]
|
||||
|
||||
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
|
||||
|
||||
@@ -583,6 +583,7 @@ class LongformerModelIntegrationTest(unittest.TestCase):
|
||||
is_index_masked=is_index_masked,
|
||||
is_index_global_attn=is_index_global_attn,
|
||||
is_global_attn=is_global_attn,
|
||||
output_attentions=True,
|
||||
)
|
||||
|
||||
self.assertEqual(local_attentions.shape, (2, 4, 2, 8))
|
||||
|
||||
Reference in New Issue
Block a user