Traced models serialization and torchscripting fix (#17206)
* Fix torch.jit.script and pickling issues * Fix get_attr issues * Fix import in function * Fix GPT-J and T5 tracing for torch=1.11 * Gate graph surgery on torch version * Modeling minor changes to enable TorchScripting * Model serialization / deserialization test * Remove _assert_is_none users
This commit is contained in:
@@ -187,7 +187,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
|||||||
if not self.is_cross_attention:
|
if not self.is_cross_attention:
|
||||||
# if only "normal" attention layer implements causal mask
|
# if only "normal" attention layer implements causal mask
|
||||||
query_length, key_length = query.size(-2), key.size(-2)
|
query_length, key_length = query.size(-2), key.size(-2)
|
||||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
|
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
|
||||||
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
|
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
|
|||||||
@@ -211,7 +211,7 @@ class MultiHeadSelfAttention(nn.Module):
|
|||||||
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
|
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
|
||||||
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
|
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
|
||||||
mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
|
mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
|
||||||
scores = scores.masked_fill(mask, -float("inf")) # (bs, n_heads, q_length, k_length)
|
scores = scores.masked_fill(mask, torch.tensor(-float("inf"))) # (bs, n_heads, q_length, k_length)
|
||||||
|
|
||||||
weights = nn.functional.softmax(scores, dim=-1) # (bs, n_heads, q_length, k_length)
|
weights = nn.functional.softmax(scores, dim=-1) # (bs, n_heads, q_length, k_length)
|
||||||
weights = self.dropout(weights) # (bs, n_heads, q_length, k_length)
|
weights = self.dropout(weights) # (bs, n_heads, q_length, k_length)
|
||||||
|
|||||||
@@ -198,7 +198,7 @@ class GPT2Attention(nn.Module):
|
|||||||
if not self.is_cross_attention:
|
if not self.is_cross_attention:
|
||||||
# if only "normal" attention layer implements causal mask
|
# if only "normal" attention layer implements causal mask
|
||||||
query_length, key_length = query.size(-2), key.size(-2)
|
query_length, key_length = query.size(-2), key.size(-2)
|
||||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
|
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
|
||||||
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
|
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
@@ -1410,7 +1410,7 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
|
|||||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||||
)
|
)
|
||||||
|
|
||||||
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
|
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
|||||||
@@ -147,8 +147,8 @@ class GPTNeoSelfAttention(nn.Module):
|
|||||||
self.register_buffer("bias", bias)
|
self.register_buffer("bias", bias)
|
||||||
self.register_buffer("masked_bias", torch.tensor(-1e9))
|
self.register_buffer("masked_bias", torch.tensor(-1e9))
|
||||||
|
|
||||||
self.attn_dropout = nn.Dropout(config.attention_dropout)
|
self.attn_dropout = nn.Dropout(float(config.attention_dropout))
|
||||||
self.resid_dropout = nn.Dropout(config.resid_dropout)
|
self.resid_dropout = nn.Dropout(float(config.resid_dropout))
|
||||||
|
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
self.num_heads = config.num_heads
|
self.num_heads = config.num_heads
|
||||||
@@ -188,7 +188,7 @@ class GPTNeoSelfAttention(nn.Module):
|
|||||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||||
|
|
||||||
query_length, key_length = query.size(-2), key.size(-2)
|
query_length, key_length = query.size(-2), key.size(-2)
|
||||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
|
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
|
||||||
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
|
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
@@ -290,7 +290,7 @@ class GPTNeoMLP(nn.Module):
|
|||||||
self.c_fc = nn.Linear(embed_dim, intermediate_size)
|
self.c_fc = nn.Linear(embed_dim, intermediate_size)
|
||||||
self.c_proj = nn.Linear(intermediate_size, embed_dim)
|
self.c_proj = nn.Linear(intermediate_size, embed_dim)
|
||||||
self.act = ACT2FN[config.activation_function]
|
self.act = ACT2FN[config.activation_function]
|
||||||
self.dropout = nn.Dropout(config.resid_dropout)
|
self.dropout = nn.Dropout(float(config.resid_dropout))
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
hidden_states = self.c_fc(hidden_states)
|
hidden_states = self.c_fc(hidden_states)
|
||||||
@@ -475,7 +475,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
|
|||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
||||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||||
self.drop = nn.Dropout(config.embed_dropout)
|
self.drop = nn.Dropout(float(config.embed_dropout))
|
||||||
self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)])
|
self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)])
|
||||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
@@ -887,7 +887,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
|
|||||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||||
)
|
)
|
||||||
|
|
||||||
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
|
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
|
|||||||
def rotate_every_two(x):
|
def rotate_every_two(x):
|
||||||
x1 = x[:, :, :, ::2]
|
x1 = x[:, :, :, ::2]
|
||||||
x2 = x[:, :, :, 1::2]
|
x2 = x[:, :, :, 1::2]
|
||||||
x = torch.stack((-x2, x1), axis=-1)
|
x = torch.stack((-x2, x1), dim=-1)
|
||||||
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
|
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
|
||||||
|
|
||||||
|
|
||||||
@@ -163,7 +163,7 @@ class GPTJAttention(nn.Module):
|
|||||||
|
|
||||||
# compute causal mask from causal mask buffer
|
# compute causal mask from causal mask buffer
|
||||||
query_length, key_length = query.size(-2), key.size(-2)
|
query_length, key_length = query.size(-2), key.size(-2)
|
||||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
|
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
|
||||||
|
|
||||||
# Keep the attention weights computation in fp32 to avoid overflow issues
|
# Keep the attention weights computation in fp32 to avoid overflow issues
|
||||||
query = query.to(torch.float32)
|
query = query.to(torch.float32)
|
||||||
@@ -971,7 +971,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
|
|||||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||||
)
|
)
|
||||||
|
|
||||||
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
|
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
|||||||
@@ -226,9 +226,9 @@ class MobileBertEmbeddings(nn.Module):
|
|||||||
# dimensional output.
|
# dimensional output.
|
||||||
inputs_embeds = torch.cat(
|
inputs_embeds = torch.cat(
|
||||||
[
|
[
|
||||||
nn.functional.pad(inputs_embeds[:, 1:], [0, 0, 0, 1, 0, 0], value=0),
|
nn.functional.pad(inputs_embeds[:, 1:], [0, 0, 0, 1, 0, 0], value=0.0),
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
nn.functional.pad(inputs_embeds[:, :-1], [0, 0, 1, 0, 0, 0], value=0),
|
nn.functional.pad(inputs_embeds[:, :-1], [0, 0, 1, 0, 0, 0], value=0.0),
|
||||||
],
|
],
|
||||||
dim=2,
|
dim=2,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import collections
|
|||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import math
|
import math
|
||||||
|
import operator
|
||||||
import random
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
|
||||||
@@ -26,6 +27,7 @@ import torch
|
|||||||
from packaging import version
|
from packaging import version
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.fx import Graph, GraphModule, Proxy, Tracer
|
from torch.fx import Graph, GraphModule, Proxy, Tracer
|
||||||
|
from torch.fx.proxy import ParameterProxy
|
||||||
|
|
||||||
from .. import (
|
from .. import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
@@ -126,45 +128,45 @@ _SUPPORTED_MODELS = tuple(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def embedding_override(self, input):
|
def torch_nn_embedding(self, input):
|
||||||
return torch.empty(*input.shape, self.weight.shape[-1], device="meta")
|
return torch.empty(*input.shape, self.weight.shape[-1], device="meta")
|
||||||
|
|
||||||
|
|
||||||
def torch_nn_layernorm_override(self, input):
|
def torch_nn_layernorm(self, input):
|
||||||
return input
|
return input
|
||||||
|
|
||||||
|
|
||||||
def torch_nn_linear_override(self, input):
|
def torch_nn_linear(self, input):
|
||||||
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
|
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
|
||||||
|
|
||||||
|
|
||||||
def torch_relu_override(x):
|
def torch_relu(x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def torch_nn_relu_override(self, x):
|
def torch_nn_relu(self, x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def torch_nn_functional_relu_override(x, inplace=False):
|
def torch_nn_functional_relu(x, inplace=False):
|
||||||
if not inplace:
|
if not inplace:
|
||||||
raise ValueError("Don't support in-place functional.relu for MetaTensor analysis")
|
raise ValueError("Don't support in-place functional.relu for MetaTensor analysis")
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def torch_where_override(condition, x, y):
|
def torch_where(condition, x, y):
|
||||||
# torch.where returns the broadcasted tensor of condition, x, and y,
|
# torch.where returns the broadcasted tensor of condition, x, and y,
|
||||||
# so hack it by using addition
|
# so hack it by using addition
|
||||||
return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
|
return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
|
||||||
|
|
||||||
|
|
||||||
def torch_abs_override(input, *, out=None):
|
def torch_abs(input, *, out=None):
|
||||||
if out is None:
|
if out is not None:
|
||||||
raise ValueError("Don't support in-place abs for MetaTensor analysis")
|
raise ValueError("Don't support in-place abs for MetaTensor analysis")
|
||||||
return input
|
return input
|
||||||
|
|
||||||
|
|
||||||
def torch_arange_override(*args, **kwargs):
|
def torch_arange(*args, **kwargs):
|
||||||
n = len(args)
|
n = len(args)
|
||||||
step = 1
|
step = 1
|
||||||
if n == 1:
|
if n == 1:
|
||||||
@@ -179,7 +181,7 @@ def torch_arange_override(*args, **kwargs):
|
|||||||
return torch.empty((end - start) // step, dtype=dtype, device="meta")
|
return torch.empty((end - start) // step, dtype=dtype, device="meta")
|
||||||
|
|
||||||
|
|
||||||
def torch_cat_override(tensors, dim=None, axis=None, *, out=None):
|
def torch_cat(tensors, dim=None, axis=None, *, out=None):
|
||||||
if dim is None and axis is None:
|
if dim is None and axis is None:
|
||||||
dim = 0
|
dim = 0
|
||||||
if dim is None and axis is not None:
|
if dim is None and axis is not None:
|
||||||
@@ -193,7 +195,7 @@ def torch_cat_override(tensors, dim=None, axis=None, *, out=None):
|
|||||||
return torch.empty(final_shape, device="meta")
|
return torch.empty(final_shape, device="meta")
|
||||||
|
|
||||||
|
|
||||||
def torch_stack_override(tensors, dim=None, axis=None, *, out=None):
|
def torch_stack(tensors, dim=None, axis=None, *, out=None):
|
||||||
if dim is None and axis is None:
|
if dim is None and axis is None:
|
||||||
dim = 0
|
dim = 0
|
||||||
if dim is None and axis is not None:
|
if dim is None and axis is not None:
|
||||||
@@ -205,7 +207,7 @@ def torch_stack_override(tensors, dim=None, axis=None, *, out=None):
|
|||||||
return torch.empty(shape, device="meta")
|
return torch.empty(shape, device="meta")
|
||||||
|
|
||||||
|
|
||||||
def torch_add_override(input, other, *, alpha=1, out=None):
|
def torch_add(input, other, *, alpha=1, out=None):
|
||||||
if not isinstance(input, torch.Tensor):
|
if not isinstance(input, torch.Tensor):
|
||||||
return torch.empty_like(other, device="meta")
|
return torch.empty_like(other, device="meta")
|
||||||
if not isinstance(other, torch.Tensor):
|
if not isinstance(other, torch.Tensor):
|
||||||
@@ -219,15 +221,15 @@ def torch_add_override(input, other, *, alpha=1, out=None):
|
|||||||
return torch.empty(shape, device="meta")
|
return torch.empty(shape, device="meta")
|
||||||
|
|
||||||
|
|
||||||
def torch_mul_override(input, other, *, out=None):
|
def torch_mul(input, other, *, out=None):
|
||||||
return torch_add_override(input, other, out=out)
|
return torch_add(input, other, out=out)
|
||||||
|
|
||||||
|
|
||||||
def torch_tensor_mul_override(self, other):
|
def torch_tensor_mul(self, other):
|
||||||
return torch_mul_override(self, other)
|
return torch_mul(self, other)
|
||||||
|
|
||||||
|
|
||||||
def torch_matmul_override(input, other, *, out=None):
|
def torch_matmul(input, other, *, out=None):
|
||||||
d1 = input.dim()
|
d1 = input.dim()
|
||||||
d2 = other.dim()
|
d2 = other.dim()
|
||||||
shape = None
|
shape = None
|
||||||
@@ -263,7 +265,13 @@ def torch_matmul_override(input, other, *, out=None):
|
|||||||
return torch.empty(*shape, device="meta")
|
return torch.empty(*shape, device="meta")
|
||||||
|
|
||||||
|
|
||||||
def torch_tensor_repeat_override(self, *sizes):
|
def torch_einsum(equation, *operands):
|
||||||
|
# TODO: infer shape without performing the computation, this might be quite hard.
|
||||||
|
concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands)
|
||||||
|
return torch.einsum(equation, *concrete_operands).to("meta")
|
||||||
|
|
||||||
|
|
||||||
|
def torch_tensor_repeat(self, *sizes):
|
||||||
shape = list(self.shape)
|
shape = list(self.shape)
|
||||||
for i, x in enumerate(sizes):
|
for i, x in enumerate(sizes):
|
||||||
shape[i] *= x
|
shape[i] *= x
|
||||||
@@ -305,6 +313,18 @@ def torch_nn_conv2d(self, input):
|
|||||||
return torch.empty(shape, device="meta")
|
return torch.empty(shape, device="meta")
|
||||||
|
|
||||||
|
|
||||||
|
def torch_unsqueeze(input, dim):
|
||||||
|
shape = list(input.shape)
|
||||||
|
if dim < 0:
|
||||||
|
dim = input.dim() + 1 + dim
|
||||||
|
shape.insert(dim, 1)
|
||||||
|
return torch.empty(shape, device="meta")
|
||||||
|
|
||||||
|
|
||||||
|
def torch_tensor_unsqueeze(self, dim):
|
||||||
|
return torch_unsqueeze(self, dim)
|
||||||
|
|
||||||
|
|
||||||
def torch_nn_mseloss(self, input, target):
|
def torch_nn_mseloss(self, input, target):
|
||||||
if self.reduction == "none":
|
if self.reduction == "none":
|
||||||
shape = target.shape
|
shape = target.shape
|
||||||
@@ -329,31 +349,42 @@ def torch_nn_bcewithlogitsloss(self, input, target):
|
|||||||
return torch.empty(shape, device="meta")
|
return torch.empty(shape, device="meta")
|
||||||
|
|
||||||
|
|
||||||
|
def operator_getitem(a, b):
|
||||||
|
if isinstance(a, torch.Tensor):
|
||||||
|
# TODO: infer shape without performing the computation.
|
||||||
|
return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
|
||||||
|
return operator.getitem(a, b)
|
||||||
|
|
||||||
|
|
||||||
_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
|
_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
|
||||||
torch.nn.Embedding: embedding_override,
|
torch.nn.Embedding: torch_nn_embedding,
|
||||||
torch.nn.LayerNorm: torch_nn_layernorm_override,
|
torch.nn.LayerNorm: torch_nn_layernorm,
|
||||||
torch.nn.Linear: torch_nn_linear_override,
|
torch.nn.Linear: torch_nn_linear,
|
||||||
torch.relu: torch_relu_override,
|
torch.relu: torch_relu,
|
||||||
torch.nn.functional.relu: torch_nn_functional_relu_override,
|
torch.nn.functional.relu: torch_nn_functional_relu,
|
||||||
torch.nn.ReLU: torch_nn_relu_override,
|
torch.nn.ReLU: torch_nn_relu,
|
||||||
torch.where: torch_where_override,
|
torch.where: torch_where,
|
||||||
torch.abs: torch_abs_override,
|
torch.abs: torch_abs,
|
||||||
torch.arange: torch_arange_override,
|
torch.arange: torch_arange,
|
||||||
torch.cat: torch_cat_override,
|
torch.cat: torch_cat,
|
||||||
torch.stack: torch_stack_override,
|
torch.stack: torch_stack,
|
||||||
torch.add: torch_add_override,
|
torch.add: torch_add,
|
||||||
torch.mul: torch_mul_override,
|
torch.mul: torch_mul,
|
||||||
torch.Tensor.mul: torch_tensor_mul_override,
|
torch.Tensor.mul: torch_tensor_mul,
|
||||||
torch.matmul: torch_matmul_override,
|
torch.matmul: torch_matmul,
|
||||||
torch.Tensor.repeat: torch_tensor_repeat_override,
|
torch.einsum: torch_einsum,
|
||||||
|
torch.Tensor.repeat: torch_tensor_repeat,
|
||||||
torch.roll: torch_roll,
|
torch.roll: torch_roll,
|
||||||
# TODO: those might not be needed.
|
# TODO: those might not be needed.
|
||||||
# torch.index_select: torch_index_select,
|
# torch.index_select: torch_index_select,
|
||||||
# torch.Tensor.index_select: torch_tensor_index_select,
|
# torch.Tensor.index_select: torch_tensor_index_select,
|
||||||
torch.nn.Conv2d: torch_nn_conv2d,
|
torch.nn.Conv2d: torch_nn_conv2d,
|
||||||
|
torch.unsqueeze: torch_unsqueeze,
|
||||||
|
torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
|
||||||
torch.nn.MSELoss: torch_nn_mseloss,
|
torch.nn.MSELoss: torch_nn_mseloss,
|
||||||
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
|
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
|
||||||
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
|
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
|
||||||
|
operator.getitem: operator_getitem,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -371,7 +402,6 @@ class HFProxy(Proxy):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
return self.tracer.root.dtype
|
|
||||||
if hasattr(self, "_metadata") and self._metadata is not None:
|
if hasattr(self, "_metadata") and self._metadata is not None:
|
||||||
return self._metadata.dtype
|
return self._metadata.dtype
|
||||||
return self.tracer.create_proxy("call_function", builtins.getattr, (self, "dtype"), {})
|
return self.tracer.create_proxy("call_function", builtins.getattr, (self, "dtype"), {})
|
||||||
@@ -400,7 +430,7 @@ class HFProxy(Proxy):
|
|||||||
return HFAttribute(self, k)
|
return HFAttribute(self, k)
|
||||||
|
|
||||||
def __setitem__(self, indices, values):
|
def __setitem__(self, indices, values):
|
||||||
return self.tracer.create_proxy("call_method", "__setitem__", (self, indices, values), {})
|
return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
# To handle cases such as :
|
# To handle cases such as :
|
||||||
@@ -480,14 +510,14 @@ class HFTracer(Tracer):
|
|||||||
regular PyTorch torch.fx.Proxy.
|
regular PyTorch torch.fx.Proxy.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Feature flag for proxying accesses to buffer values
|
||||||
|
proxy_buffer_attributes: bool = True
|
||||||
allow_insert_stateless_mods: bool = True
|
allow_insert_stateless_mods: bool = True
|
||||||
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full_like", "eye"]
|
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full_like", "eye"]
|
||||||
|
|
||||||
def __init__(self, autowrap_modules=(math,), autowrap_functions=(), enable_cpatching=False):
|
def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions)
|
||||||
autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions, enable_cpatching=enable_cpatching
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_torch_fx_available():
|
if not is_torch_fx_available():
|
||||||
torch_version = version.parse(importlib_metadata.version("torch"))
|
torch_version = version.parse(importlib_metadata.version("torch"))
|
||||||
@@ -500,7 +530,9 @@ class HFTracer(Tracer):
|
|||||||
self, model: PreTrainedModel, input_name: str, shape: List[int]
|
self, model: PreTrainedModel, input_name: str, shape: List[int]
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
"""Generates dummy input for model inference recording."""
|
"""Generates dummy input for model inference recording."""
|
||||||
model_class = model.__class__
|
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
|
||||||
|
# from pickle, or from the "__class__" attribute in the general case.
|
||||||
|
model_class = getattr(model, "class_for_deserialization", model.__class__)
|
||||||
device = model.device
|
device = model.device
|
||||||
inputs_dict = {}
|
inputs_dict = {}
|
||||||
|
|
||||||
@@ -641,7 +673,38 @@ class HFTracer(Tracer):
|
|||||||
if getattr(self, "_disable_module_getattr", False):
|
if getattr(self, "_disable_module_getattr", False):
|
||||||
return attr_val
|
return attr_val
|
||||||
else:
|
else:
|
||||||
return super()._module_getattr(attr, attr_val, parameter_proxy_cache)
|
# return super()._module_getattr(attr, attr_val, parameter_proxy_cache)
|
||||||
|
def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
|
||||||
|
for n, p in collection_to_search:
|
||||||
|
if attr_val is p:
|
||||||
|
if n not in parameter_proxy_cache:
|
||||||
|
kwargs = {}
|
||||||
|
if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
|
||||||
|
kwargs["proxy_factory_fn"] = (
|
||||||
|
None
|
||||||
|
if not self.param_shapes_constant
|
||||||
|
else lambda node: ParameterProxy(self, node, n, attr_val)
|
||||||
|
)
|
||||||
|
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
|
||||||
|
parameter_proxy_cache[n] = val_proxy
|
||||||
|
return parameter_proxy_cache[n]
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(attr_val, torch.nn.Parameter):
|
||||||
|
maybe_parameter_proxy = maybe_get_proxy_for_attr(
|
||||||
|
attr_val, self.root.named_parameters(), parameter_proxy_cache
|
||||||
|
)
|
||||||
|
if maybe_parameter_proxy is not None:
|
||||||
|
return maybe_parameter_proxy
|
||||||
|
|
||||||
|
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
|
||||||
|
maybe_buffer_proxy = maybe_get_proxy_for_attr(
|
||||||
|
attr_val, self.root.named_buffers(), parameter_proxy_cache
|
||||||
|
)
|
||||||
|
if maybe_buffer_proxy is not None:
|
||||||
|
return maybe_buffer_proxy
|
||||||
|
|
||||||
|
return attr_val
|
||||||
|
|
||||||
def call_module(self, m, forward, args, kwargs):
|
def call_module(self, m, forward, args, kwargs):
|
||||||
self.orig_forward = forward
|
self.orig_forward = forward
|
||||||
@@ -693,17 +756,29 @@ class HFTracer(Tracer):
|
|||||||
for name, (_, orig) in self.patched_torch_methods.items():
|
for name, (_, orig) in self.patched_torch_methods.items():
|
||||||
setattr(torch, name, orig)
|
setattr(torch, name, orig)
|
||||||
|
|
||||||
# TODO: keep this until necessary.
|
|
||||||
# This is necessary because concrete args are added as input to the traced module since
|
# This is necessary because concrete args are added as input to the traced module since
|
||||||
# https://github.com/pytorch/pytorch/pull/55888.
|
# https://github.com/pytorch/pytorch/pull/55888.
|
||||||
# A PR that solves this was posted: https://github.com/pytorch/pytorch/pull/59569 but it was not merged yet.
|
|
||||||
for node in self.graph.nodes:
|
for node in self.graph.nodes:
|
||||||
if node.op == "placeholder":
|
if node.op == "placeholder":
|
||||||
# Removing default values for inputs as the forward pass will fail with them.
|
# Removing default values for inputs as the forward pass will fail with them.
|
||||||
if node.target in input_names:
|
if node.target in input_names:
|
||||||
node.args = ()
|
node.args = ()
|
||||||
|
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
|
||||||
|
# It cannot infer on the attributes and methods the input should have, and fails.
|
||||||
|
node.type = torch.Tensor
|
||||||
# It is a concrete arg so it is not used and should be removed.
|
# It is a concrete arg so it is not used and should be removed.
|
||||||
else:
|
else:
|
||||||
|
if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
|
||||||
|
# Newer versions of torch.fx emit an assert statement
|
||||||
|
# for concrete arguments; delete those before we delete
|
||||||
|
# the concrete arg.
|
||||||
|
to_delete = []
|
||||||
|
for user in node.users:
|
||||||
|
if user.target == torch.fx._symbolic_trace._assert_is_none:
|
||||||
|
to_delete.append(user)
|
||||||
|
for user in to_delete:
|
||||||
|
self.graph.erase_node(user)
|
||||||
|
|
||||||
self.graph.erase_node(node)
|
self.graph.erase_node(node)
|
||||||
|
|
||||||
# TODO: solves GraphModule creation.
|
# TODO: solves GraphModule creation.
|
||||||
@@ -809,4 +884,10 @@ def symbolic_trace(
|
|||||||
traced_graph = tracer.trace(model, concrete_args=concrete_args)
|
traced_graph = tracer.trace(model, concrete_args=concrete_args)
|
||||||
traced = torch.fx.GraphModule(model, traced_graph)
|
traced = torch.fx.GraphModule(model, traced_graph)
|
||||||
|
|
||||||
|
traced.config = model.config
|
||||||
|
# The model class must be stored as an attribute to allow model deserialization, which uses trace, and thus
|
||||||
|
# _generate_dummy_input, where the model class is needed.
|
||||||
|
traced.class_for_deserialization = model.__class__
|
||||||
|
traced.device = model.device
|
||||||
|
|
||||||
return traced
|
return traced
|
||||||
|
|||||||
@@ -325,7 +325,7 @@ torch_version = None
|
|||||||
_torch_fx_available = _torch_onnx_dict_inputs_support_available = False
|
_torch_fx_available = _torch_onnx_dict_inputs_support_available = False
|
||||||
if _torch_available:
|
if _torch_available:
|
||||||
torch_version = version.parse(importlib_metadata.version("torch"))
|
torch_version = version.parse(importlib_metadata.version("torch"))
|
||||||
_torch_fx_available = (torch_version.major, torch_version.minor) == (
|
_torch_fx_available = (torch_version.major, torch_version.minor) >= (
|
||||||
TORCH_FX_REQUIRED_VERSION.major,
|
TORCH_FX_REQUIRED_VERSION.major,
|
||||||
TORCH_FX_REQUIRED_VERSION.minor,
|
TORCH_FX_REQUIRED_VERSION.minor,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,11 +16,14 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import SwinConfig
|
from transformers import SwinConfig
|
||||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
from transformers.utils import cached_property, is_torch_available, is_torch_fx_available, is_vision_available
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||||
@@ -38,6 +41,9 @@ if is_vision_available():
|
|||||||
|
|
||||||
from transformers import AutoFeatureExtractor
|
from transformers import AutoFeatureExtractor
|
||||||
|
|
||||||
|
if is_torch_fx_available():
|
||||||
|
from transformers.utils.fx import symbolic_trace
|
||||||
|
|
||||||
|
|
||||||
def _config_zero_init(config):
|
def _config_zero_init(config):
|
||||||
configs_no_init = copy.deepcopy(config)
|
configs_no_init = copy.deepcopy(config)
|
||||||
@@ -381,6 +387,97 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
|
||||||
|
if not is_torch_fx_available() or not self.fx_compatible:
|
||||||
|
return
|
||||||
|
|
||||||
|
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||||
|
configs_no_init.return_dict = False
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config=configs_no_init)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||||
|
labels = inputs.get("labels", None)
|
||||||
|
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
|
||||||
|
if labels is not None:
|
||||||
|
input_names.append("labels")
|
||||||
|
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||||
|
|
||||||
|
model_output = model(**filtered_inputs)
|
||||||
|
|
||||||
|
traced_model = symbolic_trace(model, input_names)
|
||||||
|
traced_output = traced_model(**filtered_inputs)
|
||||||
|
else:
|
||||||
|
input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"]
|
||||||
|
|
||||||
|
labels = inputs.get("labels", None)
|
||||||
|
start_positions = inputs.get("start_positions", None)
|
||||||
|
end_positions = inputs.get("end_positions", None)
|
||||||
|
if labels is not None:
|
||||||
|
input_names.append("labels")
|
||||||
|
if start_positions is not None:
|
||||||
|
input_names.append("start_positions")
|
||||||
|
if end_positions is not None:
|
||||||
|
input_names.append("end_positions")
|
||||||
|
|
||||||
|
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||||
|
input_names = filtered_inputs.keys()
|
||||||
|
|
||||||
|
model_output = model(**filtered_inputs)
|
||||||
|
|
||||||
|
traced_model = symbolic_trace(model, input_names)
|
||||||
|
traced_output = traced_model(**filtered_inputs)
|
||||||
|
|
||||||
|
except RuntimeError as e:
|
||||||
|
self.fail(f"Couldn't trace module: {e}")
|
||||||
|
|
||||||
|
def flatten_output(output):
|
||||||
|
flatten = []
|
||||||
|
for x in output:
|
||||||
|
if isinstance(x, (tuple, list)):
|
||||||
|
flatten += flatten_output(x)
|
||||||
|
elif not isinstance(x, torch.Tensor):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
flatten.append(x)
|
||||||
|
return flatten
|
||||||
|
|
||||||
|
model_output = flatten_output(model_output)
|
||||||
|
traced_output = flatten_output(traced_output)
|
||||||
|
num_outputs = len(model_output)
|
||||||
|
|
||||||
|
for i in range(num_outputs):
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(model_output[i], traced_output[i]),
|
||||||
|
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test that the model can be serialized and restored properly
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||||
|
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
|
||||||
|
try:
|
||||||
|
with open(pkl_file_name, "wb") as f:
|
||||||
|
pickle.dump(traced_model, f)
|
||||||
|
with open(pkl_file_name, "rb") as f:
|
||||||
|
loaded = pickle.load(f)
|
||||||
|
except Exception as e:
|
||||||
|
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
|
||||||
|
|
||||||
|
loaded_output = loaded(**filtered_inputs)
|
||||||
|
loaded_output = flatten_output(loaded_output)
|
||||||
|
|
||||||
|
for i in range(num_outputs):
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(model_output[i], loaded_output[i]),
|
||||||
|
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@require_vision
|
@require_vision
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import inspect
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import os.path
|
import os.path
|
||||||
|
import pickle
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
@@ -758,8 +759,8 @@ class ModelTesterMixin:
|
|||||||
traced_model = symbolic_trace(model, input_names)
|
traced_model = symbolic_trace(model, input_names)
|
||||||
traced_output = traced_model(**filtered_inputs)
|
traced_output = traced_model(**filtered_inputs)
|
||||||
|
|
||||||
except RuntimeError:
|
except RuntimeError as e:
|
||||||
self.fail("Couldn't trace module.")
|
self.fail(f"Couldn't trace module: {e}")
|
||||||
|
|
||||||
def flatten_output(output):
|
def flatten_output(output):
|
||||||
flatten = []
|
flatten = []
|
||||||
@@ -782,6 +783,40 @@ class ModelTesterMixin:
|
|||||||
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Test that the model can be TorchScripted
|
||||||
|
try:
|
||||||
|
scripted = torch.jit.script(traced_model)
|
||||||
|
except Exception as e:
|
||||||
|
self.fail(f"Could not TorchScript the traced model: {e}")
|
||||||
|
scripted_output = scripted(**filtered_inputs)
|
||||||
|
scripted_output = flatten_output(scripted_output)
|
||||||
|
|
||||||
|
for i in range(num_outputs):
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(model_output[i], scripted_output[i]),
|
||||||
|
f"scripted {i}th output doesn't match model {i}th output for {model_class}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test that the model can be serialized and restored properly
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||||
|
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
|
||||||
|
try:
|
||||||
|
with open(pkl_file_name, "wb") as f:
|
||||||
|
pickle.dump(traced_model, f)
|
||||||
|
with open(pkl_file_name, "rb") as f:
|
||||||
|
loaded = pickle.load(f)
|
||||||
|
except Exception as e:
|
||||||
|
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
|
||||||
|
|
||||||
|
loaded_output = loaded(**filtered_inputs)
|
||||||
|
loaded_output = flatten_output(loaded_output)
|
||||||
|
|
||||||
|
for i in range(num_outputs):
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(model_output[i], loaded_output[i]),
|
||||||
|
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
|
||||||
|
)
|
||||||
|
|
||||||
def test_headmasking(self):
|
def test_headmasking(self):
|
||||||
if not self.test_head_masking:
|
if not self.test_head_masking:
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user