diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index a1df51ac63..e66241bd56 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -104,9 +104,9 @@ class XSoftmax(torch.autograd.Function): @staticmethod def forward(self, input, mask, dim): self.dim = dim - rmask = ~(mask.bool()) + rmask = ~(mask.to(torch.bool)) - output = input.masked_fill(rmask, float("-inf")) + output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min)) output = torch.softmax(output, self.dim) output.masked_fill_(rmask, 0) self.save_for_backward(output) @@ -129,7 +129,7 @@ class XSoftmax(torch.autograd.Function): g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), to_i=sym_help.cast_pytorch_to_onnx["Byte"], ) - output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf")))) + output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.dtype).min))) output = softmax(g, output, dim) return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8))) @@ -152,7 +152,7 @@ def get_mask(input, local_context): mask = local_context.mask if local_context.reuse_mask else None if dropout > 0 and mask is None: - mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool() + mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool) if isinstance(local_context, DropoutContext): if local_context.mask is None: @@ -564,7 +564,7 @@ class DisentangledSelfAttention(nn.Module): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -652,7 +652,7 @@ class DisentangledSelfAttention(nn.Module): context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (-1,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) if output_attentions: return (context_layer, attention_probs) else: diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 8ebd6384db..3e57666acf 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -107,9 +107,9 @@ class XSoftmax(torch.autograd.Function): @staticmethod def forward(self, input, mask, dim): self.dim = dim - rmask = ~(mask.bool()) + rmask = ~(mask.to(torch.bool)) - output = input.masked_fill(rmask, float("-inf")) + output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min)) output = torch.softmax(output, self.dim) output.masked_fill_(rmask, 0) self.save_for_backward(output) @@ -132,7 +132,7 @@ class XSoftmax(torch.autograd.Function): g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), to_i=sym_help.cast_pytorch_to_onnx["Byte"], ) - output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf")))) + output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.dtype).min))) output = softmax(g, output, dim) return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8))) @@ -157,7 +157,7 @@ def get_mask(input, local_context): mask = local_context.mask if local_context.reuse_mask else None if dropout > 0 and mask is None: - mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool() + mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool) if isinstance(local_context, DropoutContext): if local_context.mask is None: @@ -638,7 +638,7 @@ class DisentangledSelfAttention(nn.Module): def transpose_for_scores(self, x, attention_heads): new_x_shape = x.size()[:-1] + (attention_heads, -1) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1)) def forward( @@ -719,7 +719,7 @@ class DisentangledSelfAttention(nn.Module): .contiguous() ) new_context_layer_shape = context_layer.size()[:-2] + (-1,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) if output_attentions: return (context_layer, attention_probs) else: diff --git a/src/transformers/models/lxmert/modeling_lxmert.py b/src/transformers/models/lxmert/modeling_lxmert.py index 823fcdb545..34b90441fe 100644 --- a/src/transformers/models/lxmert/modeling_lxmert.py +++ b/src/transformers/models/lxmert/modeling_lxmert.py @@ -336,7 +336,7 @@ class LxmertAttention(nn.Module): self.num_attention_heads, self.attention_head_size, ) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward(self, hidden_states, context, attention_mask=None, output_attentions=False): @@ -365,7 +365,7 @@ class LxmertAttention(nn.Module): context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) return outputs @@ -1253,7 +1253,7 @@ class LxmertForPreTraining(LxmertPreTrainedModel): visual_prediction_scores = visual_prediction_scores_dict[key] visual_loss = visual_loss_fct( visual_prediction_scores.view(-1, output_dim), - label.view(*label_shape), + label.view(label_shape), ) if visual_loss.dim() > 1: # Regression Losses visual_loss = visual_loss.mean(1) diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index a297e4c7b2..defdd71584 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -261,7 +261,7 @@ def get_mask(input, local_context): mask = local_context.mask if local_context.reuse_mask else None if dropout > 0 and mask is None: - mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool() + mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool) if isinstance(local_context, DropoutContext): if local_context.mask is None: @@ -532,9 +532,9 @@ class XSoftmax(torch.autograd.Function): @staticmethod def forward(self, input, mask, dim): self.dim = dim - rmask = ~(mask.bool()) + rmask = ~(mask.to(torch.bool)) - output = input.masked_fill(rmask, float("-inf")) + output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min)) output = torch.softmax(output, self.dim) output.masked_fill_(rmask, 0) self.save_for_backward(output) @@ -557,7 +557,7 @@ class XSoftmax(torch.autograd.Function): g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), to_i=sym_help.cast_pytorch_to_onnx["Byte"], ) - output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf")))) + output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.dtype).min))) output = softmax(g, output, dim) return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8))) @@ -711,7 +711,7 @@ class DisentangledSelfAttention(nn.Module): def transpose_for_scores(self, x, attention_heads): new_x_shape = x.size()[:-1] + (attention_heads, -1) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1)) def forward( @@ -792,7 +792,7 @@ class DisentangledSelfAttention(nn.Module): .contiguous() ) new_context_layer_shape = context_layer.size()[:-2] + (-1,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) if output_attentions: return (context_layer, attention_probs) else: diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index bbf32f9c6f..fc1bec9b91 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -32,7 +32,9 @@ from torch.fx.proxy import ParameterProxy from .. import PretrainedConfig, PreTrainedModel, logging from ..models.auto import get_values from ..models.auto.modeling_auto import ( + MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + MODEL_FOR_CTC_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES, @@ -72,6 +74,8 @@ def _generate_supported_model_class_names( "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, "masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES, "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, + "ctc": MODEL_FOR_CTC_MAPPING_NAMES, + "audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, } if supported_tasks is None: @@ -95,12 +99,16 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [ "blenderbot", "blenderbot-small", "clip", + "deberta", + "deberta-v2", "distilbert", "electra", "gpt2", "gpt_neo", "gptj", + "hubert", "layoutlm", + "lxmert", "m2m_100", "marian", "mbart", @@ -118,8 +126,8 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [ "trocr", "vit", "xglm", - # TODO: add support for them as it should be quite easy to do so (small blocking issues). # "xlnet", + # TODO: add support for them as it should be quite easy to do so (small blocking issues). ] _REGULAR_SUPPORTED_MODELS = [] @@ -155,6 +163,10 @@ def torch_nn_layernorm(self, input): return input +def torch_nn_groupnorm(self, input): + return input + + def torch_nn_linear(self, input): return torch.empty(input.shape[:-1] + (self.out_features,), device="meta") @@ -372,6 +384,27 @@ def torch_nn_conv2d(self, input): return torch.empty(shape, device="meta") +def torch_squeeze(input, dim=None): + shape = list(input.shape) + if dim is not None: + if dim < 0: + dim = input.dim() + dim + if shape[dim] == 1: + shape.pop(dim) + else: + new_shape = [] + for dim_value in shape: + if dim_value == 1: + continue + new_shape.append(dim_value) + shape = new_shape + return torch.empty(shape, device="meta") + + +def torch_tensor_squeeze(self, dim=None): + return torch_squeeze(self, dim) + + def torch_unsqueeze(input, dim): shape = list(input.shape) if dim < 0: @@ -446,6 +479,7 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = { torch.nn.Embedding: torch_nn_embedding, torch.nn.functional.embedding: torch_nn_functional_embedding, torch.nn.LayerNorm: torch_nn_layernorm, + torch.nn.GroupNorm: torch_nn_groupnorm, torch.nn.Linear: torch_nn_linear, torch.relu: torch_relu, torch.nn.functional.relu: torch_nn_functional_relu, @@ -469,6 +503,8 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = { torch.Tensor.index_select: torch_tensor_index_select, torch.nn.Conv1d: torch_nn_conv1d, torch.nn.Conv2d: torch_nn_conv2d, + torch.squeeze: torch_squeeze, + torch.Tensor.squeeze: torch_tensor_squeeze, torch.unsqueeze: torch_unsqueeze, torch.Tensor.unsqueeze: torch_tensor_unsqueeze, torch.unique_consecutive: torch_unique_consecutive, @@ -605,7 +641,7 @@ class HFTracer(Tracer): # Feature flag for proxying accesses to buffer values proxy_buffer_attributes: bool = True allow_insert_stateless_mods: bool = True - _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty"] + _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor"] def __init__(self, autowrap_modules=(math,), autowrap_functions=()): @@ -704,8 +740,31 @@ class HFTracer(Tracer): inputs_dict[input_name] = torch.zeros( *shape, model.config.input_feat_per_channel, dtype=torch.float, device=device ) + elif "visual_feats" in input_name: + inputs_dict[input_name] = torch.zeros( + shape + + [ + model.config.visual_feat_dim, + ], + dtype=torch.float, + device=device, + ) + elif "visual_pos" in input_name: + inputs_dict[input_name] = torch.zeros( + shape + + [ + model.config.visual_pos_dim, + ], + dtype=torch.float, + device=device, + ) elif "inputs" in input_name: inputs_dict[input_name] = torch.zeros(*shape, dtype=torch.float, device=device) + elif "input_values" in input_name: + batch_size, _ = shape + # Generating big sequence length for audio inputs. + seq_length = _generate_random_int(low=10000, high=20000) + inputs_dict[input_name] = torch.zeros(batch_size, seq_length, dtype=torch.float, device=device) elif "mask" in input_name or "ids" in input_name: inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device) else: diff --git a/tests/models/deberta/test_modeling_deberta.py b/tests/models/deberta/test_modeling_deberta.py index 8d2e2dd020..ee1ba57414 100644 --- a/tests/models/deberta/test_modeling_deberta.py +++ b/tests/models/deberta/test_modeling_deberta.py @@ -222,6 +222,7 @@ class DebertaModelTest(ModelTesterMixin, unittest.TestCase): else () ) + fx_compatible = True test_torchscript = False test_pruning = False test_head_masking = False diff --git a/tests/models/deberta_v2/test_modeling_deberta_v2.py b/tests/models/deberta_v2/test_modeling_deberta_v2.py index ef8e6e8e52..93436b901b 100644 --- a/tests/models/deberta_v2/test_modeling_deberta_v2.py +++ b/tests/models/deberta_v2/test_modeling_deberta_v2.py @@ -241,6 +241,7 @@ class DebertaV2ModelTest(ModelTesterMixin, unittest.TestCase): else () ) + fx_compatible = True test_torchscript = False test_pruning = False test_head_masking = False diff --git a/tests/models/hubert/test_modeling_hubert.py b/tests/models/hubert/test_modeling_hubert.py index 0055b8346a..1e27690bd4 100644 --- a/tests/models/hubert/test_modeling_hubert.py +++ b/tests/models/hubert/test_modeling_hubert.py @@ -16,12 +16,16 @@ import math +import os +import pickle +import tempfile import unittest import pytest from transformers import HubertConfig, is_torch_available from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device +from transformers.utils import is_torch_fx_available from ...test_configuration_common import ConfigTester from ...test_modeling_common import ( @@ -45,6 +49,9 @@ if is_torch_available(): ) from transformers.models.hubert.modeling_hubert import _compute_mask_indices +if is_torch_fx_available(): + from transformers.utils.fx import symbolic_trace + class HubertModelTester: def __init__( @@ -299,6 +306,7 @@ class HubertModelTester: @require_torch class HubertModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (HubertForCTC, HubertForSequenceClassification, HubertModel) if is_torch_available() else () + fx_compatible = True test_pruning = False test_headmasking = False @@ -417,6 +425,117 @@ class HubertModelTest(ModelTesterMixin, unittest.TestCase): msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) + # Hubert cannot be TorchScripted because of torch.nn.utils.weight_norm + def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False): + if not is_torch_fx_available() or not self.fx_compatible: + return + + configs_no_init = _config_zero_init(config) # To be sure we have no Nan + configs_no_init.return_dict = False + + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + model.to(torch_device) + model.eval() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss) + + try: + if model.config.is_encoder_decoder: + model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward + labels = inputs.get("labels", None) + input_names = [ + "attention_mask", + "decoder_attention_mask", + "decoder_input_ids", + "input_features", + "input_ids", + "input_values", + ] + if labels is not None: + input_names.append("labels") + + filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} + input_names = list(filtered_inputs.keys()) + + model_output = model(**filtered_inputs) + + traced_model = symbolic_trace(model, input_names) + traced_output = traced_model(**filtered_inputs) + else: + input_names = [ + "attention_mask", + "bbox", + "input_features", + "input_ids", + "input_values", + "pixel_values", + "token_type_ids", + "visual_feats", + "visual_pos", + ] + + labels = inputs.get("labels", None) + start_positions = inputs.get("start_positions", None) + end_positions = inputs.get("end_positions", None) + if labels is not None: + input_names.append("labels") + if start_positions is not None: + input_names.append("start_positions") + if end_positions is not None: + input_names.append("end_positions") + + filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} + input_names = list(filtered_inputs.keys()) + + model_output = model(**filtered_inputs) + + traced_model = symbolic_trace(model, input_names) + traced_output = traced_model(**filtered_inputs) + + except Exception as e: + self.fail(f"Couldn't trace module: {e}") + + def flatten_output(output): + flatten = [] + for x in output: + if isinstance(x, (tuple, list)): + flatten += flatten_output(x) + elif not isinstance(x, torch.Tensor): + continue + else: + flatten.append(x) + return flatten + + model_output = flatten_output(model_output) + traced_output = flatten_output(traced_output) + num_outputs = len(model_output) + + for i in range(num_outputs): + self.assertTrue( + torch.allclose(model_output[i], traced_output[i]), + f"traced {i}th output doesn't match model {i}th output for {model_class}", + ) + + # Test that the model can be serialized and restored properly + with tempfile.TemporaryDirectory() as tmp_dir_name: + pkl_file_name = os.path.join(tmp_dir_name, "model.pkl") + try: + with open(pkl_file_name, "wb") as f: + pickle.dump(traced_model, f) + with open(pkl_file_name, "rb") as f: + loaded = pickle.load(f) + except Exception as e: + self.fail(f"Couldn't serialize / deserialize the traced model: {e}") + + loaded_output = loaded(**filtered_inputs) + loaded_output = flatten_output(loaded_output) + + for i in range(num_outputs): + self.assertTrue( + torch.allclose(model_output[i], loaded_output[i]), + f"serialized model {i}th output doesn't match model {i}th output for {model_class}", + ) + # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: diff --git a/tests/models/lxmert/test_modeling_lxmert.py b/tests/models/lxmert/test_modeling_lxmert.py index 7061aaa7d3..1c51d02e96 100644 --- a/tests/models/lxmert/test_modeling_lxmert.py +++ b/tests/models/lxmert/test_modeling_lxmert.py @@ -535,6 +535,7 @@ class LxmertModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (LxmertModel, LxmertForPreTraining, LxmertForQuestionAnswering) if is_torch_available() else () + fx_compatible = True test_head_masking = False test_pruning = False test_torchscript = False diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 927eccefd4..83927bb27f 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -740,11 +740,12 @@ class ModelTesterMixin: model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward labels = inputs.get("labels", None) input_names = [ - "input_ids", "attention_mask", - "decoder_input_ids", "decoder_attention_mask", + "decoder_input_ids", "input_features", + "input_ids", + "input_values", ] if labels is not None: input_names.append("labels") @@ -758,12 +759,15 @@ class ModelTesterMixin: traced_output = traced_model(**filtered_inputs) else: input_names = [ - "input_ids", "attention_mask", - "token_type_ids", - "pixel_values", "bbox", "input_features", + "input_ids", + "input_values", + "pixel_values", + "token_type_ids", + "visual_feats", + "visual_pos", ] labels = inputs.get("labels", None) @@ -781,10 +785,17 @@ class ModelTesterMixin: model_output = model(**filtered_inputs) + if ( + isinstance(model, tuple(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values())) + and not hasattr(model.config, "problem_type") + or model.config.problem_type is None + ): + model.config.problem_type = "single_label_classification" + traced_model = symbolic_trace(model, input_names) traced_output = traced_model(**filtered_inputs) - except RuntimeError as e: + except Exception as e: self.fail(f"Couldn't trace module: {e}") def flatten_output(output):