diff --git a/tests/test_modeling_tf_bart.py b/tests/test_modeling_tf_bart.py index 7ab769c4fe..5637e5addd 100644 --- a/tests/test_modeling_tf_bart.py +++ b/tests/test_modeling_tf_bart.py @@ -281,6 +281,10 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase): # TODO JP: Make BART float16 compliant pass + def test_xla_mode(self): + # TODO JP: Make BART XLA compliant + pass + def _assert_tensors_equal(a, b, atol=1e-12, prefix=""): """If tensors not close, or a and b arent both tensors, raise a nice Assertion error.""" diff --git a/tests/test_modeling_tf_blenderbot.py b/tests/test_modeling_tf_blenderbot.py index ef58fa05ee..f87de7f7d0 100644 --- a/tests/test_modeling_tf_blenderbot.py +++ b/tests/test_modeling_tf_blenderbot.py @@ -217,6 +217,10 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase): # TODO JP: Make Blenderbot float16 compliant pass + def test_xla_mode(self): + # TODO JP: Make Blenderbot XLA compliant + pass + def test_resize_token_embeddings(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_tf_blenderbot_small.py b/tests/test_modeling_tf_blenderbot_small.py index e136f541d2..582dfc373f 100644 --- a/tests/test_modeling_tf_blenderbot_small.py +++ b/tests/test_modeling_tf_blenderbot_small.py @@ -282,6 +282,10 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase): # TODO JP: Make Blenderbot Small float16 compliant pass + def test_xla_mode(self): + # TODO JP: Make Blenderbot Small XLA compliant + pass + def _assert_tensors_equal(a, b, atol=1e-12, prefix=""): """If tensors not close, or a and b arent both tensors, raise a nice Assertion error.""" diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 4be9b955a9..1bab898c8f 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -141,6 +141,19 @@ class TFModelTesterMixin: outputs = run_in_graph_mode() self.assertIsNotNone(outputs) + def test_xla_mode(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + inputs = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config) + + @tf.function(experimental_compile=True) + def run_in_graph_mode(): + return model(inputs) + + outputs = run_in_graph_mode() + self.assertIsNotNone(outputs) + def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_tf_convbert.py b/tests/test_modeling_tf_convbert.py index 5f61b757b3..9c2a1b8104 100644 --- a/tests/test_modeling_tf_convbert.py +++ b/tests/test_modeling_tf_convbert.py @@ -301,6 +301,10 @@ class TFConvBertModelTest(TFModelTesterMixin, unittest.TestCase): [self.model_tester.num_attention_heads / 2, encoder_seq_length, encoder_key_length], ) + def test_xla_mode(self): + # TODO JP: Make ConvBert XLA compliant + pass + @slow def test_model_from_pretrained(self): model = TFConvBertModel.from_pretrained("YituTech/conv-bert-base") diff --git a/tests/test_modeling_tf_ctrl.py b/tests/test_modeling_tf_ctrl.py index f870edfdc5..09d6cf9fe6 100644 --- a/tests/test_modeling_tf_ctrl.py +++ b/tests/test_modeling_tf_ctrl.py @@ -225,6 +225,10 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase): # TODO JP: Make CTRL float16 compliant pass + def test_xla_mode(self): + # TODO JP: Make CTRL XLA compliant + pass + @slow def test_model_from_pretrained(self): for model_name in TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/test_modeling_tf_flaubert.py b/tests/test_modeling_tf_flaubert.py index 27be521b98..53a899e0bf 100644 --- a/tests/test_modeling_tf_flaubert.py +++ b/tests/test_modeling_tf_flaubert.py @@ -334,6 +334,10 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase): # TODO JP: Make Flaubert float16 compliant pass + def test_xla_mode(self): + # TODO JP: Make Flaubert XLA compliant + pass + @require_tf @require_sentencepiece diff --git a/tests/test_modeling_tf_gpt2.py b/tests/test_modeling_tf_gpt2.py index 56ed81643d..4d9a12384b 100644 --- a/tests/test_modeling_tf_gpt2.py +++ b/tests/test_modeling_tf_gpt2.py @@ -391,6 +391,10 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase): # TODO JP: Make GPT2 float16 compliant pass + def test_xla_mode(self): + # TODO JP: Make GPT2 XLA compliant + pass + @slow def test_model_from_pretrained(self): for model_name in TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/test_modeling_tf_led.py b/tests/test_modeling_tf_led.py index 6447bd0f13..620bc430ba 100644 --- a/tests/test_modeling_tf_led.py +++ b/tests/test_modeling_tf_led.py @@ -361,6 +361,10 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase): # TODO JP: Make LED float16 compliant pass + def test_xla_mode(self): + # TODO JP: Make LED XLA compliant + pass + def test_saved_model_with_attentions_output(self): # This test don't pass because of the error: # condition [13,8,4,5], then [13,8,4,5], and else [13,8,4,6] must be broadcastable diff --git a/tests/test_modeling_tf_longformer.py b/tests/test_modeling_tf_longformer.py index 43b32e9524..374b165f3c 100644 --- a/tests/test_modeling_tf_longformer.py +++ b/tests/test_modeling_tf_longformer.py @@ -359,6 +359,10 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase): # TODO JP: Make Longformer float16 compliant pass + def test_xla_mode(self): + # TODO JP: Make Blenderbot XLA compliant + pass + @require_tf @require_sentencepiece diff --git a/tests/test_modeling_tf_marian.py b/tests/test_modeling_tf_marian.py index ce54cd1c63..292f489313 100644 --- a/tests/test_modeling_tf_marian.py +++ b/tests/test_modeling_tf_marian.py @@ -250,6 +250,10 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase): # TODO JP: Make Marian float16 compliant pass + def test_xla_mode(self): + # TODO JP: Make Marian XLA compliant + pass + def test_resize_token_embeddings(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_tf_mbart.py b/tests/test_modeling_tf_mbart.py index b22c54d5bf..eb0cb553ca 100644 --- a/tests/test_modeling_tf_mbart.py +++ b/tests/test_modeling_tf_mbart.py @@ -221,6 +221,10 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase): # TODO JP: Make MBart float16 compliant pass + def test_xla_mode(self): + # TODO JP: Make MBart XLA compliant + pass + def test_resize_token_embeddings(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_tf_mpnet.py b/tests/test_modeling_tf_mpnet.py index 5aa66b5279..da14679ba6 100644 --- a/tests/test_modeling_tf_mpnet.py +++ b/tests/test_modeling_tf_mpnet.py @@ -231,6 +231,10 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_mpnet_for_token_classification(*config_and_inputs) + def test_xla_mode(self): + # TODO JP: Make MPNet XLA compliant + pass + @slow def test_model_from_pretrained(self): for model_name in ["microsoft/mpnet-base"]: diff --git a/tests/test_modeling_tf_openai.py b/tests/test_modeling_tf_openai.py index 8d8c21835a..7da10235fb 100644 --- a/tests/test_modeling_tf_openai.py +++ b/tests/test_modeling_tf_openai.py @@ -249,6 +249,10 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase): # TODO JP: Make OpenAIGPT float16 compliant pass + def test_xla_mode(self): + # TODO JP: Make OpenAIGPT XLA compliant + pass + @slow def test_model_from_pretrained(self): for model_name in TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/test_modeling_tf_pegasus.py b/tests/test_modeling_tf_pegasus.py index 7473e0e1cf..a469aff7fb 100644 --- a/tests/test_modeling_tf_pegasus.py +++ b/tests/test_modeling_tf_pegasus.py @@ -248,6 +248,10 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase): # TODO JP: Make Pegasus float16 compliant pass + def test_xla_mode(self): + # TODO JP: Make Pegasus XLA compliant + pass + def test_resize_token_embeddings(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_tf_t5.py b/tests/test_modeling_tf_t5.py index a01273fc22..2d0638f0e4 100644 --- a/tests/test_modeling_tf_t5.py +++ b/tests/test_modeling_tf_t5.py @@ -310,6 +310,10 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): # TODO JP: Make T5 float16 compliant pass + def test_xla_mode(self): + # TODO JP: Make T5 XLA compliant + pass + @slow def test_model_from_pretrained(self): model = TFT5Model.from_pretrained("t5-small") @@ -443,6 +447,10 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase): # TODO JP: Make T5 float16 compliant pass + def test_xla_mode(self): + # TODO JP: Make T5 XLA compliant + pass + @require_tf @require_sentencepiece diff --git a/tests/test_modeling_tf_transfo_xl.py b/tests/test_modeling_tf_transfo_xl.py index 86d9468cec..a903831208 100644 --- a/tests/test_modeling_tf_transfo_xl.py +++ b/tests/test_modeling_tf_transfo_xl.py @@ -208,6 +208,10 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase): # TODO JP: Make TransfoXL float16 compliant pass + def test_xla_mode(self): + # TODO JP: Make TransfoXL XLA compliant + pass + @slow def test_model_from_pretrained(self): for model_name in TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/test_modeling_tf_xlm.py b/tests/test_modeling_tf_xlm.py index 466f5640db..e3eb1bdbc1 100644 --- a/tests/test_modeling_tf_xlm.py +++ b/tests/test_modeling_tf_xlm.py @@ -330,6 +330,10 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase): # TODO JP: Make XLM float16 compliant pass + def test_xla_mode(self): + # TODO JP: Make XLM XLA compliant + pass + @slow def test_model_from_pretrained(self): for model_name in TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: