Fx support for Deberta-v[1-2], Hubert and LXMERT (#17539)
* Support for deberta and deberta-v2 * Support for LXMert * Support for Hubert * Fix for pt1.11 * Trigger CI
This commit is contained in:
@@ -104,9 +104,9 @@ class XSoftmax(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(self, input, mask, dim):
|
||||
self.dim = dim
|
||||
rmask = ~(mask.bool())
|
||||
rmask = ~(mask.to(torch.bool))
|
||||
|
||||
output = input.masked_fill(rmask, float("-inf"))
|
||||
output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
|
||||
output = torch.softmax(output, self.dim)
|
||||
output.masked_fill_(rmask, 0)
|
||||
self.save_for_backward(output)
|
||||
@@ -129,7 +129,7 @@ class XSoftmax(torch.autograd.Function):
|
||||
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
|
||||
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
|
||||
)
|
||||
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf"))))
|
||||
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.dtype).min)))
|
||||
output = softmax(g, output, dim)
|
||||
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
|
||||
|
||||
@@ -152,7 +152,7 @@ def get_mask(input, local_context):
|
||||
mask = local_context.mask if local_context.reuse_mask else None
|
||||
|
||||
if dropout > 0 and mask is None:
|
||||
mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool()
|
||||
mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool)
|
||||
|
||||
if isinstance(local_context, DropoutContext):
|
||||
if local_context.mask is None:
|
||||
@@ -564,7 +564,7 @@ class DisentangledSelfAttention(nn.Module):
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1)
|
||||
x = x.view(*new_x_shape)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
@@ -652,7 +652,7 @@ class DisentangledSelfAttention(nn.Module):
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
if output_attentions:
|
||||
return (context_layer, attention_probs)
|
||||
else:
|
||||
|
||||
@@ -107,9 +107,9 @@ class XSoftmax(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(self, input, mask, dim):
|
||||
self.dim = dim
|
||||
rmask = ~(mask.bool())
|
||||
rmask = ~(mask.to(torch.bool))
|
||||
|
||||
output = input.masked_fill(rmask, float("-inf"))
|
||||
output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
|
||||
output = torch.softmax(output, self.dim)
|
||||
output.masked_fill_(rmask, 0)
|
||||
self.save_for_backward(output)
|
||||
@@ -132,7 +132,7 @@ class XSoftmax(torch.autograd.Function):
|
||||
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
|
||||
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
|
||||
)
|
||||
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf"))))
|
||||
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.dtype).min)))
|
||||
output = softmax(g, output, dim)
|
||||
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
|
||||
|
||||
@@ -157,7 +157,7 @@ def get_mask(input, local_context):
|
||||
mask = local_context.mask if local_context.reuse_mask else None
|
||||
|
||||
if dropout > 0 and mask is None:
|
||||
mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool()
|
||||
mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool)
|
||||
|
||||
if isinstance(local_context, DropoutContext):
|
||||
if local_context.mask is None:
|
||||
@@ -638,7 +638,7 @@ class DisentangledSelfAttention(nn.Module):
|
||||
|
||||
def transpose_for_scores(self, x, attention_heads):
|
||||
new_x_shape = x.size()[:-1] + (attention_heads, -1)
|
||||
x = x.view(*new_x_shape)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
|
||||
|
||||
def forward(
|
||||
@@ -719,7 +719,7 @@ class DisentangledSelfAttention(nn.Module):
|
||||
.contiguous()
|
||||
)
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
if output_attentions:
|
||||
return (context_layer, attention_probs)
|
||||
else:
|
||||
|
||||
@@ -336,7 +336,7 @@ class LxmertAttention(nn.Module):
|
||||
self.num_attention_heads,
|
||||
self.attention_head_size,
|
||||
)
|
||||
x = x.view(*new_x_shape)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, hidden_states, context, attention_mask=None, output_attentions=False):
|
||||
@@ -365,7 +365,7 @@ class LxmertAttention(nn.Module):
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
return outputs
|
||||
@@ -1253,7 +1253,7 @@ class LxmertForPreTraining(LxmertPreTrainedModel):
|
||||
visual_prediction_scores = visual_prediction_scores_dict[key]
|
||||
visual_loss = visual_loss_fct(
|
||||
visual_prediction_scores.view(-1, output_dim),
|
||||
label.view(*label_shape),
|
||||
label.view(label_shape),
|
||||
)
|
||||
if visual_loss.dim() > 1: # Regression Losses
|
||||
visual_loss = visual_loss.mean(1)
|
||||
|
||||
@@ -261,7 +261,7 @@ def get_mask(input, local_context):
|
||||
mask = local_context.mask if local_context.reuse_mask else None
|
||||
|
||||
if dropout > 0 and mask is None:
|
||||
mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool()
|
||||
mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool)
|
||||
|
||||
if isinstance(local_context, DropoutContext):
|
||||
if local_context.mask is None:
|
||||
@@ -532,9 +532,9 @@ class XSoftmax(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(self, input, mask, dim):
|
||||
self.dim = dim
|
||||
rmask = ~(mask.bool())
|
||||
rmask = ~(mask.to(torch.bool))
|
||||
|
||||
output = input.masked_fill(rmask, float("-inf"))
|
||||
output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
|
||||
output = torch.softmax(output, self.dim)
|
||||
output.masked_fill_(rmask, 0)
|
||||
self.save_for_backward(output)
|
||||
@@ -557,7 +557,7 @@ class XSoftmax(torch.autograd.Function):
|
||||
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
|
||||
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
|
||||
)
|
||||
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf"))))
|
||||
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.dtype).min)))
|
||||
output = softmax(g, output, dim)
|
||||
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
|
||||
|
||||
@@ -711,7 +711,7 @@ class DisentangledSelfAttention(nn.Module):
|
||||
|
||||
def transpose_for_scores(self, x, attention_heads):
|
||||
new_x_shape = x.size()[:-1] + (attention_heads, -1)
|
||||
x = x.view(*new_x_shape)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
|
||||
|
||||
def forward(
|
||||
@@ -792,7 +792,7 @@ class DisentangledSelfAttention(nn.Module):
|
||||
.contiguous()
|
||||
)
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
if output_attentions:
|
||||
return (context_layer, attention_probs)
|
||||
else:
|
||||
|
||||
@@ -32,7 +32,9 @@ from torch.fx.proxy import ParameterProxy
|
||||
from .. import PretrainedConfig, PreTrainedModel, logging
|
||||
from ..models.auto import get_values
|
||||
from ..models.auto.modeling_auto import (
|
||||
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||
MODEL_FOR_CTC_MAPPING_NAMES,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
|
||||
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
||||
@@ -72,6 +74,8 @@ def _generate_supported_model_class_names(
|
||||
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
||||
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
|
||||
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||
"ctc": MODEL_FOR_CTC_MAPPING_NAMES,
|
||||
"audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
||||
}
|
||||
|
||||
if supported_tasks is None:
|
||||
@@ -95,12 +99,16 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
|
||||
"blenderbot",
|
||||
"blenderbot-small",
|
||||
"clip",
|
||||
"deberta",
|
||||
"deberta-v2",
|
||||
"distilbert",
|
||||
"electra",
|
||||
"gpt2",
|
||||
"gpt_neo",
|
||||
"gptj",
|
||||
"hubert",
|
||||
"layoutlm",
|
||||
"lxmert",
|
||||
"m2m_100",
|
||||
"marian",
|
||||
"mbart",
|
||||
@@ -118,8 +126,8 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
|
||||
"trocr",
|
||||
"vit",
|
||||
"xglm",
|
||||
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
|
||||
# "xlnet",
|
||||
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
|
||||
]
|
||||
|
||||
_REGULAR_SUPPORTED_MODELS = []
|
||||
@@ -155,6 +163,10 @@ def torch_nn_layernorm(self, input):
|
||||
return input
|
||||
|
||||
|
||||
def torch_nn_groupnorm(self, input):
|
||||
return input
|
||||
|
||||
|
||||
def torch_nn_linear(self, input):
|
||||
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
|
||||
|
||||
@@ -372,6 +384,27 @@ def torch_nn_conv2d(self, input):
|
||||
return torch.empty(shape, device="meta")
|
||||
|
||||
|
||||
def torch_squeeze(input, dim=None):
|
||||
shape = list(input.shape)
|
||||
if dim is not None:
|
||||
if dim < 0:
|
||||
dim = input.dim() + dim
|
||||
if shape[dim] == 1:
|
||||
shape.pop(dim)
|
||||
else:
|
||||
new_shape = []
|
||||
for dim_value in shape:
|
||||
if dim_value == 1:
|
||||
continue
|
||||
new_shape.append(dim_value)
|
||||
shape = new_shape
|
||||
return torch.empty(shape, device="meta")
|
||||
|
||||
|
||||
def torch_tensor_squeeze(self, dim=None):
|
||||
return torch_squeeze(self, dim)
|
||||
|
||||
|
||||
def torch_unsqueeze(input, dim):
|
||||
shape = list(input.shape)
|
||||
if dim < 0:
|
||||
@@ -446,6 +479,7 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
|
||||
torch.nn.Embedding: torch_nn_embedding,
|
||||
torch.nn.functional.embedding: torch_nn_functional_embedding,
|
||||
torch.nn.LayerNorm: torch_nn_layernorm,
|
||||
torch.nn.GroupNorm: torch_nn_groupnorm,
|
||||
torch.nn.Linear: torch_nn_linear,
|
||||
torch.relu: torch_relu,
|
||||
torch.nn.functional.relu: torch_nn_functional_relu,
|
||||
@@ -469,6 +503,8 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
|
||||
torch.Tensor.index_select: torch_tensor_index_select,
|
||||
torch.nn.Conv1d: torch_nn_conv1d,
|
||||
torch.nn.Conv2d: torch_nn_conv2d,
|
||||
torch.squeeze: torch_squeeze,
|
||||
torch.Tensor.squeeze: torch_tensor_squeeze,
|
||||
torch.unsqueeze: torch_unsqueeze,
|
||||
torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
|
||||
torch.unique_consecutive: torch_unique_consecutive,
|
||||
@@ -605,7 +641,7 @@ class HFTracer(Tracer):
|
||||
# Feature flag for proxying accesses to buffer values
|
||||
proxy_buffer_attributes: bool = True
|
||||
allow_insert_stateless_mods: bool = True
|
||||
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty"]
|
||||
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor"]
|
||||
|
||||
def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
|
||||
|
||||
@@ -704,8 +740,31 @@ class HFTracer(Tracer):
|
||||
inputs_dict[input_name] = torch.zeros(
|
||||
*shape, model.config.input_feat_per_channel, dtype=torch.float, device=device
|
||||
)
|
||||
elif "visual_feats" in input_name:
|
||||
inputs_dict[input_name] = torch.zeros(
|
||||
shape
|
||||
+ [
|
||||
model.config.visual_feat_dim,
|
||||
],
|
||||
dtype=torch.float,
|
||||
device=device,
|
||||
)
|
||||
elif "visual_pos" in input_name:
|
||||
inputs_dict[input_name] = torch.zeros(
|
||||
shape
|
||||
+ [
|
||||
model.config.visual_pos_dim,
|
||||
],
|
||||
dtype=torch.float,
|
||||
device=device,
|
||||
)
|
||||
elif "inputs" in input_name:
|
||||
inputs_dict[input_name] = torch.zeros(*shape, dtype=torch.float, device=device)
|
||||
elif "input_values" in input_name:
|
||||
batch_size, _ = shape
|
||||
# Generating big sequence length for audio inputs.
|
||||
seq_length = _generate_random_int(low=10000, high=20000)
|
||||
inputs_dict[input_name] = torch.zeros(batch_size, seq_length, dtype=torch.float, device=device)
|
||||
elif "mask" in input_name or "ids" in input_name:
|
||||
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
|
||||
else:
|
||||
|
||||
@@ -222,6 +222,7 @@ class DebertaModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
|
||||
fx_compatible = True
|
||||
test_torchscript = False
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
|
||||
@@ -241,6 +241,7 @@ class DebertaV2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
|
||||
fx_compatible = True
|
||||
test_torchscript = False
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
|
||||
@@ -16,12 +16,16 @@
|
||||
|
||||
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import HubertConfig, is_torch_available
|
||||
from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
|
||||
from transformers.utils import is_torch_fx_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import (
|
||||
@@ -45,6 +49,9 @@ if is_torch_available():
|
||||
)
|
||||
from transformers.models.hubert.modeling_hubert import _compute_mask_indices
|
||||
|
||||
if is_torch_fx_available():
|
||||
from transformers.utils.fx import symbolic_trace
|
||||
|
||||
|
||||
class HubertModelTester:
|
||||
def __init__(
|
||||
@@ -299,6 +306,7 @@ class HubertModelTester:
|
||||
@require_torch
|
||||
class HubertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (HubertForCTC, HubertForSequenceClassification, HubertModel) if is_torch_available() else ()
|
||||
fx_compatible = True
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
|
||||
@@ -417,6 +425,117 @@ class HubertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
# Hubert cannot be TorchScripted because of torch.nn.utils.weight_norm
|
||||
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
|
||||
if not is_torch_fx_available() or not self.fx_compatible:
|
||||
return
|
||||
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
configs_no_init.return_dict = False
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
|
||||
|
||||
try:
|
||||
if model.config.is_encoder_decoder:
|
||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||
labels = inputs.get("labels", None)
|
||||
input_names = [
|
||||
"attention_mask",
|
||||
"decoder_attention_mask",
|
||||
"decoder_input_ids",
|
||||
"input_features",
|
||||
"input_ids",
|
||||
"input_values",
|
||||
]
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
else:
|
||||
input_names = [
|
||||
"attention_mask",
|
||||
"bbox",
|
||||
"input_features",
|
||||
"input_ids",
|
||||
"input_values",
|
||||
"pixel_values",
|
||||
"token_type_ids",
|
||||
"visual_feats",
|
||||
"visual_pos",
|
||||
]
|
||||
|
||||
labels = inputs.get("labels", None)
|
||||
start_positions = inputs.get("start_positions", None)
|
||||
end_positions = inputs.get("end_positions", None)
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
if start_positions is not None:
|
||||
input_names.append("start_positions")
|
||||
if end_positions is not None:
|
||||
input_names.append("end_positions")
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
|
||||
except Exception as e:
|
||||
self.fail(f"Couldn't trace module: {e}")
|
||||
|
||||
def flatten_output(output):
|
||||
flatten = []
|
||||
for x in output:
|
||||
if isinstance(x, (tuple, list)):
|
||||
flatten += flatten_output(x)
|
||||
elif not isinstance(x, torch.Tensor):
|
||||
continue
|
||||
else:
|
||||
flatten.append(x)
|
||||
return flatten
|
||||
|
||||
model_output = flatten_output(model_output)
|
||||
traced_output = flatten_output(traced_output)
|
||||
num_outputs = len(model_output)
|
||||
|
||||
for i in range(num_outputs):
|
||||
self.assertTrue(
|
||||
torch.allclose(model_output[i], traced_output[i]),
|
||||
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
# Test that the model can be serialized and restored properly
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
|
||||
try:
|
||||
with open(pkl_file_name, "wb") as f:
|
||||
pickle.dump(traced_model, f)
|
||||
with open(pkl_file_name, "rb") as f:
|
||||
loaded = pickle.load(f)
|
||||
except Exception as e:
|
||||
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
|
||||
|
||||
loaded_output = loaded(**filtered_inputs)
|
||||
loaded_output = flatten_output(loaded_output)
|
||||
|
||||
for i in range(num_outputs):
|
||||
self.assertTrue(
|
||||
torch.allclose(model_output[i], loaded_output[i]),
|
||||
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
# overwrite from test_modeling_common
|
||||
def _mock_init_weights(self, module):
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
|
||||
@@ -535,6 +535,7 @@ class LxmertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (LxmertModel, LxmertForPreTraining, LxmertForQuestionAnswering) if is_torch_available() else ()
|
||||
|
||||
fx_compatible = True
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
|
||||
@@ -740,11 +740,12 @@ class ModelTesterMixin:
|
||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||
labels = inputs.get("labels", None)
|
||||
input_names = [
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"decoder_input_ids",
|
||||
"decoder_attention_mask",
|
||||
"decoder_input_ids",
|
||||
"input_features",
|
||||
"input_ids",
|
||||
"input_values",
|
||||
]
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
@@ -758,12 +759,15 @@ class ModelTesterMixin:
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
else:
|
||||
input_names = [
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"token_type_ids",
|
||||
"pixel_values",
|
||||
"bbox",
|
||||
"input_features",
|
||||
"input_ids",
|
||||
"input_values",
|
||||
"pixel_values",
|
||||
"token_type_ids",
|
||||
"visual_feats",
|
||||
"visual_pos",
|
||||
]
|
||||
|
||||
labels = inputs.get("labels", None)
|
||||
@@ -781,10 +785,17 @@ class ModelTesterMixin:
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
if (
|
||||
isinstance(model, tuple(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values()))
|
||||
and not hasattr(model.config, "problem_type")
|
||||
or model.config.problem_type is None
|
||||
):
|
||||
model.config.problem_type = "single_label_classification"
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
|
||||
except RuntimeError as e:
|
||||
except Exception as e:
|
||||
self.fail(f"Couldn't trace module: {e}")
|
||||
|
||||
def flatten_output(output):
|
||||
|
||||
Reference in New Issue
Block a user