Fix torchscript tests (#13701)
This commit is contained in:
@@ -436,7 +436,7 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
torch.jit.save(traced_model, os.path.join(tmp, "traced_model.pt"))
|
torch.jit.save(traced_model, os.path.join(tmp, "traced_model.pt"))
|
||||||
loaded = torch.jit.load(os.path.join(tmp, "bert.pt"), map_location=torch_device)
|
loaded = torch.jit.load(os.path.join(tmp, "traced_model.pt"), map_location=torch_device)
|
||||||
loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))
|
loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -273,7 +273,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
torch.jit.save(traced_model, os.path.join(tmp, "traced_model.pt"))
|
torch.jit.save(traced_model, os.path.join(tmp, "traced_model.pt"))
|
||||||
loaded = torch.jit.load(os.path.join(tmp, "bert.pt"), map_location=torch_device)
|
loaded = torch.jit.load(os.path.join(tmp, "traced_model.pt"), map_location=torch_device)
|
||||||
loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))
|
loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -325,7 +325,12 @@ class FlaubertModelTester(object):
|
|||||||
choice_labels,
|
choice_labels,
|
||||||
input_mask,
|
input_mask,
|
||||||
) = config_and_inputs
|
) = config_and_inputs
|
||||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "lengths": input_lengths}
|
inputs_dict = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"token_type_ids": token_type_ids,
|
||||||
|
"lengths": input_lengths,
|
||||||
|
"attention_mask": input_mask,
|
||||||
|
}
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
@@ -422,7 +427,7 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
torch.jit.save(traced_model, os.path.join(tmp, "traced_model.pt"))
|
torch.jit.save(traced_model, os.path.join(tmp, "traced_model.pt"))
|
||||||
loaded = torch.jit.load(os.path.join(tmp, "bert.pt"), map_location=torch_device)
|
loaded = torch.jit.load(os.path.join(tmp, "traced_model.pt"), map_location=torch_device)
|
||||||
loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))
|
loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user