Fx support for Deberta-v[1-2], Hubert and LXMERT (#17539)
* Support for deberta and deberta-v2 * Support for LXMert * Support for Hubert * Fix for pt1.11 * Trigger CI
This commit is contained in:
@@ -104,9 +104,9 @@ class XSoftmax(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(self, input, mask, dim):
|
def forward(self, input, mask, dim):
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
rmask = ~(mask.bool())
|
rmask = ~(mask.to(torch.bool))
|
||||||
|
|
||||||
output = input.masked_fill(rmask, float("-inf"))
|
output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
|
||||||
output = torch.softmax(output, self.dim)
|
output = torch.softmax(output, self.dim)
|
||||||
output.masked_fill_(rmask, 0)
|
output.masked_fill_(rmask, 0)
|
||||||
self.save_for_backward(output)
|
self.save_for_backward(output)
|
||||||
@@ -129,7 +129,7 @@ class XSoftmax(torch.autograd.Function):
|
|||||||
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
|
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
|
||||||
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
|
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
|
||||||
)
|
)
|
||||||
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf"))))
|
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.dtype).min)))
|
||||||
output = softmax(g, output, dim)
|
output = softmax(g, output, dim)
|
||||||
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
|
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
|
||||||
|
|
||||||
@@ -152,7 +152,7 @@ def get_mask(input, local_context):
|
|||||||
mask = local_context.mask if local_context.reuse_mask else None
|
mask = local_context.mask if local_context.reuse_mask else None
|
||||||
|
|
||||||
if dropout > 0 and mask is None:
|
if dropout > 0 and mask is None:
|
||||||
mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool()
|
mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool)
|
||||||
|
|
||||||
if isinstance(local_context, DropoutContext):
|
if isinstance(local_context, DropoutContext):
|
||||||
if local_context.mask is None:
|
if local_context.mask is None:
|
||||||
@@ -564,7 +564,7 @@ class DisentangledSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def transpose_for_scores(self, x):
|
def transpose_for_scores(self, x):
|
||||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1)
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(new_x_shape)
|
||||||
return x.permute(0, 2, 1, 3)
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -652,7 +652,7 @@ class DisentangledSelfAttention(nn.Module):
|
|||||||
context_layer = torch.matmul(attention_probs, value_layer)
|
context_layer = torch.matmul(attention_probs, value_layer)
|
||||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
|
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(new_context_layer_shape)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
return (context_layer, attention_probs)
|
return (context_layer, attention_probs)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -107,9 +107,9 @@ class XSoftmax(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(self, input, mask, dim):
|
def forward(self, input, mask, dim):
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
rmask = ~(mask.bool())
|
rmask = ~(mask.to(torch.bool))
|
||||||
|
|
||||||
output = input.masked_fill(rmask, float("-inf"))
|
output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
|
||||||
output = torch.softmax(output, self.dim)
|
output = torch.softmax(output, self.dim)
|
||||||
output.masked_fill_(rmask, 0)
|
output.masked_fill_(rmask, 0)
|
||||||
self.save_for_backward(output)
|
self.save_for_backward(output)
|
||||||
@@ -132,7 +132,7 @@ class XSoftmax(torch.autograd.Function):
|
|||||||
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
|
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
|
||||||
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
|
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
|
||||||
)
|
)
|
||||||
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf"))))
|
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.dtype).min)))
|
||||||
output = softmax(g, output, dim)
|
output = softmax(g, output, dim)
|
||||||
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
|
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
|
||||||
|
|
||||||
@@ -157,7 +157,7 @@ def get_mask(input, local_context):
|
|||||||
mask = local_context.mask if local_context.reuse_mask else None
|
mask = local_context.mask if local_context.reuse_mask else None
|
||||||
|
|
||||||
if dropout > 0 and mask is None:
|
if dropout > 0 and mask is None:
|
||||||
mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool()
|
mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool)
|
||||||
|
|
||||||
if isinstance(local_context, DropoutContext):
|
if isinstance(local_context, DropoutContext):
|
||||||
if local_context.mask is None:
|
if local_context.mask is None:
|
||||||
@@ -638,7 +638,7 @@ class DisentangledSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def transpose_for_scores(self, x, attention_heads):
|
def transpose_for_scores(self, x, attention_heads):
|
||||||
new_x_shape = x.size()[:-1] + (attention_heads, -1)
|
new_x_shape = x.size()[:-1] + (attention_heads, -1)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(new_x_shape)
|
||||||
return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
|
return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -719,7 +719,7 @@ class DisentangledSelfAttention(nn.Module):
|
|||||||
.contiguous()
|
.contiguous()
|
||||||
)
|
)
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
|
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(new_context_layer_shape)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
return (context_layer, attention_probs)
|
return (context_layer, attention_probs)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -336,7 +336,7 @@ class LxmertAttention(nn.Module):
|
|||||||
self.num_attention_heads,
|
self.num_attention_heads,
|
||||||
self.attention_head_size,
|
self.attention_head_size,
|
||||||
)
|
)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(new_x_shape)
|
||||||
return x.permute(0, 2, 1, 3)
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
def forward(self, hidden_states, context, attention_mask=None, output_attentions=False):
|
def forward(self, hidden_states, context, attention_mask=None, output_attentions=False):
|
||||||
@@ -365,7 +365,7 @@ class LxmertAttention(nn.Module):
|
|||||||
context_layer = torch.matmul(attention_probs, value_layer)
|
context_layer = torch.matmul(attention_probs, value_layer)
|
||||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.head_size,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.head_size,)
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(new_context_layer_shape)
|
||||||
|
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
return outputs
|
return outputs
|
||||||
@@ -1253,7 +1253,7 @@ class LxmertForPreTraining(LxmertPreTrainedModel):
|
|||||||
visual_prediction_scores = visual_prediction_scores_dict[key]
|
visual_prediction_scores = visual_prediction_scores_dict[key]
|
||||||
visual_loss = visual_loss_fct(
|
visual_loss = visual_loss_fct(
|
||||||
visual_prediction_scores.view(-1, output_dim),
|
visual_prediction_scores.view(-1, output_dim),
|
||||||
label.view(*label_shape),
|
label.view(label_shape),
|
||||||
)
|
)
|
||||||
if visual_loss.dim() > 1: # Regression Losses
|
if visual_loss.dim() > 1: # Regression Losses
|
||||||
visual_loss = visual_loss.mean(1)
|
visual_loss = visual_loss.mean(1)
|
||||||
|
|||||||
@@ -261,7 +261,7 @@ def get_mask(input, local_context):
|
|||||||
mask = local_context.mask if local_context.reuse_mask else None
|
mask = local_context.mask if local_context.reuse_mask else None
|
||||||
|
|
||||||
if dropout > 0 and mask is None:
|
if dropout > 0 and mask is None:
|
||||||
mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool()
|
mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool)
|
||||||
|
|
||||||
if isinstance(local_context, DropoutContext):
|
if isinstance(local_context, DropoutContext):
|
||||||
if local_context.mask is None:
|
if local_context.mask is None:
|
||||||
@@ -532,9 +532,9 @@ class XSoftmax(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(self, input, mask, dim):
|
def forward(self, input, mask, dim):
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
rmask = ~(mask.bool())
|
rmask = ~(mask.to(torch.bool))
|
||||||
|
|
||||||
output = input.masked_fill(rmask, float("-inf"))
|
output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
|
||||||
output = torch.softmax(output, self.dim)
|
output = torch.softmax(output, self.dim)
|
||||||
output.masked_fill_(rmask, 0)
|
output.masked_fill_(rmask, 0)
|
||||||
self.save_for_backward(output)
|
self.save_for_backward(output)
|
||||||
@@ -557,7 +557,7 @@ class XSoftmax(torch.autograd.Function):
|
|||||||
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
|
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
|
||||||
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
|
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
|
||||||
)
|
)
|
||||||
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf"))))
|
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.dtype).min)))
|
||||||
output = softmax(g, output, dim)
|
output = softmax(g, output, dim)
|
||||||
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
|
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
|
||||||
|
|
||||||
@@ -711,7 +711,7 @@ class DisentangledSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def transpose_for_scores(self, x, attention_heads):
|
def transpose_for_scores(self, x, attention_heads):
|
||||||
new_x_shape = x.size()[:-1] + (attention_heads, -1)
|
new_x_shape = x.size()[:-1] + (attention_heads, -1)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(new_x_shape)
|
||||||
return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
|
return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -792,7 +792,7 @@ class DisentangledSelfAttention(nn.Module):
|
|||||||
.contiguous()
|
.contiguous()
|
||||||
)
|
)
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
|
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(new_context_layer_shape)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
return (context_layer, attention_probs)
|
return (context_layer, attention_probs)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -32,7 +32,9 @@ from torch.fx.proxy import ParameterProxy
|
|||||||
from .. import PretrainedConfig, PreTrainedModel, logging
|
from .. import PretrainedConfig, PreTrainedModel, logging
|
||||||
from ..models.auto import get_values
|
from ..models.auto import get_values
|
||||||
from ..models.auto.modeling_auto import (
|
from ..models.auto.modeling_auto import (
|
||||||
|
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
||||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||||
|
MODEL_FOR_CTC_MAPPING_NAMES,
|
||||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||||
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
|
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
|
||||||
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
||||||
@@ -72,6 +74,8 @@ def _generate_supported_model_class_names(
|
|||||||
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
||||||
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
|
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
|
||||||
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||||
|
"ctc": MODEL_FOR_CTC_MAPPING_NAMES,
|
||||||
|
"audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
||||||
}
|
}
|
||||||
|
|
||||||
if supported_tasks is None:
|
if supported_tasks is None:
|
||||||
@@ -95,12 +99,16 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
|
|||||||
"blenderbot",
|
"blenderbot",
|
||||||
"blenderbot-small",
|
"blenderbot-small",
|
||||||
"clip",
|
"clip",
|
||||||
|
"deberta",
|
||||||
|
"deberta-v2",
|
||||||
"distilbert",
|
"distilbert",
|
||||||
"electra",
|
"electra",
|
||||||
"gpt2",
|
"gpt2",
|
||||||
"gpt_neo",
|
"gpt_neo",
|
||||||
"gptj",
|
"gptj",
|
||||||
|
"hubert",
|
||||||
"layoutlm",
|
"layoutlm",
|
||||||
|
"lxmert",
|
||||||
"m2m_100",
|
"m2m_100",
|
||||||
"marian",
|
"marian",
|
||||||
"mbart",
|
"mbart",
|
||||||
@@ -118,8 +126,8 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
|
|||||||
"trocr",
|
"trocr",
|
||||||
"vit",
|
"vit",
|
||||||
"xglm",
|
"xglm",
|
||||||
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
|
|
||||||
# "xlnet",
|
# "xlnet",
|
||||||
|
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
|
||||||
]
|
]
|
||||||
|
|
||||||
_REGULAR_SUPPORTED_MODELS = []
|
_REGULAR_SUPPORTED_MODELS = []
|
||||||
@@ -155,6 +163,10 @@ def torch_nn_layernorm(self, input):
|
|||||||
return input
|
return input
|
||||||
|
|
||||||
|
|
||||||
|
def torch_nn_groupnorm(self, input):
|
||||||
|
return input
|
||||||
|
|
||||||
|
|
||||||
def torch_nn_linear(self, input):
|
def torch_nn_linear(self, input):
|
||||||
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
|
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
|
||||||
|
|
||||||
@@ -372,6 +384,27 @@ def torch_nn_conv2d(self, input):
|
|||||||
return torch.empty(shape, device="meta")
|
return torch.empty(shape, device="meta")
|
||||||
|
|
||||||
|
|
||||||
|
def torch_squeeze(input, dim=None):
|
||||||
|
shape = list(input.shape)
|
||||||
|
if dim is not None:
|
||||||
|
if dim < 0:
|
||||||
|
dim = input.dim() + dim
|
||||||
|
if shape[dim] == 1:
|
||||||
|
shape.pop(dim)
|
||||||
|
else:
|
||||||
|
new_shape = []
|
||||||
|
for dim_value in shape:
|
||||||
|
if dim_value == 1:
|
||||||
|
continue
|
||||||
|
new_shape.append(dim_value)
|
||||||
|
shape = new_shape
|
||||||
|
return torch.empty(shape, device="meta")
|
||||||
|
|
||||||
|
|
||||||
|
def torch_tensor_squeeze(self, dim=None):
|
||||||
|
return torch_squeeze(self, dim)
|
||||||
|
|
||||||
|
|
||||||
def torch_unsqueeze(input, dim):
|
def torch_unsqueeze(input, dim):
|
||||||
shape = list(input.shape)
|
shape = list(input.shape)
|
||||||
if dim < 0:
|
if dim < 0:
|
||||||
@@ -446,6 +479,7 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
|
|||||||
torch.nn.Embedding: torch_nn_embedding,
|
torch.nn.Embedding: torch_nn_embedding,
|
||||||
torch.nn.functional.embedding: torch_nn_functional_embedding,
|
torch.nn.functional.embedding: torch_nn_functional_embedding,
|
||||||
torch.nn.LayerNorm: torch_nn_layernorm,
|
torch.nn.LayerNorm: torch_nn_layernorm,
|
||||||
|
torch.nn.GroupNorm: torch_nn_groupnorm,
|
||||||
torch.nn.Linear: torch_nn_linear,
|
torch.nn.Linear: torch_nn_linear,
|
||||||
torch.relu: torch_relu,
|
torch.relu: torch_relu,
|
||||||
torch.nn.functional.relu: torch_nn_functional_relu,
|
torch.nn.functional.relu: torch_nn_functional_relu,
|
||||||
@@ -469,6 +503,8 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
|
|||||||
torch.Tensor.index_select: torch_tensor_index_select,
|
torch.Tensor.index_select: torch_tensor_index_select,
|
||||||
torch.nn.Conv1d: torch_nn_conv1d,
|
torch.nn.Conv1d: torch_nn_conv1d,
|
||||||
torch.nn.Conv2d: torch_nn_conv2d,
|
torch.nn.Conv2d: torch_nn_conv2d,
|
||||||
|
torch.squeeze: torch_squeeze,
|
||||||
|
torch.Tensor.squeeze: torch_tensor_squeeze,
|
||||||
torch.unsqueeze: torch_unsqueeze,
|
torch.unsqueeze: torch_unsqueeze,
|
||||||
torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
|
torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
|
||||||
torch.unique_consecutive: torch_unique_consecutive,
|
torch.unique_consecutive: torch_unique_consecutive,
|
||||||
@@ -605,7 +641,7 @@ class HFTracer(Tracer):
|
|||||||
# Feature flag for proxying accesses to buffer values
|
# Feature flag for proxying accesses to buffer values
|
||||||
proxy_buffer_attributes: bool = True
|
proxy_buffer_attributes: bool = True
|
||||||
allow_insert_stateless_mods: bool = True
|
allow_insert_stateless_mods: bool = True
|
||||||
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty"]
|
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor"]
|
||||||
|
|
||||||
def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
|
def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
|
||||||
|
|
||||||
@@ -704,8 +740,31 @@ class HFTracer(Tracer):
|
|||||||
inputs_dict[input_name] = torch.zeros(
|
inputs_dict[input_name] = torch.zeros(
|
||||||
*shape, model.config.input_feat_per_channel, dtype=torch.float, device=device
|
*shape, model.config.input_feat_per_channel, dtype=torch.float, device=device
|
||||||
)
|
)
|
||||||
|
elif "visual_feats" in input_name:
|
||||||
|
inputs_dict[input_name] = torch.zeros(
|
||||||
|
shape
|
||||||
|
+ [
|
||||||
|
model.config.visual_feat_dim,
|
||||||
|
],
|
||||||
|
dtype=torch.float,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
elif "visual_pos" in input_name:
|
||||||
|
inputs_dict[input_name] = torch.zeros(
|
||||||
|
shape
|
||||||
|
+ [
|
||||||
|
model.config.visual_pos_dim,
|
||||||
|
],
|
||||||
|
dtype=torch.float,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
elif "inputs" in input_name:
|
elif "inputs" in input_name:
|
||||||
inputs_dict[input_name] = torch.zeros(*shape, dtype=torch.float, device=device)
|
inputs_dict[input_name] = torch.zeros(*shape, dtype=torch.float, device=device)
|
||||||
|
elif "input_values" in input_name:
|
||||||
|
batch_size, _ = shape
|
||||||
|
# Generating big sequence length for audio inputs.
|
||||||
|
seq_length = _generate_random_int(low=10000, high=20000)
|
||||||
|
inputs_dict[input_name] = torch.zeros(batch_size, seq_length, dtype=torch.float, device=device)
|
||||||
elif "mask" in input_name or "ids" in input_name:
|
elif "mask" in input_name or "ids" in input_name:
|
||||||
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
|
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -222,6 +222,7 @@ class DebertaModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
fx_compatible = True
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
|||||||
@@ -241,6 +241,7 @@ class DebertaV2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
fx_compatible = True
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
|||||||
@@ -16,12 +16,16 @@
|
|||||||
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from transformers import HubertConfig, is_torch_available
|
from transformers import HubertConfig, is_torch_available
|
||||||
from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
|
from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
|
||||||
|
from transformers.utils import is_torch_fx_available
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import (
|
from ...test_modeling_common import (
|
||||||
@@ -45,6 +49,9 @@ if is_torch_available():
|
|||||||
)
|
)
|
||||||
from transformers.models.hubert.modeling_hubert import _compute_mask_indices
|
from transformers.models.hubert.modeling_hubert import _compute_mask_indices
|
||||||
|
|
||||||
|
if is_torch_fx_available():
|
||||||
|
from transformers.utils.fx import symbolic_trace
|
||||||
|
|
||||||
|
|
||||||
class HubertModelTester:
|
class HubertModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -299,6 +306,7 @@ class HubertModelTester:
|
|||||||
@require_torch
|
@require_torch
|
||||||
class HubertModelTest(ModelTesterMixin, unittest.TestCase):
|
class HubertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (HubertForCTC, HubertForSequenceClassification, HubertModel) if is_torch_available() else ()
|
all_model_classes = (HubertForCTC, HubertForSequenceClassification, HubertModel) if is_torch_available() else ()
|
||||||
|
fx_compatible = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
|
|
||||||
@@ -417,6 +425,117 @@ class HubertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Hubert cannot be TorchScripted because of torch.nn.utils.weight_norm
|
||||||
|
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
|
||||||
|
if not is_torch_fx_available() or not self.fx_compatible:
|
||||||
|
return
|
||||||
|
|
||||||
|
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||||
|
configs_no_init.return_dict = False
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config=configs_no_init)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||||
|
labels = inputs.get("labels", None)
|
||||||
|
input_names = [
|
||||||
|
"attention_mask",
|
||||||
|
"decoder_attention_mask",
|
||||||
|
"decoder_input_ids",
|
||||||
|
"input_features",
|
||||||
|
"input_ids",
|
||||||
|
"input_values",
|
||||||
|
]
|
||||||
|
if labels is not None:
|
||||||
|
input_names.append("labels")
|
||||||
|
|
||||||
|
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||||
|
input_names = list(filtered_inputs.keys())
|
||||||
|
|
||||||
|
model_output = model(**filtered_inputs)
|
||||||
|
|
||||||
|
traced_model = symbolic_trace(model, input_names)
|
||||||
|
traced_output = traced_model(**filtered_inputs)
|
||||||
|
else:
|
||||||
|
input_names = [
|
||||||
|
"attention_mask",
|
||||||
|
"bbox",
|
||||||
|
"input_features",
|
||||||
|
"input_ids",
|
||||||
|
"input_values",
|
||||||
|
"pixel_values",
|
||||||
|
"token_type_ids",
|
||||||
|
"visual_feats",
|
||||||
|
"visual_pos",
|
||||||
|
]
|
||||||
|
|
||||||
|
labels = inputs.get("labels", None)
|
||||||
|
start_positions = inputs.get("start_positions", None)
|
||||||
|
end_positions = inputs.get("end_positions", None)
|
||||||
|
if labels is not None:
|
||||||
|
input_names.append("labels")
|
||||||
|
if start_positions is not None:
|
||||||
|
input_names.append("start_positions")
|
||||||
|
if end_positions is not None:
|
||||||
|
input_names.append("end_positions")
|
||||||
|
|
||||||
|
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||||
|
input_names = list(filtered_inputs.keys())
|
||||||
|
|
||||||
|
model_output = model(**filtered_inputs)
|
||||||
|
|
||||||
|
traced_model = symbolic_trace(model, input_names)
|
||||||
|
traced_output = traced_model(**filtered_inputs)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.fail(f"Couldn't trace module: {e}")
|
||||||
|
|
||||||
|
def flatten_output(output):
|
||||||
|
flatten = []
|
||||||
|
for x in output:
|
||||||
|
if isinstance(x, (tuple, list)):
|
||||||
|
flatten += flatten_output(x)
|
||||||
|
elif not isinstance(x, torch.Tensor):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
flatten.append(x)
|
||||||
|
return flatten
|
||||||
|
|
||||||
|
model_output = flatten_output(model_output)
|
||||||
|
traced_output = flatten_output(traced_output)
|
||||||
|
num_outputs = len(model_output)
|
||||||
|
|
||||||
|
for i in range(num_outputs):
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(model_output[i], traced_output[i]),
|
||||||
|
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test that the model can be serialized and restored properly
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||||
|
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
|
||||||
|
try:
|
||||||
|
with open(pkl_file_name, "wb") as f:
|
||||||
|
pickle.dump(traced_model, f)
|
||||||
|
with open(pkl_file_name, "rb") as f:
|
||||||
|
loaded = pickle.load(f)
|
||||||
|
except Exception as e:
|
||||||
|
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
|
||||||
|
|
||||||
|
loaded_output = loaded(**filtered_inputs)
|
||||||
|
loaded_output = flatten_output(loaded_output)
|
||||||
|
|
||||||
|
for i in range(num_outputs):
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(model_output[i], loaded_output[i]),
|
||||||
|
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
|
||||||
|
)
|
||||||
|
|
||||||
# overwrite from test_modeling_common
|
# overwrite from test_modeling_common
|
||||||
def _mock_init_weights(self, module):
|
def _mock_init_weights(self, module):
|
||||||
if hasattr(module, "weight") and module.weight is not None:
|
if hasattr(module, "weight") and module.weight is not None:
|
||||||
|
|||||||
@@ -535,6 +535,7 @@ class LxmertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
all_model_classes = (LxmertModel, LxmertForPreTraining, LxmertForQuestionAnswering) if is_torch_available() else ()
|
all_model_classes = (LxmertModel, LxmertForPreTraining, LxmertForQuestionAnswering) if is_torch_available() else ()
|
||||||
|
|
||||||
|
fx_compatible = True
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
|
|||||||
@@ -740,11 +740,12 @@ class ModelTesterMixin:
|
|||||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||||
labels = inputs.get("labels", None)
|
labels = inputs.get("labels", None)
|
||||||
input_names = [
|
input_names = [
|
||||||
"input_ids",
|
|
||||||
"attention_mask",
|
"attention_mask",
|
||||||
"decoder_input_ids",
|
|
||||||
"decoder_attention_mask",
|
"decoder_attention_mask",
|
||||||
|
"decoder_input_ids",
|
||||||
"input_features",
|
"input_features",
|
||||||
|
"input_ids",
|
||||||
|
"input_values",
|
||||||
]
|
]
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
input_names.append("labels")
|
input_names.append("labels")
|
||||||
@@ -758,12 +759,15 @@ class ModelTesterMixin:
|
|||||||
traced_output = traced_model(**filtered_inputs)
|
traced_output = traced_model(**filtered_inputs)
|
||||||
else:
|
else:
|
||||||
input_names = [
|
input_names = [
|
||||||
"input_ids",
|
|
||||||
"attention_mask",
|
"attention_mask",
|
||||||
"token_type_ids",
|
|
||||||
"pixel_values",
|
|
||||||
"bbox",
|
"bbox",
|
||||||
"input_features",
|
"input_features",
|
||||||
|
"input_ids",
|
||||||
|
"input_values",
|
||||||
|
"pixel_values",
|
||||||
|
"token_type_ids",
|
||||||
|
"visual_feats",
|
||||||
|
"visual_pos",
|
||||||
]
|
]
|
||||||
|
|
||||||
labels = inputs.get("labels", None)
|
labels = inputs.get("labels", None)
|
||||||
@@ -781,10 +785,17 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
model_output = model(**filtered_inputs)
|
model_output = model(**filtered_inputs)
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(model, tuple(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values()))
|
||||||
|
and not hasattr(model.config, "problem_type")
|
||||||
|
or model.config.problem_type is None
|
||||||
|
):
|
||||||
|
model.config.problem_type = "single_label_classification"
|
||||||
|
|
||||||
traced_model = symbolic_trace(model, input_names)
|
traced_model = symbolic_trace(model, input_names)
|
||||||
traced_output = traced_model(**filtered_inputs)
|
traced_output = traced_model(**filtered_inputs)
|
||||||
|
|
||||||
except RuntimeError as e:
|
except Exception as e:
|
||||||
self.fail(f"Couldn't trace module: {e}")
|
self.fail(f"Couldn't trace module: {e}")
|
||||||
|
|
||||||
def flatten_output(output):
|
def flatten_output(output):
|
||||||
|
|||||||
Reference in New Issue
Block a user