Add XLA test (#9848)
This commit is contained in:
@@ -281,6 +281,10 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
# TODO JP: Make BART float16 compliant
|
# TODO JP: Make BART float16 compliant
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_xla_mode(self):
|
||||||
|
# TODO JP: Make BART XLA compliant
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||||
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||||
|
|||||||
@@ -217,6 +217,10 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
# TODO JP: Make Blenderbot float16 compliant
|
# TODO JP: Make Blenderbot float16 compliant
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_xla_mode(self):
|
||||||
|
# TODO JP: Make Blenderbot XLA compliant
|
||||||
|
pass
|
||||||
|
|
||||||
def test_resize_token_embeddings(self):
|
def test_resize_token_embeddings(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
|||||||
@@ -282,6 +282,10 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
# TODO JP: Make Blenderbot Small float16 compliant
|
# TODO JP: Make Blenderbot Small float16 compliant
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_xla_mode(self):
|
||||||
|
# TODO JP: Make Blenderbot Small XLA compliant
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||||
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||||
|
|||||||
@@ -141,6 +141,19 @@ class TFModelTesterMixin:
|
|||||||
outputs = run_in_graph_mode()
|
outputs = run_in_graph_mode()
|
||||||
self.assertIsNotNone(outputs)
|
self.assertIsNotNone(outputs)
|
||||||
|
|
||||||
|
def test_xla_mode(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
@tf.function(experimental_compile=True)
|
||||||
|
def run_in_graph_mode():
|
||||||
|
return model(inputs)
|
||||||
|
|
||||||
|
outputs = run_in_graph_mode()
|
||||||
|
self.assertIsNotNone(outputs)
|
||||||
|
|
||||||
def test_forward_signature(self):
|
def test_forward_signature(self):
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
|||||||
@@ -301,6 +301,10 @@ class TFConvBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
[self.model_tester.num_attention_heads / 2, encoder_seq_length, encoder_key_length],
|
[self.model_tester.num_attention_heads / 2, encoder_seq_length, encoder_key_length],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_xla_mode(self):
|
||||||
|
# TODO JP: Make ConvBert XLA compliant
|
||||||
|
pass
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
model = TFConvBertModel.from_pretrained("YituTech/conv-bert-base")
|
model = TFConvBertModel.from_pretrained("YituTech/conv-bert-base")
|
||||||
|
|||||||
@@ -225,6 +225,10 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
# TODO JP: Make CTRL float16 compliant
|
# TODO JP: Make CTRL float16 compliant
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_xla_mode(self):
|
||||||
|
# TODO JP: Make CTRL XLA compliant
|
||||||
|
pass
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
|||||||
@@ -334,6 +334,10 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
# TODO JP: Make Flaubert float16 compliant
|
# TODO JP: Make Flaubert float16 compliant
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_xla_mode(self):
|
||||||
|
# TODO JP: Make Flaubert XLA compliant
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
|
|||||||
@@ -391,6 +391,10 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
# TODO JP: Make GPT2 float16 compliant
|
# TODO JP: Make GPT2 float16 compliant
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_xla_mode(self):
|
||||||
|
# TODO JP: Make GPT2 XLA compliant
|
||||||
|
pass
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
|||||||
@@ -361,6 +361,10 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
# TODO JP: Make LED float16 compliant
|
# TODO JP: Make LED float16 compliant
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_xla_mode(self):
|
||||||
|
# TODO JP: Make LED XLA compliant
|
||||||
|
pass
|
||||||
|
|
||||||
def test_saved_model_with_attentions_output(self):
|
def test_saved_model_with_attentions_output(self):
|
||||||
# This test don't pass because of the error:
|
# This test don't pass because of the error:
|
||||||
# condition [13,8,4,5], then [13,8,4,5], and else [13,8,4,6] must be broadcastable
|
# condition [13,8,4,5], then [13,8,4,5], and else [13,8,4,6] must be broadcastable
|
||||||
|
|||||||
@@ -359,6 +359,10 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
# TODO JP: Make Longformer float16 compliant
|
# TODO JP: Make Longformer float16 compliant
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_xla_mode(self):
|
||||||
|
# TODO JP: Make Blenderbot XLA compliant
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
|
|||||||
@@ -250,6 +250,10 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
# TODO JP: Make Marian float16 compliant
|
# TODO JP: Make Marian float16 compliant
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_xla_mode(self):
|
||||||
|
# TODO JP: Make Marian XLA compliant
|
||||||
|
pass
|
||||||
|
|
||||||
def test_resize_token_embeddings(self):
|
def test_resize_token_embeddings(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
|||||||
@@ -221,6 +221,10 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
# TODO JP: Make MBart float16 compliant
|
# TODO JP: Make MBart float16 compliant
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_xla_mode(self):
|
||||||
|
# TODO JP: Make MBart XLA compliant
|
||||||
|
pass
|
||||||
|
|
||||||
def test_resize_token_embeddings(self):
|
def test_resize_token_embeddings(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
|||||||
@@ -231,6 +231,10 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_mpnet_for_token_classification(*config_and_inputs)
|
self.model_tester.create_and_check_mpnet_for_token_classification(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_xla_mode(self):
|
||||||
|
# TODO JP: Make MPNet XLA compliant
|
||||||
|
pass
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in ["microsoft/mpnet-base"]:
|
for model_name in ["microsoft/mpnet-base"]:
|
||||||
|
|||||||
@@ -249,6 +249,10 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
# TODO JP: Make OpenAIGPT float16 compliant
|
# TODO JP: Make OpenAIGPT float16 compliant
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_xla_mode(self):
|
||||||
|
# TODO JP: Make OpenAIGPT XLA compliant
|
||||||
|
pass
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
|||||||
@@ -248,6 +248,10 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
# TODO JP: Make Pegasus float16 compliant
|
# TODO JP: Make Pegasus float16 compliant
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_xla_mode(self):
|
||||||
|
# TODO JP: Make Pegasus XLA compliant
|
||||||
|
pass
|
||||||
|
|
||||||
def test_resize_token_embeddings(self):
|
def test_resize_token_embeddings(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
|||||||
@@ -310,6 +310,10 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
# TODO JP: Make T5 float16 compliant
|
# TODO JP: Make T5 float16 compliant
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_xla_mode(self):
|
||||||
|
# TODO JP: Make T5 XLA compliant
|
||||||
|
pass
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
model = TFT5Model.from_pretrained("t5-small")
|
model = TFT5Model.from_pretrained("t5-small")
|
||||||
@@ -443,6 +447,10 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
# TODO JP: Make T5 float16 compliant
|
# TODO JP: Make T5 float16 compliant
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_xla_mode(self):
|
||||||
|
# TODO JP: Make T5 XLA compliant
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
|
|||||||
@@ -208,6 +208,10 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
# TODO JP: Make TransfoXL float16 compliant
|
# TODO JP: Make TransfoXL float16 compliant
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_xla_mode(self):
|
||||||
|
# TODO JP: Make TransfoXL XLA compliant
|
||||||
|
pass
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
|||||||
@@ -330,6 +330,10 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
# TODO JP: Make XLM float16 compliant
|
# TODO JP: Make XLM float16 compliant
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_xla_mode(self):
|
||||||
|
# TODO JP: Make XLM XLA compliant
|
||||||
|
pass
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
|||||||
Reference in New Issue
Block a user