diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 4f202800b7..509d7250b7 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -187,7 +187,7 @@ class DecisionTransformerGPT2Attention(nn.Module): if not self.is_cross_attention: # if only "normal" attention layer implements causal mask query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) if attention_mask is not None: diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index 5ef86541f2..a93c08345a 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -211,7 +211,7 @@ class MultiHeadSelfAttention(nn.Module): q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head) scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length) mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length) - scores = scores.masked_fill(mask, -float("inf")) # (bs, n_heads, q_length, k_length) + scores = scores.masked_fill(mask, torch.tensor(-float("inf"))) # (bs, n_heads, q_length, k_length) weights = nn.functional.softmax(scores, dim=-1) # (bs, n_heads, q_length, k_length) weights = self.dropout(weights) # (bs, n_heads, q_length, k_length) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 5e981bf9f2..776a69230b 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -198,7 +198,7 @@ class GPT2Attention(nn.Module): if not self.is_cross_attention: # if only "normal" attention layer implements causal mask query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) if attention_mask is not None: @@ -1410,7 +1410,7 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel): "unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) - pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths] + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] loss = None if labels is not None: diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 8d4dcd9a7c..2ee4a0df8e 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -147,8 +147,8 @@ class GPTNeoSelfAttention(nn.Module): self.register_buffer("bias", bias) self.register_buffer("masked_bias", torch.tensor(-1e9)) - self.attn_dropout = nn.Dropout(config.attention_dropout) - self.resid_dropout = nn.Dropout(config.resid_dropout) + self.attn_dropout = nn.Dropout(float(config.attention_dropout)) + self.resid_dropout = nn.Dropout(float(config.resid_dropout)) self.embed_dim = config.hidden_size self.num_heads = config.num_heads @@ -188,7 +188,7 @@ class GPTNeoSelfAttention(nn.Module): attn_weights = torch.matmul(query, key.transpose(-1, -2)) query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) if attention_mask is not None: @@ -290,7 +290,7 @@ class GPTNeoMLP(nn.Module): self.c_fc = nn.Linear(embed_dim, intermediate_size) self.c_proj = nn.Linear(intermediate_size, embed_dim) self.act = ACT2FN[config.activation_function] - self.dropout = nn.Dropout(config.resid_dropout) + self.dropout = nn.Dropout(float(config.resid_dropout)) def forward(self, hidden_states): hidden_states = self.c_fc(hidden_states) @@ -475,7 +475,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel): self.embed_dim = config.hidden_size self.wte = nn.Embedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) - self.drop = nn.Dropout(config.embed_dropout) + self.drop = nn.Dropout(float(config.embed_dropout)) self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -887,7 +887,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel): "unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) - pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths] + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] loss = None if labels is not None: diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index db58113d96..f57e3e74ec 100755 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -69,7 +69,7 @@ def fixed_pos_embedding(x, seq_dim=1, seq_len=None): def rotate_every_two(x): x1 = x[:, :, :, ::2] x2 = x[:, :, :, 1::2] - x = torch.stack((-x2, x1), axis=-1) + x = torch.stack((-x2, x1), dim=-1) return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)') @@ -163,7 +163,7 @@ class GPTJAttention(nn.Module): # compute causal mask from causal mask buffer query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) # Keep the attention weights computation in fp32 to avoid overflow issues query = query.to(torch.float32) @@ -971,7 +971,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel): "unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) - pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths] + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] loss = None if labels is not None: diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py index 4e4b0d963b..6bc306a6e0 100644 --- a/src/transformers/models/mobilebert/modeling_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_mobilebert.py @@ -226,9 +226,9 @@ class MobileBertEmbeddings(nn.Module): # dimensional output. inputs_embeds = torch.cat( [ - nn.functional.pad(inputs_embeds[:, 1:], [0, 0, 0, 1, 0, 0], value=0), + nn.functional.pad(inputs_embeds[:, 1:], [0, 0, 0, 1, 0, 0], value=0.0), inputs_embeds, - nn.functional.pad(inputs_embeds[:, :-1], [0, 0, 1, 0, 0, 0], value=0), + nn.functional.pad(inputs_embeds[:, :-1], [0, 0, 1, 0, 0, 0], value=0.0), ], dim=2, ) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 83fbee36c3..9516253789 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -18,6 +18,7 @@ import collections import functools import inspect import math +import operator import random import warnings from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union @@ -26,6 +27,7 @@ import torch from packaging import version from torch import nn from torch.fx import Graph, GraphModule, Proxy, Tracer +from torch.fx.proxy import ParameterProxy from .. import ( CONFIG_MAPPING, @@ -126,45 +128,45 @@ _SUPPORTED_MODELS = tuple( ) -def embedding_override(self, input): +def torch_nn_embedding(self, input): return torch.empty(*input.shape, self.weight.shape[-1], device="meta") -def torch_nn_layernorm_override(self, input): +def torch_nn_layernorm(self, input): return input -def torch_nn_linear_override(self, input): +def torch_nn_linear(self, input): return torch.empty(input.shape[:-1] + (self.out_features,), device="meta") -def torch_relu_override(x): +def torch_relu(x): return x -def torch_nn_relu_override(self, x): +def torch_nn_relu(self, x): return x -def torch_nn_functional_relu_override(x, inplace=False): +def torch_nn_functional_relu(x, inplace=False): if not inplace: raise ValueError("Don't support in-place functional.relu for MetaTensor analysis") return x -def torch_where_override(condition, x, y): +def torch_where(condition, x, y): # torch.where returns the broadcasted tensor of condition, x, and y, # so hack it by using addition return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta") -def torch_abs_override(input, *, out=None): - if out is None: +def torch_abs(input, *, out=None): + if out is not None: raise ValueError("Don't support in-place abs for MetaTensor analysis") return input -def torch_arange_override(*args, **kwargs): +def torch_arange(*args, **kwargs): n = len(args) step = 1 if n == 1: @@ -179,7 +181,7 @@ def torch_arange_override(*args, **kwargs): return torch.empty((end - start) // step, dtype=dtype, device="meta") -def torch_cat_override(tensors, dim=None, axis=None, *, out=None): +def torch_cat(tensors, dim=None, axis=None, *, out=None): if dim is None and axis is None: dim = 0 if dim is None and axis is not None: @@ -193,7 +195,7 @@ def torch_cat_override(tensors, dim=None, axis=None, *, out=None): return torch.empty(final_shape, device="meta") -def torch_stack_override(tensors, dim=None, axis=None, *, out=None): +def torch_stack(tensors, dim=None, axis=None, *, out=None): if dim is None and axis is None: dim = 0 if dim is None and axis is not None: @@ -205,7 +207,7 @@ def torch_stack_override(tensors, dim=None, axis=None, *, out=None): return torch.empty(shape, device="meta") -def torch_add_override(input, other, *, alpha=1, out=None): +def torch_add(input, other, *, alpha=1, out=None): if not isinstance(input, torch.Tensor): return torch.empty_like(other, device="meta") if not isinstance(other, torch.Tensor): @@ -219,15 +221,15 @@ def torch_add_override(input, other, *, alpha=1, out=None): return torch.empty(shape, device="meta") -def torch_mul_override(input, other, *, out=None): - return torch_add_override(input, other, out=out) +def torch_mul(input, other, *, out=None): + return torch_add(input, other, out=out) -def torch_tensor_mul_override(self, other): - return torch_mul_override(self, other) +def torch_tensor_mul(self, other): + return torch_mul(self, other) -def torch_matmul_override(input, other, *, out=None): +def torch_matmul(input, other, *, out=None): d1 = input.dim() d2 = other.dim() shape = None @@ -263,7 +265,13 @@ def torch_matmul_override(input, other, *, out=None): return torch.empty(*shape, device="meta") -def torch_tensor_repeat_override(self, *sizes): +def torch_einsum(equation, *operands): + # TODO: infer shape without performing the computation, this might be quite hard. + concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands) + return torch.einsum(equation, *concrete_operands).to("meta") + + +def torch_tensor_repeat(self, *sizes): shape = list(self.shape) for i, x in enumerate(sizes): shape[i] *= x @@ -305,6 +313,18 @@ def torch_nn_conv2d(self, input): return torch.empty(shape, device="meta") +def torch_unsqueeze(input, dim): + shape = list(input.shape) + if dim < 0: + dim = input.dim() + 1 + dim + shape.insert(dim, 1) + return torch.empty(shape, device="meta") + + +def torch_tensor_unsqueeze(self, dim): + return torch_unsqueeze(self, dim) + + def torch_nn_mseloss(self, input, target): if self.reduction == "none": shape = target.shape @@ -329,31 +349,42 @@ def torch_nn_bcewithlogitsloss(self, input, target): return torch.empty(shape, device="meta") +def operator_getitem(a, b): + if isinstance(a, torch.Tensor): + # TODO: infer shape without performing the computation. + return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta") + return operator.getitem(a, b) + + _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = { - torch.nn.Embedding: embedding_override, - torch.nn.LayerNorm: torch_nn_layernorm_override, - torch.nn.Linear: torch_nn_linear_override, - torch.relu: torch_relu_override, - torch.nn.functional.relu: torch_nn_functional_relu_override, - torch.nn.ReLU: torch_nn_relu_override, - torch.where: torch_where_override, - torch.abs: torch_abs_override, - torch.arange: torch_arange_override, - torch.cat: torch_cat_override, - torch.stack: torch_stack_override, - torch.add: torch_add_override, - torch.mul: torch_mul_override, - torch.Tensor.mul: torch_tensor_mul_override, - torch.matmul: torch_matmul_override, - torch.Tensor.repeat: torch_tensor_repeat_override, + torch.nn.Embedding: torch_nn_embedding, + torch.nn.LayerNorm: torch_nn_layernorm, + torch.nn.Linear: torch_nn_linear, + torch.relu: torch_relu, + torch.nn.functional.relu: torch_nn_functional_relu, + torch.nn.ReLU: torch_nn_relu, + torch.where: torch_where, + torch.abs: torch_abs, + torch.arange: torch_arange, + torch.cat: torch_cat, + torch.stack: torch_stack, + torch.add: torch_add, + torch.mul: torch_mul, + torch.Tensor.mul: torch_tensor_mul, + torch.matmul: torch_matmul, + torch.einsum: torch_einsum, + torch.Tensor.repeat: torch_tensor_repeat, torch.roll: torch_roll, # TODO: those might not be needed. # torch.index_select: torch_index_select, # torch.Tensor.index_select: torch_tensor_index_select, torch.nn.Conv2d: torch_nn_conv2d, + torch.unsqueeze: torch_unsqueeze, + torch.Tensor.unsqueeze: torch_tensor_unsqueeze, torch.nn.MSELoss: torch_nn_mseloss, torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss, torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss, + operator.getitem: operator_getitem, } @@ -371,7 +402,6 @@ class HFProxy(Proxy): @property def dtype(self): - return self.tracer.root.dtype if hasattr(self, "_metadata") and self._metadata is not None: return self._metadata.dtype return self.tracer.create_proxy("call_function", builtins.getattr, (self, "dtype"), {}) @@ -400,7 +430,7 @@ class HFProxy(Proxy): return HFAttribute(self, k) def __setitem__(self, indices, values): - return self.tracer.create_proxy("call_method", "__setitem__", (self, indices, values), {}) + return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {}) def __contains__(self, key): # To handle cases such as : @@ -480,14 +510,14 @@ class HFTracer(Tracer): regular PyTorch torch.fx.Proxy. """ + # Feature flag for proxying accesses to buffer values + proxy_buffer_attributes: bool = True allow_insert_stateless_mods: bool = True _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full_like", "eye"] - def __init__(self, autowrap_modules=(math,), autowrap_functions=(), enable_cpatching=False): + def __init__(self, autowrap_modules=(math,), autowrap_functions=()): - super().__init__( - autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions, enable_cpatching=enable_cpatching - ) + super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions) if not is_torch_fx_available(): torch_version = version.parse(importlib_metadata.version("torch")) @@ -500,7 +530,9 @@ class HFTracer(Tracer): self, model: PreTrainedModel, input_name: str, shape: List[int] ) -> Dict[str, torch.Tensor]: """Generates dummy input for model inference recording.""" - model_class = model.__class__ + # Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored + # from pickle, or from the "__class__" attribute in the general case. + model_class = getattr(model, "class_for_deserialization", model.__class__) device = model.device inputs_dict = {} @@ -641,7 +673,38 @@ class HFTracer(Tracer): if getattr(self, "_disable_module_getattr", False): return attr_val else: - return super()._module_getattr(attr, attr_val, parameter_proxy_cache) + # return super()._module_getattr(attr, attr_val, parameter_proxy_cache) + def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache): + for n, p in collection_to_search: + if attr_val is p: + if n not in parameter_proxy_cache: + kwargs = {} + if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters: + kwargs["proxy_factory_fn"] = ( + None + if not self.param_shapes_constant + else lambda node: ParameterProxy(self, node, n, attr_val) + ) + val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] + parameter_proxy_cache[n] = val_proxy + return parameter_proxy_cache[n] + return None + + if isinstance(attr_val, torch.nn.Parameter): + maybe_parameter_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_parameters(), parameter_proxy_cache + ) + if maybe_parameter_proxy is not None: + return maybe_parameter_proxy + + if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor): + maybe_buffer_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_buffers(), parameter_proxy_cache + ) + if maybe_buffer_proxy is not None: + return maybe_buffer_proxy + + return attr_val def call_module(self, m, forward, args, kwargs): self.orig_forward = forward @@ -693,17 +756,29 @@ class HFTracer(Tracer): for name, (_, orig) in self.patched_torch_methods.items(): setattr(torch, name, orig) - # TODO: keep this until necessary. # This is necessary because concrete args are added as input to the traced module since # https://github.com/pytorch/pytorch/pull/55888. - # A PR that solves this was posted: https://github.com/pytorch/pytorch/pull/59569 but it was not merged yet. for node in self.graph.nodes: if node.op == "placeholder": # Removing default values for inputs as the forward pass will fail with them. if node.target in input_names: node.args = () + # Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor]. + # It cannot infer on the attributes and methods the input should have, and fails. + node.type = torch.Tensor # It is a concrete arg so it is not used and should be removed. else: + if hasattr(torch.fx._symbolic_trace, "_assert_is_none"): + # Newer versions of torch.fx emit an assert statement + # for concrete arguments; delete those before we delete + # the concrete arg. + to_delete = [] + for user in node.users: + if user.target == torch.fx._symbolic_trace._assert_is_none: + to_delete.append(user) + for user in to_delete: + self.graph.erase_node(user) + self.graph.erase_node(node) # TODO: solves GraphModule creation. @@ -809,4 +884,10 @@ def symbolic_trace( traced_graph = tracer.trace(model, concrete_args=concrete_args) traced = torch.fx.GraphModule(model, traced_graph) + traced.config = model.config + # The model class must be stored as an attribute to allow model deserialization, which uses trace, and thus + # _generate_dummy_input, where the model class is needed. + traced.class_for_deserialization = model.__class__ + traced.device = model.device + return traced diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index ce1d43cc78..1c6fda55f5 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -325,7 +325,7 @@ torch_version = None _torch_fx_available = _torch_onnx_dict_inputs_support_available = False if _torch_available: torch_version = version.parse(importlib_metadata.version("torch")) - _torch_fx_available = (torch_version.major, torch_version.minor) == ( + _torch_fx_available = (torch_version.major, torch_version.minor) >= ( TORCH_FX_REQUIRED_VERSION.major, TORCH_FX_REQUIRED_VERSION.minor, ) diff --git a/tests/models/swin/test_modeling_swin.py b/tests/models/swin/test_modeling_swin.py index 367c118c05..47f219d482 100644 --- a/tests/models/swin/test_modeling_swin.py +++ b/tests/models/swin/test_modeling_swin.py @@ -16,11 +16,14 @@ import copy import inspect +import os +import pickle +import tempfile import unittest from transformers import SwinConfig from transformers.testing_utils import require_torch, require_vision, slow, torch_device -from transformers.utils import cached_property, is_torch_available, is_vision_available +from transformers.utils import cached_property, is_torch_available, is_torch_fx_available, is_vision_available from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor @@ -38,6 +41,9 @@ if is_vision_available(): from transformers import AutoFeatureExtractor +if is_torch_fx_available(): + from transformers.utils.fx import symbolic_trace + def _config_zero_init(config): configs_no_init = copy.deepcopy(config) @@ -381,6 +387,97 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) + def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False): + if not is_torch_fx_available() or not self.fx_compatible: + return + + configs_no_init = _config_zero_init(config) # To be sure we have no Nan + configs_no_init.return_dict = False + + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + model.to(torch_device) + model.eval() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss) + + try: + if model.config.is_encoder_decoder: + model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward + labels = inputs.get("labels", None) + input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"] + if labels is not None: + input_names.append("labels") + filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} + + model_output = model(**filtered_inputs) + + traced_model = symbolic_trace(model, input_names) + traced_output = traced_model(**filtered_inputs) + else: + input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"] + + labels = inputs.get("labels", None) + start_positions = inputs.get("start_positions", None) + end_positions = inputs.get("end_positions", None) + if labels is not None: + input_names.append("labels") + if start_positions is not None: + input_names.append("start_positions") + if end_positions is not None: + input_names.append("end_positions") + + filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} + input_names = filtered_inputs.keys() + + model_output = model(**filtered_inputs) + + traced_model = symbolic_trace(model, input_names) + traced_output = traced_model(**filtered_inputs) + + except RuntimeError as e: + self.fail(f"Couldn't trace module: {e}") + + def flatten_output(output): + flatten = [] + for x in output: + if isinstance(x, (tuple, list)): + flatten += flatten_output(x) + elif not isinstance(x, torch.Tensor): + continue + else: + flatten.append(x) + return flatten + + model_output = flatten_output(model_output) + traced_output = flatten_output(traced_output) + num_outputs = len(model_output) + + for i in range(num_outputs): + self.assertTrue( + torch.allclose(model_output[i], traced_output[i]), + f"traced {i}th output doesn't match model {i}th output for {model_class}", + ) + + # Test that the model can be serialized and restored properly + with tempfile.TemporaryDirectory() as tmp_dir_name: + pkl_file_name = os.path.join(tmp_dir_name, "model.pkl") + try: + with open(pkl_file_name, "wb") as f: + pickle.dump(traced_model, f) + with open(pkl_file_name, "rb") as f: + loaded = pickle.load(f) + except Exception as e: + self.fail(f"Couldn't serialize / deserialize the traced model: {e}") + + loaded_output = loaded(**filtered_inputs) + loaded_output = flatten_output(loaded_output) + + for i in range(num_outputs): + self.assertTrue( + torch.allclose(model_output[i], loaded_output[i]), + f"serialized model {i}th output doesn't match model {i}th output for {model_class}", + ) + @require_vision @require_torch diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 3c01286c6b..a7d2bd9f2b 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -19,6 +19,7 @@ import inspect import json import os import os.path +import pickle import random import sys import tempfile @@ -758,8 +759,8 @@ class ModelTesterMixin: traced_model = symbolic_trace(model, input_names) traced_output = traced_model(**filtered_inputs) - except RuntimeError: - self.fail("Couldn't trace module.") + except RuntimeError as e: + self.fail(f"Couldn't trace module: {e}") def flatten_output(output): flatten = [] @@ -782,6 +783,40 @@ class ModelTesterMixin: f"traced {i}th output doesn't match model {i}th output for {model_class}", ) + # Test that the model can be TorchScripted + try: + scripted = torch.jit.script(traced_model) + except Exception as e: + self.fail(f"Could not TorchScript the traced model: {e}") + scripted_output = scripted(**filtered_inputs) + scripted_output = flatten_output(scripted_output) + + for i in range(num_outputs): + self.assertTrue( + torch.allclose(model_output[i], scripted_output[i]), + f"scripted {i}th output doesn't match model {i}th output for {model_class}", + ) + + # Test that the model can be serialized and restored properly + with tempfile.TemporaryDirectory() as tmp_dir_name: + pkl_file_name = os.path.join(tmp_dir_name, "model.pkl") + try: + with open(pkl_file_name, "wb") as f: + pickle.dump(traced_model, f) + with open(pkl_file_name, "rb") as f: + loaded = pickle.load(f) + except Exception as e: + self.fail(f"Couldn't serialize / deserialize the traced model: {e}") + + loaded_output = loaded(**filtered_inputs) + loaded_output = flatten_output(loaded_output) + + for i in range(num_outputs): + self.assertTrue( + torch.allclose(model_output[i], loaded_output[i]), + f"serialized model {i}th output doesn't match model {i}th output for {model_class}", + ) + def test_headmasking(self): if not self.test_head_masking: return