From 680733a7c4b64d41b5571138edfbb72f1ef0d3d5 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Wed, 1 Sep 2021 10:42:21 +0200 Subject: [PATCH] Torchscript test for DistilBERT (#13351) * Torchscript test for DistilBERT * Update tests/test_modeling_distilbert.py --- tests/test_modeling_distilbert.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/tests/test_modeling_distilbert.py b/tests/test_modeling_distilbert.py index 64a6a9cce9..62d176dce5 100644 --- a/tests/test_modeling_distilbert.py +++ b/tests/test_modeling_distilbert.py @@ -12,12 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - +import os +import tempfile import unittest from transformers import DistilBertConfig, is_torch_available -from transformers.testing_utils import require_torch, slow, torch_device +from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device from .test_configuration_common import ConfigTester from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask @@ -252,6 +252,29 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase): model = DistilBertModel.from_pretrained(model_name) self.assertIsNotNone(model) + @slow + @require_torch_gpu + def test_torchscript_device_change(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + + # BertForMultipleChoice behaves incorrectly in JIT environments. + if model_class == DistilBertForMultipleChoice: + return + + config.torchscript = True + model = model_class(config=config) + + inputs_dict = self._prepare_for_class(inputs_dict, model_class) + traced_model = torch.jit.trace( + model, (inputs_dict["input_ids"].to("cpu"), inputs_dict["attention_mask"].to("cpu")) + ) + + with tempfile.TemporaryDirectory() as tmp: + torch.jit.save(traced_model, os.path.join(tmp, "traced_model.pt")) + loaded = torch.jit.load(os.path.join(tmp, "bert.pt"), map_location=torch_device) + loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device)) + @require_torch class DistilBertModelIntergrationTest(unittest.TestCase):