Traced models serialization and torchscripting fix (#17206)

* Fix torch.jit.script and pickling issues

* Fix get_attr issues

* Fix import in function

* Fix GPT-J and T5 tracing for torch=1.11

* Gate graph surgery on torch version

* Modeling minor changes to enable TorchScripting

* Model serialization / deserialization test

* Remove _assert_is_none users
This commit is contained in:
Michael Benayoun
2022-05-23 17:50:40 +02:00
committed by GitHub
parent 1cd01b0af3
commit 2e7e4280aa
10 changed files with 277 additions and 64 deletions

View File

@@ -19,6 +19,7 @@ import inspect
import json
import os
import os.path
import pickle
import random
import sys
import tempfile
@@ -758,8 +759,8 @@ class ModelTesterMixin:
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
except RuntimeError:
self.fail("Couldn't trace module.")
except RuntimeError as e:
self.fail(f"Couldn't trace module: {e}")
def flatten_output(output):
flatten = []
@@ -782,6 +783,40 @@ class ModelTesterMixin:
f"traced {i}th output doesn't match model {i}th output for {model_class}",
)
# Test that the model can be TorchScripted
try:
scripted = torch.jit.script(traced_model)
except Exception as e:
self.fail(f"Could not TorchScript the traced model: {e}")
scripted_output = scripted(**filtered_inputs)
scripted_output = flatten_output(scripted_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], scripted_output[i]),
f"scripted {i}th output doesn't match model {i}th output for {model_class}",
)
# Test that the model can be serialized and restored properly
with tempfile.TemporaryDirectory() as tmp_dir_name:
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
try:
with open(pkl_file_name, "wb") as f:
pickle.dump(traced_model, f)
with open(pkl_file_name, "rb") as f:
loaded = pickle.load(f)
except Exception as e:
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
loaded_output = loaded(**filtered_inputs)
loaded_output = flatten_output(loaded_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], loaded_output[i]),
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
)
def test_headmasking(self):
if not self.test_head_masking:
return