Fix for Neuron (#30259)

This commit is contained in:
Michael Benayoun
2024-05-02 10:24:47 +02:00
committed by ArthurZucker
parent 9fe3f585bb
commit bb98e7ce58
7 changed files with 240 additions and 99 deletions

View File

@@ -18,7 +18,6 @@ import gc
import inspect
import os
import os.path
import pickle
import random
import re
import tempfile
@@ -1279,26 +1278,6 @@ class ModelTesterMixin:
f"traced {i}th output doesn't match model {i}th output for {model_class}",
)
# Test that the model can be serialized and restored properly
with tempfile.TemporaryDirectory() as tmp_dir_name:
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
try:
with open(pkl_file_name, "wb") as f:
pickle.dump(traced_model, f)
with open(pkl_file_name, "rb") as f:
loaded = pickle.load(f)
except Exception as e:
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
loaded_output = loaded(**filtered_inputs)
loaded_output = flatten_output(loaded_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], loaded_output[i]),
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
)
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
# (Even with this call, there are still memory leak by ~0.04MB)
self.clear_torch_jit_class_registry()