Fix torchscript tests (#13701)
This commit is contained in:
@@ -273,7 +273,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
torch.jit.save(traced_model, os.path.join(tmp, "traced_model.pt"))
|
||||
loaded = torch.jit.load(os.path.join(tmp, "bert.pt"), map_location=torch_device)
|
||||
loaded = torch.jit.load(os.path.join(tmp, "traced_model.pt"), map_location=torch_device)
|
||||
loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user