Making TF BART-like models XLA and AMP compliant (#10191)

* Update BART

* Update Blenderbot

* Update BlenderbotSmall

* Update Marian

* Update MBart

* Update MBart

* Update Pegasus

* Update template

* Fix Marian and Pegasus

* Apply style

* Default initializer

* Default initializer

* Default initializer

* Remove int32 casts

* Fix template

* Remove more cast
This commit is contained in:
Julien Plu
2021-02-17 17:48:56 +01:00
committed by GitHub
parent 8d79e5ca49
commit 83d803ba02
13 changed files with 492 additions and 367 deletions

View File

@@ -279,14 +279,6 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def test_mixed_precision(self):
# TODO JP: Make Blenderbot Small float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make Blenderbot Small XLA compliant
pass
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""