From d7c31abf38f379c5645e1ee322ff6a65be12eacc Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Fri, 22 Jan 2021 14:50:46 +0100 Subject: [PATCH] Fix some TF slow tests (#9728) * Fix saved model tests + fix a graph issue in longformer * Apply style --- .../models/longformer/modeling_tf_longformer.py | 10 ++++++++-- tests/test_modeling_tf_common.py | 10 ++++++---- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/longformer/modeling_tf_longformer.py b/src/transformers/models/longformer/modeling_tf_longformer.py index 640ade3db5..c9f2838fb2 100644 --- a/src/transformers/models/longformer/modeling_tf_longformer.py +++ b/src/transformers/models/longformer/modeling_tf_longformer.py @@ -2438,10 +2438,16 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque logger.info("Initializing global attention on CLS token...") # global attention on cls token inputs["global_attention_mask"] = tf.zeros_like(inputs["input_ids"]) + updates = tf.ones(shape_list(inputs["input_ids"])[0], dtype=tf.int32) + indices = tf.pad( + tensor=tf.expand_dims(tf.range(shape_list(inputs["input_ids"])[0]), axis=1), + paddings=[[0, 0], [0, 1]], + constant_values=0, + ) inputs["global_attention_mask"] = tf.tensor_scatter_nd_update( inputs["global_attention_mask"], - [[i, 0] for i in range(shape_list(inputs["input_ids"])[0])], - [1 for _ in range(shape_list(inputs["input_ids"])[0])], + indices, + updates, ) outputs = self.longformer( diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 8056d6cfd5..65715f98dc 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -184,7 +184,7 @@ class TFModelTesterMixin: with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname, saved_model=True) - saved_model_dir = os.path.join(tmpdirname, "saved_model") + saved_model_dir = os.path.join(tmpdirname, "saved_model", "1") self.assertTrue(os.path.exists(saved_model_dir)) @slow @@ -204,7 +204,7 @@ class TFModelTesterMixin: with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname, saved_model=True) - saved_model_dir = os.path.join(tmpdirname, "saved_model") + saved_model_dir = os.path.join(tmpdirname, "saved_model", "1") self.assertTrue(os.path.exists(saved_model_dir)) @slow @@ -223,7 +223,8 @@ class TFModelTesterMixin: with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname, saved_model=True) - model = tf.keras.models.load_model(os.path.join(tmpdirname, "saved_model", "1")) + saved_model_dir = os.path.join(tmpdirname, "saved_model", "1") + model = tf.keras.models.load_model(saved_model_dir) outputs = model(class_inputs_dict) if self.is_encoder_decoder: @@ -262,7 +263,8 @@ class TFModelTesterMixin: with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname, saved_model=True) - model = tf.keras.models.load_model(os.path.join(tmpdirname, "saved_model", "1")) + saved_model_dir = os.path.join(tmpdirname, "saved_model", "1") + model = tf.keras.models.load_model(saved_model_dir) outputs = model(class_inputs_dict) if self.is_encoder_decoder: