Making TF BART-like models XLA and AMP compliant (#10191)

* Update BART

* Update Blenderbot

* Update BlenderbotSmall

* Update Marian

* Update MBart

* Update MBart

* Update Pegasus

* Update template

* Fix Marian and Pegasus

* Apply style

* Default initializer

* Default initializer

* Default initializer

* Remove int32 casts

* Fix template

* Remove more cast
This commit is contained in:
Julien Plu
2021-02-17 17:48:56 +01:00
committed by GitHub
parent 8d79e5ca49
commit 83d803ba02
13 changed files with 492 additions and 367 deletions

View File

@@ -214,18 +214,6 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
name = model.get_bias()
assert name is None
def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass
def test_mixed_precision(self):
# TODO JP: Make MBart float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make MBart XLA compliant
pass
def test_resize_token_embeddings(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -289,6 +277,10 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
models_equal = False
self.assertTrue(models_equal)
def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""