Fix TF LED/Longformer attentions computation (#10007)

* Fix test

* Remove commented test

* Fix name

* Apply style

* Fix check copies

* Remove prints

* Restore boolean

* Fix reshape
This commit is contained in:
Julien Plu
2021-02-10 16:58:37 +01:00
committed by GitHub
parent 0d8e554d42
commit 22a32cf485
4 changed files with 87 additions and 62 deletions

View File

@@ -339,15 +339,8 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
@slow
def test_saved_model_with_attentions_output(self):
# This test don't pass because of the error:
# condition [13,8,4,5], then [13,8,4,5], and else [13,8,4,6] must be broadcastable
# This occurs line 323 in modeling_tf_led.py because the condition line 255
# returns a tensor of shape
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 2]
# if is_global_attn is True and a tensor of shape
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
# This is due to the tf.concat call line 703 that adds one dimension
# Need to check with PVP how to properly fix this
# Temporarily disable this test in order to find
# how to better handle it without timing out the CI
pass
@slow
@@ -371,7 +364,7 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
pass
def test_xla_mode(self):
# TODO JP: Make Blenderbot XLA compliant
# TODO JP: Make Longformer XLA compliant
pass