FX support for ConvNext, Wav2Vec2 and ResNet (#19053)
* Support for ConvNext * Support for Wav2Vec2 * Support for Resnet * Fix small issue in test_modeling_convnext
This commit is contained in:
@@ -960,7 +960,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
|||||||
# take argmax in non-differentiable way
|
# take argmax in non-differentiable way
|
||||||
# comptute hard codevector distribution (one hot)
|
# comptute hard codevector distribution (one hot)
|
||||||
codevector_idx = hidden_states.argmax(dim=-1)
|
codevector_idx = hidden_states.argmax(dim=-1)
|
||||||
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
|
codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_(
|
||||||
-1, codevector_idx.view(-1, 1), 1.0
|
-1, codevector_idx.view(-1, 1), 1.0
|
||||||
)
|
)
|
||||||
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
|
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
|
||||||
|
|||||||
@@ -1023,7 +1023,7 @@ class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module):
|
|||||||
# take argmax in non-differentiable way
|
# take argmax in non-differentiable way
|
||||||
# comptute hard codevector distribution (one hot)
|
# comptute hard codevector distribution (one hot)
|
||||||
codevector_idx = hidden_states.argmax(dim=-1)
|
codevector_idx = hidden_states.argmax(dim=-1)
|
||||||
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
|
codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_(
|
||||||
-1, codevector_idx.view(-1, 1), 1.0
|
-1, codevector_idx.view(-1, 1), 1.0
|
||||||
)
|
)
|
||||||
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
|
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
|
||||||
|
|||||||
@@ -104,6 +104,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
|
|||||||
"blenderbot-small",
|
"blenderbot-small",
|
||||||
"bloom",
|
"bloom",
|
||||||
"clip",
|
"clip",
|
||||||
|
"convnext",
|
||||||
"deberta",
|
"deberta",
|
||||||
"deberta-v2",
|
"deberta-v2",
|
||||||
"distilbert",
|
"distilbert",
|
||||||
@@ -125,6 +126,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
|
|||||||
"opt",
|
"opt",
|
||||||
"pegasus",
|
"pegasus",
|
||||||
"plbart",
|
"plbart",
|
||||||
|
"resnet",
|
||||||
"roberta",
|
"roberta",
|
||||||
"speech_to_text",
|
"speech_to_text",
|
||||||
"speech_to_text_2",
|
"speech_to_text_2",
|
||||||
@@ -133,6 +135,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
|
|||||||
"trocr",
|
"trocr",
|
||||||
"vit",
|
"vit",
|
||||||
"xglm",
|
"xglm",
|
||||||
|
"wav2vec2",
|
||||||
# "xlnet",
|
# "xlnet",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -743,7 +746,7 @@ class HFTracer(Tracer):
|
|||||||
elif hasattr(model.config, "encoder"):
|
elif hasattr(model.config, "encoder"):
|
||||||
image_size = model.config.encoder.image_size
|
image_size = model.config.encoder.image_size
|
||||||
else:
|
else:
|
||||||
raise AttributeError('Could not find the "image_size" field in the model config')
|
image_size = (_generate_random_int(), _generate_random_int())
|
||||||
|
|
||||||
# If no num_channels is in the config, use some arbitrary value.
|
# If no num_channels is in the config, use some arbitrary value.
|
||||||
num_channels = getattr(model.config, "num_channels", 3)
|
num_channels = getattr(model.config, "num_channels", 3)
|
||||||
|
|||||||
@@ -137,6 +137,7 @@ class ConvNextModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
fx_compatible = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
|||||||
@@ -126,6 +126,7 @@ class ResNetModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
all_model_classes = (ResNetModel, ResNetForImageClassification) if is_torch_available() else ()
|
all_model_classes = (ResNetModel, ResNetForImageClassification) if is_torch_available() else ()
|
||||||
|
|
||||||
|
fx_compatible = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
|||||||
@@ -15,6 +15,9 @@
|
|||||||
""" Testing suite for the PyTorch Wav2Vec2 model. """
|
""" Testing suite for the PyTorch Wav2Vec2 model. """
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -32,6 +35,7 @@ from transformers.testing_utils import (
|
|||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
from transformers.utils import is_torch_fx_available
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import (
|
from ...test_modeling_common import (
|
||||||
@@ -72,6 +76,10 @@ if is_pyctcdecode_available():
|
|||||||
from transformers import Wav2Vec2ProcessorWithLM
|
from transformers import Wav2Vec2ProcessorWithLM
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_fx_available():
|
||||||
|
from transformers.utils.fx import symbolic_trace
|
||||||
|
|
||||||
|
|
||||||
class Wav2Vec2ModelTester:
|
class Wav2Vec2ModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -411,6 +419,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
|
fx_compatible = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
|
|
||||||
@@ -633,6 +642,106 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
# Wav2Vec2 cannot be torchscripted because of group norm.
|
||||||
|
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
|
||||||
|
if not is_torch_fx_available() or not self.fx_compatible:
|
||||||
|
return
|
||||||
|
|
||||||
|
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||||
|
configs_no_init.return_dict = False
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config=configs_no_init)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
|
||||||
|
|
||||||
|
try:
|
||||||
|
input_names = [
|
||||||
|
"attention_mask",
|
||||||
|
"bbox",
|
||||||
|
"input_features",
|
||||||
|
"input_ids",
|
||||||
|
"input_values",
|
||||||
|
"pixel_values",
|
||||||
|
"token_type_ids",
|
||||||
|
"visual_feats",
|
||||||
|
"visual_pos",
|
||||||
|
]
|
||||||
|
|
||||||
|
labels = inputs.get("labels", None)
|
||||||
|
start_positions = inputs.get("start_positions", None)
|
||||||
|
end_positions = inputs.get("end_positions", None)
|
||||||
|
if labels is not None:
|
||||||
|
input_names.append("labels")
|
||||||
|
if start_positions is not None:
|
||||||
|
input_names.append("start_positions")
|
||||||
|
if end_positions is not None:
|
||||||
|
input_names.append("end_positions")
|
||||||
|
|
||||||
|
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||||
|
input_names = list(filtered_inputs.keys())
|
||||||
|
|
||||||
|
model_output = model(**filtered_inputs)
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(model, Wav2Vec2ForSequenceClassification)
|
||||||
|
and not hasattr(model.config, "problem_type")
|
||||||
|
or model.config.problem_type is None
|
||||||
|
):
|
||||||
|
model.config.problem_type = "single_label_classification"
|
||||||
|
|
||||||
|
traced_model = symbolic_trace(model, input_names)
|
||||||
|
traced_output = traced_model(**filtered_inputs)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.fail(f"Couldn't trace module: {e}")
|
||||||
|
|
||||||
|
def flatten_output(output):
|
||||||
|
flatten = []
|
||||||
|
for x in output:
|
||||||
|
if isinstance(x, (tuple, list)):
|
||||||
|
flatten += flatten_output(x)
|
||||||
|
elif not isinstance(x, torch.Tensor):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
flatten.append(x)
|
||||||
|
return flatten
|
||||||
|
|
||||||
|
model_output = flatten_output(model_output)
|
||||||
|
traced_output = flatten_output(traced_output)
|
||||||
|
num_outputs = len(model_output)
|
||||||
|
|
||||||
|
for i in range(num_outputs):
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(model_output[i], traced_output[i]),
|
||||||
|
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test that the model can be serialized and restored properly
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||||
|
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
|
||||||
|
try:
|
||||||
|
with open(pkl_file_name, "wb") as f:
|
||||||
|
pickle.dump(traced_model, f)
|
||||||
|
with open(pkl_file_name, "rb") as f:
|
||||||
|
loaded = pickle.load(f)
|
||||||
|
except Exception as e:
|
||||||
|
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
|
||||||
|
|
||||||
|
loaded_output = loaded(**filtered_inputs)
|
||||||
|
loaded_output = flatten_output(loaded_output)
|
||||||
|
|
||||||
|
for i in range(num_outputs):
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(model_output[i], loaded_output[i]),
|
||||||
|
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
|
||||||
|
# (Even with this call, there are still memory leak by ~0.04MB)
|
||||||
|
self.clear_torch_jit_class_registry()
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user