From d07c771dd9cc03429bb64ba1dacd1c9d266a4ec9 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Wed, 1 Sep 2021 10:43:09 +0200 Subject: [PATCH] Torchscript test for ConvBERT (#13352) * Torchscript test for ConvBERT * Apply suggestions from code review --- tests/test_modeling_convbert.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/tests/test_modeling_convbert.py b/tests/test_modeling_convbert.py index 21013f83b5..cbff98f88c 100644 --- a/tests/test_modeling_convbert.py +++ b/tests/test_modeling_convbert.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Testing suite for the PyTorch ConvBERT model. """ - - +import os +import tempfile import unittest from tests.test_modeling_common import floats_tensor from transformers import ConvBertConfig, is_torch_available from transformers.models.auto import get_values -from transformers.testing_utils import require_torch, slow, torch_device +from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device from .test_configuration_common import ConfigTester from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask @@ -416,6 +416,29 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase): [self.model_tester.num_attention_heads / 2, encoder_seq_length, encoder_key_length], ) + @slow + @require_torch_gpu + def test_torchscript_device_change(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + + # ConvBertForMultipleChoice behaves incorrectly in JIT environments. + if model_class == ConvBertForMultipleChoice: + return + + config.torchscript = True + model = model_class(config=config) + + inputs_dict = self._prepare_for_class(inputs_dict, model_class) + traced_model = torch.jit.trace( + model, (inputs_dict["input_ids"].to("cpu"), inputs_dict["attention_mask"].to("cpu")) + ) + + with tempfile.TemporaryDirectory() as tmp: + torch.jit.save(traced_model, os.path.join(tmp, "traced_model.pt")) + loaded = torch.jit.load(os.path.join(tmp, "bert.pt"), map_location=torch_device) + loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device)) + @require_torch class ConvBertModelIntegrationTest(unittest.TestCase):