Making TF BART-like models XLA and AMP compliant (#10191)

* Update BART

* Update Blenderbot

* Update BlenderbotSmall

* Update Marian

* Update MBart

* Update MBart

* Update Pegasus

* Update template

* Fix Marian and Pegasus

* Apply style

* Default initializer

* Default initializer

* Default initializer

* Remove int32 casts

* Fix template

* Remove more cast
This commit is contained in:
Julien Plu
2021-02-17 17:48:56 +01:00
committed by GitHub
parent 8d79e5ca49
commit 83d803ba02
13 changed files with 492 additions and 367 deletions

View File

@@ -279,14 +279,6 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def test_mixed_precision(self):
# TODO JP: Make BART float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make BART XLA compliant
pass
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""

View File

@@ -214,14 +214,6 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def test_mixed_precision(self):
# TODO JP: Make Blenderbot float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make Blenderbot XLA compliant
pass
def test_resize_token_embeddings(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@@ -279,14 +279,6 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def test_mixed_precision(self):
# TODO JP: Make Blenderbot Small float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make Blenderbot Small XLA compliant
pass
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""

View File

@@ -247,14 +247,6 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def test_mixed_precision(self):
# TODO JP: Make Marian float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make Marian XLA compliant
pass
def test_resize_token_embeddings(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@@ -214,18 +214,6 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
name = model.get_bias()
assert name is None
def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass
def test_mixed_precision(self):
# TODO JP: Make MBart float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make MBart XLA compliant
pass
def test_resize_token_embeddings(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -289,6 +277,10 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
models_equal = False
self.assertTrue(models_equal)
def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""

View File

@@ -245,14 +245,6 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def test_mixed_precision(self):
# TODO JP: Make Pegasus float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make Pegasus XLA compliant
pass
def test_resize_token_embeddings(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()