Fix torchscript tests (#13701)
This commit is contained in:
@@ -325,7 +325,12 @@ class FlaubertModelTester(object):
|
||||
choice_labels,
|
||||
input_mask,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "lengths": input_lengths}
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"token_type_ids": token_type_ids,
|
||||
"lengths": input_lengths,
|
||||
"attention_mask": input_mask,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@@ -422,7 +427,7 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
torch.jit.save(traced_model, os.path.join(tmp, "traced_model.pt"))
|
||||
loaded = torch.jit.load(os.path.join(tmp, "bert.pt"), map_location=torch_device)
|
||||
loaded = torch.jit.load(os.path.join(tmp, "traced_model.pt"), map_location=torch_device)
|
||||
loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user