Traced models serialization and torchscripting fix (#17206)
* Fix torch.jit.script and pickling issues * Fix get_attr issues * Fix import in function * Fix GPT-J and T5 tracing for torch=1.11 * Gate graph surgery on torch version * Modeling minor changes to enable TorchScripting * Model serialization / deserialization test * Remove _assert_is_none users
This commit is contained in:
@@ -187,7 +187,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
||||
if not self.is_cross_attention:
|
||||
# if only "normal" attention layer implements causal mask
|
||||
query_length, key_length = query.size(-2), key.size(-2)
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
|
||||
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
|
||||
|
||||
if attention_mask is not None:
|
||||
|
||||
@@ -211,7 +211,7 @@ class MultiHeadSelfAttention(nn.Module):
|
||||
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
|
||||
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
|
||||
mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
|
||||
scores = scores.masked_fill(mask, -float("inf")) # (bs, n_heads, q_length, k_length)
|
||||
scores = scores.masked_fill(mask, torch.tensor(-float("inf"))) # (bs, n_heads, q_length, k_length)
|
||||
|
||||
weights = nn.functional.softmax(scores, dim=-1) # (bs, n_heads, q_length, k_length)
|
||||
weights = self.dropout(weights) # (bs, n_heads, q_length, k_length)
|
||||
|
||||
@@ -198,7 +198,7 @@ class GPT2Attention(nn.Module):
|
||||
if not self.is_cross_attention:
|
||||
# if only "normal" attention layer implements causal mask
|
||||
query_length, key_length = query.size(-2), key.size(-2)
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
|
||||
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
|
||||
|
||||
if attention_mask is not None:
|
||||
@@ -1410,7 +1410,7 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
||||
@@ -147,8 +147,8 @@ class GPTNeoSelfAttention(nn.Module):
|
||||
self.register_buffer("bias", bias)
|
||||
self.register_buffer("masked_bias", torch.tensor(-1e9))
|
||||
|
||||
self.attn_dropout = nn.Dropout(config.attention_dropout)
|
||||
self.resid_dropout = nn.Dropout(config.resid_dropout)
|
||||
self.attn_dropout = nn.Dropout(float(config.attention_dropout))
|
||||
self.resid_dropout = nn.Dropout(float(config.resid_dropout))
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_heads
|
||||
@@ -188,7 +188,7 @@ class GPTNeoSelfAttention(nn.Module):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||
|
||||
query_length, key_length = query.size(-2), key.size(-2)
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
|
||||
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
|
||||
|
||||
if attention_mask is not None:
|
||||
@@ -290,7 +290,7 @@ class GPTNeoMLP(nn.Module):
|
||||
self.c_fc = nn.Linear(embed_dim, intermediate_size)
|
||||
self.c_proj = nn.Linear(intermediate_size, embed_dim)
|
||||
self.act = ACT2FN[config.activation_function]
|
||||
self.dropout = nn.Dropout(config.resid_dropout)
|
||||
self.dropout = nn.Dropout(float(config.resid_dropout))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.c_fc(hidden_states)
|
||||
@@ -475,7 +475,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
|
||||
self.embed_dim = config.hidden_size
|
||||
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.drop = nn.Dropout(config.embed_dropout)
|
||||
self.drop = nn.Dropout(float(config.embed_dropout))
|
||||
self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)])
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
@@ -887,7 +887,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
||||
@@ -69,7 +69,7 @@ def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
|
||||
def rotate_every_two(x):
|
||||
x1 = x[:, :, :, ::2]
|
||||
x2 = x[:, :, :, 1::2]
|
||||
x = torch.stack((-x2, x1), axis=-1)
|
||||
x = torch.stack((-x2, x1), dim=-1)
|
||||
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
|
||||
|
||||
|
||||
@@ -163,7 +163,7 @@ class GPTJAttention(nn.Module):
|
||||
|
||||
# compute causal mask from causal mask buffer
|
||||
query_length, key_length = query.size(-2), key.size(-2)
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
|
||||
|
||||
# Keep the attention weights computation in fp32 to avoid overflow issues
|
||||
query = query.to(torch.float32)
|
||||
@@ -971,7 +971,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
||||
@@ -226,9 +226,9 @@ class MobileBertEmbeddings(nn.Module):
|
||||
# dimensional output.
|
||||
inputs_embeds = torch.cat(
|
||||
[
|
||||
nn.functional.pad(inputs_embeds[:, 1:], [0, 0, 0, 1, 0, 0], value=0),
|
||||
nn.functional.pad(inputs_embeds[:, 1:], [0, 0, 0, 1, 0, 0], value=0.0),
|
||||
inputs_embeds,
|
||||
nn.functional.pad(inputs_embeds[:, :-1], [0, 0, 1, 0, 0, 0], value=0),
|
||||
nn.functional.pad(inputs_embeds[:, :-1], [0, 0, 1, 0, 0, 0], value=0.0),
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ import collections
|
||||
import functools
|
||||
import inspect
|
||||
import math
|
||||
import operator
|
||||
import random
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
|
||||
@@ -26,6 +27,7 @@ import torch
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.fx import Graph, GraphModule, Proxy, Tracer
|
||||
from torch.fx.proxy import ParameterProxy
|
||||
|
||||
from .. import (
|
||||
CONFIG_MAPPING,
|
||||
@@ -126,45 +128,45 @@ _SUPPORTED_MODELS = tuple(
|
||||
)
|
||||
|
||||
|
||||
def embedding_override(self, input):
|
||||
def torch_nn_embedding(self, input):
|
||||
return torch.empty(*input.shape, self.weight.shape[-1], device="meta")
|
||||
|
||||
|
||||
def torch_nn_layernorm_override(self, input):
|
||||
def torch_nn_layernorm(self, input):
|
||||
return input
|
||||
|
||||
|
||||
def torch_nn_linear_override(self, input):
|
||||
def torch_nn_linear(self, input):
|
||||
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
|
||||
|
||||
|
||||
def torch_relu_override(x):
|
||||
def torch_relu(x):
|
||||
return x
|
||||
|
||||
|
||||
def torch_nn_relu_override(self, x):
|
||||
def torch_nn_relu(self, x):
|
||||
return x
|
||||
|
||||
|
||||
def torch_nn_functional_relu_override(x, inplace=False):
|
||||
def torch_nn_functional_relu(x, inplace=False):
|
||||
if not inplace:
|
||||
raise ValueError("Don't support in-place functional.relu for MetaTensor analysis")
|
||||
return x
|
||||
|
||||
|
||||
def torch_where_override(condition, x, y):
|
||||
def torch_where(condition, x, y):
|
||||
# torch.where returns the broadcasted tensor of condition, x, and y,
|
||||
# so hack it by using addition
|
||||
return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
|
||||
|
||||
|
||||
def torch_abs_override(input, *, out=None):
|
||||
if out is None:
|
||||
def torch_abs(input, *, out=None):
|
||||
if out is not None:
|
||||
raise ValueError("Don't support in-place abs for MetaTensor analysis")
|
||||
return input
|
||||
|
||||
|
||||
def torch_arange_override(*args, **kwargs):
|
||||
def torch_arange(*args, **kwargs):
|
||||
n = len(args)
|
||||
step = 1
|
||||
if n == 1:
|
||||
@@ -179,7 +181,7 @@ def torch_arange_override(*args, **kwargs):
|
||||
return torch.empty((end - start) // step, dtype=dtype, device="meta")
|
||||
|
||||
|
||||
def torch_cat_override(tensors, dim=None, axis=None, *, out=None):
|
||||
def torch_cat(tensors, dim=None, axis=None, *, out=None):
|
||||
if dim is None and axis is None:
|
||||
dim = 0
|
||||
if dim is None and axis is not None:
|
||||
@@ -193,7 +195,7 @@ def torch_cat_override(tensors, dim=None, axis=None, *, out=None):
|
||||
return torch.empty(final_shape, device="meta")
|
||||
|
||||
|
||||
def torch_stack_override(tensors, dim=None, axis=None, *, out=None):
|
||||
def torch_stack(tensors, dim=None, axis=None, *, out=None):
|
||||
if dim is None and axis is None:
|
||||
dim = 0
|
||||
if dim is None and axis is not None:
|
||||
@@ -205,7 +207,7 @@ def torch_stack_override(tensors, dim=None, axis=None, *, out=None):
|
||||
return torch.empty(shape, device="meta")
|
||||
|
||||
|
||||
def torch_add_override(input, other, *, alpha=1, out=None):
|
||||
def torch_add(input, other, *, alpha=1, out=None):
|
||||
if not isinstance(input, torch.Tensor):
|
||||
return torch.empty_like(other, device="meta")
|
||||
if not isinstance(other, torch.Tensor):
|
||||
@@ -219,15 +221,15 @@ def torch_add_override(input, other, *, alpha=1, out=None):
|
||||
return torch.empty(shape, device="meta")
|
||||
|
||||
|
||||
def torch_mul_override(input, other, *, out=None):
|
||||
return torch_add_override(input, other, out=out)
|
||||
def torch_mul(input, other, *, out=None):
|
||||
return torch_add(input, other, out=out)
|
||||
|
||||
|
||||
def torch_tensor_mul_override(self, other):
|
||||
return torch_mul_override(self, other)
|
||||
def torch_tensor_mul(self, other):
|
||||
return torch_mul(self, other)
|
||||
|
||||
|
||||
def torch_matmul_override(input, other, *, out=None):
|
||||
def torch_matmul(input, other, *, out=None):
|
||||
d1 = input.dim()
|
||||
d2 = other.dim()
|
||||
shape = None
|
||||
@@ -263,7 +265,13 @@ def torch_matmul_override(input, other, *, out=None):
|
||||
return torch.empty(*shape, device="meta")
|
||||
|
||||
|
||||
def torch_tensor_repeat_override(self, *sizes):
|
||||
def torch_einsum(equation, *operands):
|
||||
# TODO: infer shape without performing the computation, this might be quite hard.
|
||||
concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands)
|
||||
return torch.einsum(equation, *concrete_operands).to("meta")
|
||||
|
||||
|
||||
def torch_tensor_repeat(self, *sizes):
|
||||
shape = list(self.shape)
|
||||
for i, x in enumerate(sizes):
|
||||
shape[i] *= x
|
||||
@@ -305,6 +313,18 @@ def torch_nn_conv2d(self, input):
|
||||
return torch.empty(shape, device="meta")
|
||||
|
||||
|
||||
def torch_unsqueeze(input, dim):
|
||||
shape = list(input.shape)
|
||||
if dim < 0:
|
||||
dim = input.dim() + 1 + dim
|
||||
shape.insert(dim, 1)
|
||||
return torch.empty(shape, device="meta")
|
||||
|
||||
|
||||
def torch_tensor_unsqueeze(self, dim):
|
||||
return torch_unsqueeze(self, dim)
|
||||
|
||||
|
||||
def torch_nn_mseloss(self, input, target):
|
||||
if self.reduction == "none":
|
||||
shape = target.shape
|
||||
@@ -329,31 +349,42 @@ def torch_nn_bcewithlogitsloss(self, input, target):
|
||||
return torch.empty(shape, device="meta")
|
||||
|
||||
|
||||
def operator_getitem(a, b):
|
||||
if isinstance(a, torch.Tensor):
|
||||
# TODO: infer shape without performing the computation.
|
||||
return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
|
||||
return operator.getitem(a, b)
|
||||
|
||||
|
||||
_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
|
||||
torch.nn.Embedding: embedding_override,
|
||||
torch.nn.LayerNorm: torch_nn_layernorm_override,
|
||||
torch.nn.Linear: torch_nn_linear_override,
|
||||
torch.relu: torch_relu_override,
|
||||
torch.nn.functional.relu: torch_nn_functional_relu_override,
|
||||
torch.nn.ReLU: torch_nn_relu_override,
|
||||
torch.where: torch_where_override,
|
||||
torch.abs: torch_abs_override,
|
||||
torch.arange: torch_arange_override,
|
||||
torch.cat: torch_cat_override,
|
||||
torch.stack: torch_stack_override,
|
||||
torch.add: torch_add_override,
|
||||
torch.mul: torch_mul_override,
|
||||
torch.Tensor.mul: torch_tensor_mul_override,
|
||||
torch.matmul: torch_matmul_override,
|
||||
torch.Tensor.repeat: torch_tensor_repeat_override,
|
||||
torch.nn.Embedding: torch_nn_embedding,
|
||||
torch.nn.LayerNorm: torch_nn_layernorm,
|
||||
torch.nn.Linear: torch_nn_linear,
|
||||
torch.relu: torch_relu,
|
||||
torch.nn.functional.relu: torch_nn_functional_relu,
|
||||
torch.nn.ReLU: torch_nn_relu,
|
||||
torch.where: torch_where,
|
||||
torch.abs: torch_abs,
|
||||
torch.arange: torch_arange,
|
||||
torch.cat: torch_cat,
|
||||
torch.stack: torch_stack,
|
||||
torch.add: torch_add,
|
||||
torch.mul: torch_mul,
|
||||
torch.Tensor.mul: torch_tensor_mul,
|
||||
torch.matmul: torch_matmul,
|
||||
torch.einsum: torch_einsum,
|
||||
torch.Tensor.repeat: torch_tensor_repeat,
|
||||
torch.roll: torch_roll,
|
||||
# TODO: those might not be needed.
|
||||
# torch.index_select: torch_index_select,
|
||||
# torch.Tensor.index_select: torch_tensor_index_select,
|
||||
torch.nn.Conv2d: torch_nn_conv2d,
|
||||
torch.unsqueeze: torch_unsqueeze,
|
||||
torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
|
||||
torch.nn.MSELoss: torch_nn_mseloss,
|
||||
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
|
||||
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
|
||||
operator.getitem: operator_getitem,
|
||||
}
|
||||
|
||||
|
||||
@@ -371,7 +402,6 @@ class HFProxy(Proxy):
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.tracer.root.dtype
|
||||
if hasattr(self, "_metadata") and self._metadata is not None:
|
||||
return self._metadata.dtype
|
||||
return self.tracer.create_proxy("call_function", builtins.getattr, (self, "dtype"), {})
|
||||
@@ -400,7 +430,7 @@ class HFProxy(Proxy):
|
||||
return HFAttribute(self, k)
|
||||
|
||||
def __setitem__(self, indices, values):
|
||||
return self.tracer.create_proxy("call_method", "__setitem__", (self, indices, values), {})
|
||||
return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
|
||||
|
||||
def __contains__(self, key):
|
||||
# To handle cases such as :
|
||||
@@ -480,14 +510,14 @@ class HFTracer(Tracer):
|
||||
regular PyTorch torch.fx.Proxy.
|
||||
"""
|
||||
|
||||
# Feature flag for proxying accesses to buffer values
|
||||
proxy_buffer_attributes: bool = True
|
||||
allow_insert_stateless_mods: bool = True
|
||||
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full_like", "eye"]
|
||||
|
||||
def __init__(self, autowrap_modules=(math,), autowrap_functions=(), enable_cpatching=False):
|
||||
def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
|
||||
|
||||
super().__init__(
|
||||
autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions, enable_cpatching=enable_cpatching
|
||||
)
|
||||
super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions)
|
||||
|
||||
if not is_torch_fx_available():
|
||||
torch_version = version.parse(importlib_metadata.version("torch"))
|
||||
@@ -500,7 +530,9 @@ class HFTracer(Tracer):
|
||||
self, model: PreTrainedModel, input_name: str, shape: List[int]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Generates dummy input for model inference recording."""
|
||||
model_class = model.__class__
|
||||
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
|
||||
# from pickle, or from the "__class__" attribute in the general case.
|
||||
model_class = getattr(model, "class_for_deserialization", model.__class__)
|
||||
device = model.device
|
||||
inputs_dict = {}
|
||||
|
||||
@@ -641,7 +673,38 @@ class HFTracer(Tracer):
|
||||
if getattr(self, "_disable_module_getattr", False):
|
||||
return attr_val
|
||||
else:
|
||||
return super()._module_getattr(attr, attr_val, parameter_proxy_cache)
|
||||
# return super()._module_getattr(attr, attr_val, parameter_proxy_cache)
|
||||
def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
|
||||
for n, p in collection_to_search:
|
||||
if attr_val is p:
|
||||
if n not in parameter_proxy_cache:
|
||||
kwargs = {}
|
||||
if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
|
||||
kwargs["proxy_factory_fn"] = (
|
||||
None
|
||||
if not self.param_shapes_constant
|
||||
else lambda node: ParameterProxy(self, node, n, attr_val)
|
||||
)
|
||||
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
|
||||
parameter_proxy_cache[n] = val_proxy
|
||||
return parameter_proxy_cache[n]
|
||||
return None
|
||||
|
||||
if isinstance(attr_val, torch.nn.Parameter):
|
||||
maybe_parameter_proxy = maybe_get_proxy_for_attr(
|
||||
attr_val, self.root.named_parameters(), parameter_proxy_cache
|
||||
)
|
||||
if maybe_parameter_proxy is not None:
|
||||
return maybe_parameter_proxy
|
||||
|
||||
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
|
||||
maybe_buffer_proxy = maybe_get_proxy_for_attr(
|
||||
attr_val, self.root.named_buffers(), parameter_proxy_cache
|
||||
)
|
||||
if maybe_buffer_proxy is not None:
|
||||
return maybe_buffer_proxy
|
||||
|
||||
return attr_val
|
||||
|
||||
def call_module(self, m, forward, args, kwargs):
|
||||
self.orig_forward = forward
|
||||
@@ -693,17 +756,29 @@ class HFTracer(Tracer):
|
||||
for name, (_, orig) in self.patched_torch_methods.items():
|
||||
setattr(torch, name, orig)
|
||||
|
||||
# TODO: keep this until necessary.
|
||||
# This is necessary because concrete args are added as input to the traced module since
|
||||
# https://github.com/pytorch/pytorch/pull/55888.
|
||||
# A PR that solves this was posted: https://github.com/pytorch/pytorch/pull/59569 but it was not merged yet.
|
||||
for node in self.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
# Removing default values for inputs as the forward pass will fail with them.
|
||||
if node.target in input_names:
|
||||
node.args = ()
|
||||
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
|
||||
# It cannot infer on the attributes and methods the input should have, and fails.
|
||||
node.type = torch.Tensor
|
||||
# It is a concrete arg so it is not used and should be removed.
|
||||
else:
|
||||
if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
|
||||
# Newer versions of torch.fx emit an assert statement
|
||||
# for concrete arguments; delete those before we delete
|
||||
# the concrete arg.
|
||||
to_delete = []
|
||||
for user in node.users:
|
||||
if user.target == torch.fx._symbolic_trace._assert_is_none:
|
||||
to_delete.append(user)
|
||||
for user in to_delete:
|
||||
self.graph.erase_node(user)
|
||||
|
||||
self.graph.erase_node(node)
|
||||
|
||||
# TODO: solves GraphModule creation.
|
||||
@@ -809,4 +884,10 @@ def symbolic_trace(
|
||||
traced_graph = tracer.trace(model, concrete_args=concrete_args)
|
||||
traced = torch.fx.GraphModule(model, traced_graph)
|
||||
|
||||
traced.config = model.config
|
||||
# The model class must be stored as an attribute to allow model deserialization, which uses trace, and thus
|
||||
# _generate_dummy_input, where the model class is needed.
|
||||
traced.class_for_deserialization = model.__class__
|
||||
traced.device = model.device
|
||||
|
||||
return traced
|
||||
|
||||
@@ -325,7 +325,7 @@ torch_version = None
|
||||
_torch_fx_available = _torch_onnx_dict_inputs_support_available = False
|
||||
if _torch_available:
|
||||
torch_version = version.parse(importlib_metadata.version("torch"))
|
||||
_torch_fx_available = (torch_version.major, torch_version.minor) == (
|
||||
_torch_fx_available = (torch_version.major, torch_version.minor) >= (
|
||||
TORCH_FX_REQUIRED_VERSION.major,
|
||||
TORCH_FX_REQUIRED_VERSION.minor,
|
||||
)
|
||||
|
||||
@@ -16,11 +16,14 @@
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import SwinConfig
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||
from transformers.utils import cached_property, is_torch_available, is_torch_fx_available, is_vision_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
@@ -38,6 +41,9 @@ if is_vision_available():
|
||||
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
if is_torch_fx_available():
|
||||
from transformers.utils.fx import symbolic_trace
|
||||
|
||||
|
||||
def _config_zero_init(config):
|
||||
configs_no_init = copy.deepcopy(config)
|
||||
@@ -381,6 +387,97 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
|
||||
if not is_torch_fx_available() or not self.fx_compatible:
|
||||
return
|
||||
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
configs_no_init.return_dict = False
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
|
||||
|
||||
try:
|
||||
if model.config.is_encoder_decoder:
|
||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||
labels = inputs.get("labels", None)
|
||||
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
else:
|
||||
input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"]
|
||||
|
||||
labels = inputs.get("labels", None)
|
||||
start_positions = inputs.get("start_positions", None)
|
||||
end_positions = inputs.get("end_positions", None)
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
if start_positions is not None:
|
||||
input_names.append("start_positions")
|
||||
if end_positions is not None:
|
||||
input_names.append("end_positions")
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = filtered_inputs.keys()
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
|
||||
except RuntimeError as e:
|
||||
self.fail(f"Couldn't trace module: {e}")
|
||||
|
||||
def flatten_output(output):
|
||||
flatten = []
|
||||
for x in output:
|
||||
if isinstance(x, (tuple, list)):
|
||||
flatten += flatten_output(x)
|
||||
elif not isinstance(x, torch.Tensor):
|
||||
continue
|
||||
else:
|
||||
flatten.append(x)
|
||||
return flatten
|
||||
|
||||
model_output = flatten_output(model_output)
|
||||
traced_output = flatten_output(traced_output)
|
||||
num_outputs = len(model_output)
|
||||
|
||||
for i in range(num_outputs):
|
||||
self.assertTrue(
|
||||
torch.allclose(model_output[i], traced_output[i]),
|
||||
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
# Test that the model can be serialized and restored properly
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
|
||||
try:
|
||||
with open(pkl_file_name, "wb") as f:
|
||||
pickle.dump(traced_model, f)
|
||||
with open(pkl_file_name, "rb") as f:
|
||||
loaded = pickle.load(f)
|
||||
except Exception as e:
|
||||
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
|
||||
|
||||
loaded_output = loaded(**filtered_inputs)
|
||||
loaded_output = flatten_output(loaded_output)
|
||||
|
||||
for i in range(num_outputs):
|
||||
self.assertTrue(
|
||||
torch.allclose(model_output[i], loaded_output[i]),
|
||||
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
|
||||
@@ -19,6 +19,7 @@ import inspect
|
||||
import json
|
||||
import os
|
||||
import os.path
|
||||
import pickle
|
||||
import random
|
||||
import sys
|
||||
import tempfile
|
||||
@@ -758,8 +759,8 @@ class ModelTesterMixin:
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
|
||||
except RuntimeError:
|
||||
self.fail("Couldn't trace module.")
|
||||
except RuntimeError as e:
|
||||
self.fail(f"Couldn't trace module: {e}")
|
||||
|
||||
def flatten_output(output):
|
||||
flatten = []
|
||||
@@ -782,6 +783,40 @@ class ModelTesterMixin:
|
||||
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
# Test that the model can be TorchScripted
|
||||
try:
|
||||
scripted = torch.jit.script(traced_model)
|
||||
except Exception as e:
|
||||
self.fail(f"Could not TorchScript the traced model: {e}")
|
||||
scripted_output = scripted(**filtered_inputs)
|
||||
scripted_output = flatten_output(scripted_output)
|
||||
|
||||
for i in range(num_outputs):
|
||||
self.assertTrue(
|
||||
torch.allclose(model_output[i], scripted_output[i]),
|
||||
f"scripted {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
# Test that the model can be serialized and restored properly
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
|
||||
try:
|
||||
with open(pkl_file_name, "wb") as f:
|
||||
pickle.dump(traced_model, f)
|
||||
with open(pkl_file_name, "rb") as f:
|
||||
loaded = pickle.load(f)
|
||||
except Exception as e:
|
||||
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
|
||||
|
||||
loaded_output = loaded(**filtered_inputs)
|
||||
loaded_output = flatten_output(loaded_output)
|
||||
|
||||
for i in range(num_outputs):
|
||||
self.assertTrue(
|
||||
torch.allclose(model_output[i], loaded_output[i]),
|
||||
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
def test_headmasking(self):
|
||||
if not self.test_head_masking:
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user