From ca257a06cc42e3345a1500391e1b1d7742e18ae0 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Wed, 22 Sep 2021 19:02:54 -0400 Subject: [PATCH] Fix torchscript tests (#13701) --- tests/test_modeling_convbert.py | 2 +- tests/test_modeling_distilbert.py | 2 +- tests/test_modeling_flaubert.py | 9 +++++++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/test_modeling_convbert.py b/tests/test_modeling_convbert.py index cbff98f88c..d4cde34be8 100644 --- a/tests/test_modeling_convbert.py +++ b/tests/test_modeling_convbert.py @@ -436,7 +436,7 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase): with tempfile.TemporaryDirectory() as tmp: torch.jit.save(traced_model, os.path.join(tmp, "traced_model.pt")) - loaded = torch.jit.load(os.path.join(tmp, "bert.pt"), map_location=torch_device) + loaded = torch.jit.load(os.path.join(tmp, "traced_model.pt"), map_location=torch_device) loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device)) diff --git a/tests/test_modeling_distilbert.py b/tests/test_modeling_distilbert.py index 1aa01eb566..87ebaa22ee 100644 --- a/tests/test_modeling_distilbert.py +++ b/tests/test_modeling_distilbert.py @@ -273,7 +273,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase): with tempfile.TemporaryDirectory() as tmp: torch.jit.save(traced_model, os.path.join(tmp, "traced_model.pt")) - loaded = torch.jit.load(os.path.join(tmp, "bert.pt"), map_location=torch_device) + loaded = torch.jit.load(os.path.join(tmp, "traced_model.pt"), map_location=torch_device) loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device)) diff --git a/tests/test_modeling_flaubert.py b/tests/test_modeling_flaubert.py index 8f4b882821..cf81970f0a 100644 --- a/tests/test_modeling_flaubert.py +++ b/tests/test_modeling_flaubert.py @@ -325,7 +325,12 @@ class FlaubertModelTester(object): choice_labels, input_mask, ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "lengths": input_lengths} + inputs_dict = { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "lengths": input_lengths, + "attention_mask": input_mask, + } return config, inputs_dict @@ -422,7 +427,7 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase): with tempfile.TemporaryDirectory() as tmp: torch.jit.save(traced_model, os.path.join(tmp, "traced_model.pt")) - loaded = torch.jit.load(os.path.join(tmp, "bert.pt"), map_location=torch_device) + loaded = torch.jit.load(os.path.join(tmp, "traced_model.pt"), map_location=torch_device) loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))