Traced models serialization and torchscripting fix (#17206)
* Fix torch.jit.script and pickling issues * Fix get_attr issues * Fix import in function * Fix GPT-J and T5 tracing for torch=1.11 * Gate graph surgery on torch version * Modeling minor changes to enable TorchScripting * Model serialization / deserialization test * Remove _assert_is_none users
This commit is contained in:
@@ -19,6 +19,7 @@ import inspect
|
||||
import json
|
||||
import os
|
||||
import os.path
|
||||
import pickle
|
||||
import random
|
||||
import sys
|
||||
import tempfile
|
||||
@@ -758,8 +759,8 @@ class ModelTesterMixin:
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
|
||||
except RuntimeError:
|
||||
self.fail("Couldn't trace module.")
|
||||
except RuntimeError as e:
|
||||
self.fail(f"Couldn't trace module: {e}")
|
||||
|
||||
def flatten_output(output):
|
||||
flatten = []
|
||||
@@ -782,6 +783,40 @@ class ModelTesterMixin:
|
||||
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
# Test that the model can be TorchScripted
|
||||
try:
|
||||
scripted = torch.jit.script(traced_model)
|
||||
except Exception as e:
|
||||
self.fail(f"Could not TorchScript the traced model: {e}")
|
||||
scripted_output = scripted(**filtered_inputs)
|
||||
scripted_output = flatten_output(scripted_output)
|
||||
|
||||
for i in range(num_outputs):
|
||||
self.assertTrue(
|
||||
torch.allclose(model_output[i], scripted_output[i]),
|
||||
f"scripted {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
# Test that the model can be serialized and restored properly
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
|
||||
try:
|
||||
with open(pkl_file_name, "wb") as f:
|
||||
pickle.dump(traced_model, f)
|
||||
with open(pkl_file_name, "rb") as f:
|
||||
loaded = pickle.load(f)
|
||||
except Exception as e:
|
||||
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
|
||||
|
||||
loaded_output = loaded(**filtered_inputs)
|
||||
loaded_output = flatten_output(loaded_output)
|
||||
|
||||
for i in range(num_outputs):
|
||||
self.assertTrue(
|
||||
torch.allclose(model_output[i], loaded_output[i]),
|
||||
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
def test_headmasking(self):
|
||||
if not self.test_head_masking:
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user