From 73a038128227be14b1b6397667b4ff227918f81e Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Wed, 1 Sep 2021 10:41:46 +0200 Subject: [PATCH] Torchscript test (#13350) * Torchscript test * Remove print statement --- tests/test_modeling_bert.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/tests/test_modeling_bert.py b/tests/test_modeling_bert.py index c60d198978..a029d9d47e 100755 --- a/tests/test_modeling_bert.py +++ b/tests/test_modeling_bert.py @@ -12,13 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - +import os +import tempfile import unittest from transformers import BertConfig, is_torch_available from transformers.models.auto import get_values -from transformers.testing_utils import require_torch, slow, torch_device +from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device from .test_configuration_common import ConfigTester from .test_generation_utils import GenerationTesterMixin @@ -556,6 +556,29 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): model = BertModel.from_pretrained(model_name) self.assertIsNotNone(model) + @slow + @require_torch_gpu + def test_torchscript_device_change(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + + # BertForMultipleChoice behaves incorrectly in JIT environments. + if model_class == BertForMultipleChoice: + return + + config.torchscript = True + model = model_class(config=config) + + inputs_dict = self._prepare_for_class(inputs_dict, model_class) + traced_model = torch.jit.trace( + model, (inputs_dict["input_ids"].to("cpu"), inputs_dict["attention_mask"].to("cpu")) + ) + + with tempfile.TemporaryDirectory() as tmp: + torch.jit.save(traced_model, os.path.join(tmp, "bert.pt")) + loaded = torch.jit.load(os.path.join(tmp, "bert.pt"), map_location=torch_device) + loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device)) + @require_torch class BertModelIntegrationTest(unittest.TestCase):