remvoe _create_and_check_torch_fx_tracing in specific test files (#18667)
* remvoe _create_and_check_torch_fx_tracing defined in specific model test files Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -16,14 +16,11 @@
|
|||||||
|
|
||||||
import collections
|
import collections
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
|
||||||
import pickle
|
|
||||||
import tempfile
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import DonutSwinConfig
|
from transformers import DonutSwinConfig
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
from transformers.utils import is_torch_available, is_torch_fx_available
|
from transformers.utils import is_torch_available
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
||||||
@@ -36,9 +33,6 @@ if is_torch_available():
|
|||||||
from transformers import DonutSwinModel
|
from transformers import DonutSwinModel
|
||||||
from transformers.models.donut.modeling_donut_swin import DONUT_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST
|
from transformers.models.donut.modeling_donut_swin import DONUT_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
|
||||||
if is_torch_fx_available():
|
|
||||||
from transformers.utils.fx import symbolic_trace
|
|
||||||
|
|
||||||
|
|
||||||
class DonutSwinModelTester:
|
class DonutSwinModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -369,96 +363,3 @@ class DonutSwinModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
[0.0, 1.0],
|
[0.0, 1.0],
|
||||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
|
|
||||||
if not is_torch_fx_available() or not self.fx_compatible:
|
|
||||||
return
|
|
||||||
|
|
||||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
|
||||||
configs_no_init.return_dict = False
|
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
model = model_class(config=configs_no_init)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if model.config.is_encoder_decoder:
|
|
||||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
|
||||||
labels = inputs.get("labels", None)
|
|
||||||
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
|
|
||||||
if labels is not None:
|
|
||||||
input_names.append("labels")
|
|
||||||
|
|
||||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
|
||||||
input_names = list(filtered_inputs.keys())
|
|
||||||
|
|
||||||
model_output = model(**filtered_inputs)
|
|
||||||
|
|
||||||
traced_model = symbolic_trace(model, input_names)
|
|
||||||
traced_output = traced_model(**filtered_inputs)
|
|
||||||
else:
|
|
||||||
input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"]
|
|
||||||
|
|
||||||
labels = inputs.get("labels", None)
|
|
||||||
start_positions = inputs.get("start_positions", None)
|
|
||||||
end_positions = inputs.get("end_positions", None)
|
|
||||||
if labels is not None:
|
|
||||||
input_names.append("labels")
|
|
||||||
if start_positions is not None:
|
|
||||||
input_names.append("start_positions")
|
|
||||||
if end_positions is not None:
|
|
||||||
input_names.append("end_positions")
|
|
||||||
|
|
||||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
|
||||||
input_names = list(filtered_inputs.keys())
|
|
||||||
|
|
||||||
model_output = model(**filtered_inputs)
|
|
||||||
|
|
||||||
traced_model = symbolic_trace(model, input_names)
|
|
||||||
traced_output = traced_model(**filtered_inputs)
|
|
||||||
|
|
||||||
except RuntimeError as e:
|
|
||||||
self.fail(f"Couldn't trace module: {e}")
|
|
||||||
|
|
||||||
def flatten_output(output):
|
|
||||||
flatten = []
|
|
||||||
for x in output:
|
|
||||||
if isinstance(x, (tuple, list)):
|
|
||||||
flatten += flatten_output(x)
|
|
||||||
elif not isinstance(x, torch.Tensor):
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
flatten.append(x)
|
|
||||||
return flatten
|
|
||||||
|
|
||||||
model_output = flatten_output(model_output)
|
|
||||||
traced_output = flatten_output(traced_output)
|
|
||||||
num_outputs = len(model_output)
|
|
||||||
|
|
||||||
for i in range(num_outputs):
|
|
||||||
self.assertTrue(
|
|
||||||
torch.allclose(model_output[i], traced_output[i]),
|
|
||||||
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test that the model can be serialized and restored properly
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
|
||||||
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
|
|
||||||
try:
|
|
||||||
with open(pkl_file_name, "wb") as f:
|
|
||||||
pickle.dump(traced_model, f)
|
|
||||||
with open(pkl_file_name, "rb") as f:
|
|
||||||
loaded = pickle.load(f)
|
|
||||||
except Exception as e:
|
|
||||||
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
|
|
||||||
|
|
||||||
loaded_output = loaded(**filtered_inputs)
|
|
||||||
loaded_output = flatten_output(loaded_output)
|
|
||||||
|
|
||||||
for i in range(num_outputs):
|
|
||||||
self.assertTrue(
|
|
||||||
torch.allclose(model_output[i], loaded_output[i]),
|
|
||||||
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -17,7 +17,6 @@
|
|||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import pickle
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@@ -31,7 +30,7 @@ from transformers.testing_utils import (
|
|||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
from transformers.utils import cached_property, is_torch_fx_available
|
from transformers.utils import cached_property
|
||||||
|
|
||||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@@ -44,9 +43,6 @@ if is_torch_available():
|
|||||||
from transformers import Speech2TextForConditionalGeneration, Speech2TextModel, Speech2TextProcessor
|
from transformers import Speech2TextForConditionalGeneration, Speech2TextModel, Speech2TextProcessor
|
||||||
from transformers.models.speech_to_text.modeling_speech_to_text import Speech2TextDecoder, Speech2TextEncoder
|
from transformers.models.speech_to_text.modeling_speech_to_text import Speech2TextDecoder, Speech2TextEncoder
|
||||||
|
|
||||||
if is_torch_fx_available():
|
|
||||||
from transformers.utils.fx import symbolic_trace
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_speech_to_text_inputs_dict(
|
def prepare_speech_to_text_inputs_dict(
|
||||||
config,
|
config,
|
||||||
@@ -720,105 +716,6 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
|
|||||||
|
|
||||||
self.assertTrue(models_equal)
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
|
|
||||||
if not is_torch_fx_available() or not self.fx_compatible:
|
|
||||||
return
|
|
||||||
|
|
||||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
|
||||||
configs_no_init.return_dict = False
|
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
model = model_class(config=configs_no_init)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if model.config.is_encoder_decoder:
|
|
||||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
|
||||||
labels = inputs.get("labels", None)
|
|
||||||
input_names = [
|
|
||||||
"input_ids",
|
|
||||||
"attention_mask",
|
|
||||||
"decoder_input_ids",
|
|
||||||
"decoder_attention_mask",
|
|
||||||
"input_features",
|
|
||||||
]
|
|
||||||
if labels is not None:
|
|
||||||
input_names.append("labels")
|
|
||||||
|
|
||||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
|
||||||
input_names = list(filtered_inputs.keys())
|
|
||||||
|
|
||||||
model_output = model(**filtered_inputs)
|
|
||||||
|
|
||||||
traced_model = symbolic_trace(model, input_names)
|
|
||||||
traced_output = traced_model(**filtered_inputs)
|
|
||||||
else:
|
|
||||||
input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values", "input_features"]
|
|
||||||
|
|
||||||
labels = inputs.get("labels", None)
|
|
||||||
start_positions = inputs.get("start_positions", None)
|
|
||||||
end_positions = inputs.get("end_positions", None)
|
|
||||||
if labels is not None:
|
|
||||||
input_names.append("labels")
|
|
||||||
if start_positions is not None:
|
|
||||||
input_names.append("start_positions")
|
|
||||||
if end_positions is not None:
|
|
||||||
input_names.append("end_positions")
|
|
||||||
|
|
||||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
|
||||||
input_names = list(filtered_inputs.keys())
|
|
||||||
|
|
||||||
model_output = model(**filtered_inputs)
|
|
||||||
|
|
||||||
traced_model = symbolic_trace(model, input_names)
|
|
||||||
traced_output = traced_model(**filtered_inputs)
|
|
||||||
|
|
||||||
except RuntimeError as e:
|
|
||||||
self.fail(f"Couldn't trace module: {e}")
|
|
||||||
|
|
||||||
def flatten_output(output):
|
|
||||||
flatten = []
|
|
||||||
for x in output:
|
|
||||||
if isinstance(x, (tuple, list)):
|
|
||||||
flatten += flatten_output(x)
|
|
||||||
elif not isinstance(x, torch.Tensor):
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
flatten.append(x)
|
|
||||||
return flatten
|
|
||||||
|
|
||||||
model_output = flatten_output(model_output)
|
|
||||||
traced_output = flatten_output(traced_output)
|
|
||||||
num_outputs = len(model_output)
|
|
||||||
|
|
||||||
for i in range(num_outputs):
|
|
||||||
self.assertTrue(
|
|
||||||
torch.allclose(model_output[i], traced_output[i]),
|
|
||||||
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test that the model can be serialized and restored properly
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
|
||||||
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
|
|
||||||
try:
|
|
||||||
with open(pkl_file_name, "wb") as f:
|
|
||||||
pickle.dump(traced_model, f)
|
|
||||||
with open(pkl_file_name, "rb") as f:
|
|
||||||
loaded = pickle.load(f)
|
|
||||||
except Exception as e:
|
|
||||||
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
|
|
||||||
|
|
||||||
loaded_output = loaded(**filtered_inputs)
|
|
||||||
loaded_output = flatten_output(loaded_output)
|
|
||||||
|
|
||||||
for i in range(num_outputs):
|
|
||||||
self.assertTrue(
|
|
||||||
torch.allclose(model_output[i], loaded_output[i]),
|
|
||||||
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_torchaudio
|
@require_torchaudio
|
||||||
|
|||||||
@@ -16,14 +16,11 @@
|
|||||||
|
|
||||||
import collections
|
import collections
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
|
||||||
import pickle
|
|
||||||
import tempfile
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import SwinConfig
|
from transformers import SwinConfig
|
||||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||||
from transformers.utils import cached_property, is_torch_available, is_torch_fx_available, is_vision_available
|
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
||||||
@@ -41,9 +38,6 @@ if is_vision_available():
|
|||||||
|
|
||||||
from transformers import AutoFeatureExtractor
|
from transformers import AutoFeatureExtractor
|
||||||
|
|
||||||
if is_torch_fx_available():
|
|
||||||
from transformers.utils.fx import symbolic_trace
|
|
||||||
|
|
||||||
|
|
||||||
class SwinModelTester:
|
class SwinModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -428,99 +422,6 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
|
|
||||||
if not is_torch_fx_available() or not self.fx_compatible:
|
|
||||||
return
|
|
||||||
|
|
||||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
|
||||||
configs_no_init.return_dict = False
|
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
model = model_class(config=configs_no_init)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if model.config.is_encoder_decoder:
|
|
||||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
|
||||||
labels = inputs.get("labels", None)
|
|
||||||
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
|
|
||||||
if labels is not None:
|
|
||||||
input_names.append("labels")
|
|
||||||
|
|
||||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
|
||||||
input_names = list(filtered_inputs.keys())
|
|
||||||
|
|
||||||
model_output = model(**filtered_inputs)
|
|
||||||
|
|
||||||
traced_model = symbolic_trace(model, input_names)
|
|
||||||
traced_output = traced_model(**filtered_inputs)
|
|
||||||
else:
|
|
||||||
input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"]
|
|
||||||
|
|
||||||
labels = inputs.get("labels", None)
|
|
||||||
start_positions = inputs.get("start_positions", None)
|
|
||||||
end_positions = inputs.get("end_positions", None)
|
|
||||||
if labels is not None:
|
|
||||||
input_names.append("labels")
|
|
||||||
if start_positions is not None:
|
|
||||||
input_names.append("start_positions")
|
|
||||||
if end_positions is not None:
|
|
||||||
input_names.append("end_positions")
|
|
||||||
|
|
||||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
|
||||||
input_names = list(filtered_inputs.keys())
|
|
||||||
|
|
||||||
model_output = model(**filtered_inputs)
|
|
||||||
|
|
||||||
traced_model = symbolic_trace(model, input_names)
|
|
||||||
traced_output = traced_model(**filtered_inputs)
|
|
||||||
|
|
||||||
except RuntimeError as e:
|
|
||||||
self.fail(f"Couldn't trace module: {e}")
|
|
||||||
|
|
||||||
def flatten_output(output):
|
|
||||||
flatten = []
|
|
||||||
for x in output:
|
|
||||||
if isinstance(x, (tuple, list)):
|
|
||||||
flatten += flatten_output(x)
|
|
||||||
elif not isinstance(x, torch.Tensor):
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
flatten.append(x)
|
|
||||||
return flatten
|
|
||||||
|
|
||||||
model_output = flatten_output(model_output)
|
|
||||||
traced_output = flatten_output(traced_output)
|
|
||||||
num_outputs = len(model_output)
|
|
||||||
|
|
||||||
for i in range(num_outputs):
|
|
||||||
self.assertTrue(
|
|
||||||
torch.allclose(model_output[i], traced_output[i]),
|
|
||||||
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test that the model can be serialized and restored properly
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
|
||||||
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
|
|
||||||
try:
|
|
||||||
with open(pkl_file_name, "wb") as f:
|
|
||||||
pickle.dump(traced_model, f)
|
|
||||||
with open(pkl_file_name, "rb") as f:
|
|
||||||
loaded = pickle.load(f)
|
|
||||||
except Exception as e:
|
|
||||||
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
|
|
||||||
|
|
||||||
loaded_output = loaded(**filtered_inputs)
|
|
||||||
loaded_output = flatten_output(loaded_output)
|
|
||||||
|
|
||||||
for i in range(num_outputs):
|
|
||||||
self.assertTrue(
|
|
||||||
torch.allclose(model_output[i], loaded_output[i]),
|
|
||||||
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@require_vision
|
@require_vision
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
@@ -15,24 +15,14 @@
|
|||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import math
|
import math
|
||||||
import os
|
|
||||||
import pickle
|
|
||||||
import tempfile
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import XGLMConfig, is_torch_available
|
from transformers import XGLMConfig, is_torch_available
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
from transformers.utils import is_torch_fx_available
|
|
||||||
|
|
||||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import (
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||||
ModelTesterMixin,
|
|
||||||
_config_zero_init,
|
|
||||||
floats_tensor,
|
|
||||||
ids_tensor,
|
|
||||||
random_attention_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -40,9 +30,6 @@ if is_torch_available():
|
|||||||
|
|
||||||
from transformers import XGLM_PRETRAINED_MODEL_ARCHIVE_LIST, XGLMForCausalLM, XGLMModel, XGLMTokenizer
|
from transformers import XGLM_PRETRAINED_MODEL_ARCHIVE_LIST, XGLMForCausalLM, XGLMModel, XGLMTokenizer
|
||||||
|
|
||||||
if is_torch_fx_available():
|
|
||||||
from transformers.utils.fx import symbolic_trace
|
|
||||||
|
|
||||||
|
|
||||||
class XGLMModelTester:
|
class XGLMModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -350,112 +337,6 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_xglm_weight_initialization(*config_and_inputs)
|
self.model_tester.create_and_check_xglm_weight_initialization(*config_and_inputs)
|
||||||
|
|
||||||
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
|
|
||||||
if not is_torch_fx_available() or not self.fx_compatible:
|
|
||||||
return
|
|
||||||
|
|
||||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
|
||||||
configs_no_init.return_dict = False
|
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
model = model_class(config=configs_no_init)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if model.config.is_encoder_decoder:
|
|
||||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
|
||||||
labels = inputs.get("labels", None)
|
|
||||||
input_names = [
|
|
||||||
"input_ids",
|
|
||||||
"attention_mask",
|
|
||||||
"decoder_input_ids",
|
|
||||||
"decoder_attention_mask",
|
|
||||||
"input_features",
|
|
||||||
]
|
|
||||||
if labels is not None:
|
|
||||||
input_names.append("labels")
|
|
||||||
|
|
||||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
|
||||||
input_names = list(filtered_inputs.keys())
|
|
||||||
|
|
||||||
model_output = model(**filtered_inputs)
|
|
||||||
|
|
||||||
traced_model = symbolic_trace(model, input_names)
|
|
||||||
traced_output = traced_model(**filtered_inputs)
|
|
||||||
else:
|
|
||||||
input_names = [
|
|
||||||
"input_ids",
|
|
||||||
"attention_mask",
|
|
||||||
"token_type_ids",
|
|
||||||
"pixel_values",
|
|
||||||
"bbox",
|
|
||||||
"input_features",
|
|
||||||
]
|
|
||||||
|
|
||||||
labels = inputs.get("labels", None)
|
|
||||||
start_positions = inputs.get("start_positions", None)
|
|
||||||
end_positions = inputs.get("end_positions", None)
|
|
||||||
if labels is not None:
|
|
||||||
input_names.append("labels")
|
|
||||||
if start_positions is not None:
|
|
||||||
input_names.append("start_positions")
|
|
||||||
if end_positions is not None:
|
|
||||||
input_names.append("end_positions")
|
|
||||||
|
|
||||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
|
||||||
input_names = list(filtered_inputs.keys())
|
|
||||||
|
|
||||||
model_output = model(**filtered_inputs)
|
|
||||||
|
|
||||||
traced_model = symbolic_trace(model, input_names)
|
|
||||||
traced_output = traced_model(**filtered_inputs)
|
|
||||||
|
|
||||||
except RuntimeError as e:
|
|
||||||
self.fail(f"Couldn't trace module: {e}")
|
|
||||||
|
|
||||||
def flatten_output(output):
|
|
||||||
flatten = []
|
|
||||||
for x in output:
|
|
||||||
if isinstance(x, (tuple, list)):
|
|
||||||
flatten += flatten_output(x)
|
|
||||||
elif not isinstance(x, torch.Tensor):
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
flatten.append(x)
|
|
||||||
return flatten
|
|
||||||
|
|
||||||
model_output = flatten_output(model_output)
|
|
||||||
traced_output = flatten_output(traced_output)
|
|
||||||
num_outputs = len(model_output)
|
|
||||||
|
|
||||||
for i in range(num_outputs):
|
|
||||||
self.assertTrue(
|
|
||||||
torch.allclose(model_output[i], traced_output[i]),
|
|
||||||
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test that the model can be serialized and restored properly
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
|
||||||
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
|
|
||||||
try:
|
|
||||||
with open(pkl_file_name, "wb") as f:
|
|
||||||
pickle.dump(traced_model, f)
|
|
||||||
with open(pkl_file_name, "rb") as f:
|
|
||||||
loaded = pickle.load(f)
|
|
||||||
except Exception as e:
|
|
||||||
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
|
|
||||||
|
|
||||||
loaded_output = loaded(**filtered_inputs)
|
|
||||||
loaded_output = flatten_output(loaded_output)
|
|
||||||
|
|
||||||
for i in range(num_outputs):
|
|
||||||
self.assertTrue(
|
|
||||||
torch.allclose(model_output[i], loaded_output[i]),
|
|
||||||
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
|
|
||||||
)
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_batch_generation(self):
|
def test_batch_generation(self):
|
||||||
model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")
|
model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")
|
||||||
|
|||||||
Reference in New Issue
Block a user