ViT and Swin symbolic tracing with torch.fx (#17182)
* Support tracing for ViT * Swin support * Fix copies * Fix type annotation issue * Removed unused import
This commit is contained in:
@@ -168,7 +168,7 @@ class DeiTSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(new_x_shape)
|
||||||
return x.permute(0, 2, 1, 3)
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -200,7 +200,7 @@ class DeiTSelfAttention(nn.Module):
|
|||||||
|
|
||||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(new_context_layer_shape)
|
||||||
|
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
|
||||||
|
|||||||
@@ -177,7 +177,7 @@ class DPTViTSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(new_x_shape)
|
||||||
return x.permute(0, 2, 1, 3)
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -209,7 +209,7 @@ class DPTViTSelfAttention(nn.Module):
|
|||||||
|
|
||||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(new_context_layer_shape)
|
||||||
|
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
|
||||||
|
|||||||
@@ -496,7 +496,7 @@ def window_reverse(windows, window_size, height, width):
|
|||||||
"""
|
"""
|
||||||
Merges windows to produce higher resolution features.
|
Merges windows to produce higher resolution features.
|
||||||
"""
|
"""
|
||||||
batch_size = int(windows.shape[0] / (height * width / window_size / window_size))
|
batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size))
|
||||||
windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1)
|
windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1)
|
||||||
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1)
|
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1)
|
||||||
return windows
|
return windows
|
||||||
@@ -697,7 +697,7 @@ class MaskFormerSwinSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def transpose_for_scores(self, x):
|
def transpose_for_scores(self, x):
|
||||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(new_x_shape)
|
||||||
return x.permute(0, 2, 1, 3)
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -750,7 +750,7 @@ class MaskFormerSwinSelfAttention(nn.Module):
|
|||||||
context_layer = torch.matmul(attention_probs, value_layer)
|
context_layer = torch.matmul(attention_probs, value_layer)
|
||||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(new_context_layer_shape)
|
||||||
|
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
|
||||||
|
|||||||
@@ -226,7 +226,7 @@ def window_reverse(windows, window_size, height, width):
|
|||||||
"""
|
"""
|
||||||
Merges windows to produce higher resolution features.
|
Merges windows to produce higher resolution features.
|
||||||
"""
|
"""
|
||||||
batch_size = int(windows.shape[0] / (height * width / window_size / window_size))
|
batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size))
|
||||||
windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1)
|
windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1)
|
||||||
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1)
|
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1)
|
||||||
return windows
|
return windows
|
||||||
@@ -435,7 +435,7 @@ class SwinSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def transpose_for_scores(self, x):
|
def transpose_for_scores(self, x):
|
||||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(new_x_shape)
|
||||||
return x.permute(0, 2, 1, 3)
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -488,7 +488,7 @@ class SwinSelfAttention(nn.Module):
|
|||||||
context_layer = torch.matmul(attention_probs, value_layer)
|
context_layer = torch.matmul(attention_probs, value_layer)
|
||||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(new_context_layer_shape)
|
||||||
|
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
|
||||||
@@ -1071,7 +1071,7 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
|
|||||||
# Reshape to (batch_size, num_channels, height, width)
|
# Reshape to (batch_size, num_channels, height, width)
|
||||||
sequence_output = sequence_output.transpose(1, 2)
|
sequence_output = sequence_output.transpose(1, 2)
|
||||||
batch_size, num_channels, sequence_length = sequence_output.shape
|
batch_size, num_channels, sequence_length = sequence_output.shape
|
||||||
height = width = int(sequence_length**0.5)
|
height = width = math.floor(sequence_length**0.5)
|
||||||
sequence_output = sequence_output.reshape(batch_size, num_channels, height, width)
|
sequence_output = sequence_output.reshape(batch_size, num_channels, height, width)
|
||||||
|
|
||||||
# Reconstruct pixel values
|
# Reconstruct pixel values
|
||||||
|
|||||||
@@ -213,7 +213,7 @@ class ViTSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(new_x_shape)
|
||||||
return x.permute(0, 2, 1, 3)
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -245,7 +245,7 @@ class ViTSelfAttention(nn.Module):
|
|||||||
|
|
||||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(new_context_layer_shape)
|
||||||
|
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
|
||||||
@@ -687,7 +687,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
|
|||||||
# Reshape to (batch_size, num_channels, height, width)
|
# Reshape to (batch_size, num_channels, height, width)
|
||||||
sequence_output = sequence_output[:, 1:]
|
sequence_output = sequence_output[:, 1:]
|
||||||
batch_size, sequence_length, num_channels = sequence_output.shape
|
batch_size, sequence_length, num_channels = sequence_output.shape
|
||||||
height = width = int(sequence_length**0.5)
|
height = width = math.floor(sequence_length**0.5)
|
||||||
sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
|
sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
|
||||||
|
|
||||||
# Reconstruct pixel values
|
# Reconstruct pixel values
|
||||||
|
|||||||
@@ -342,7 +342,7 @@ class ViTMAESelfAttention(nn.Module):
|
|||||||
|
|
||||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(new_x_shape)
|
||||||
return x.permute(0, 2, 1, 3)
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -374,7 +374,7 @@ class ViTMAESelfAttention(nn.Module):
|
|||||||
|
|
||||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(new_context_layer_shape)
|
||||||
|
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
|
||||||
|
|||||||
@@ -280,7 +280,7 @@ class YolosSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(new_x_shape)
|
||||||
return x.permute(0, 2, 1, 3)
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -312,7 +312,7 @@ class YolosSelfAttention(nn.Module):
|
|||||||
|
|
||||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(new_context_layer_shape)
|
||||||
|
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
|
||||||
|
|||||||
@@ -14,12 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
|
import collections
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
from copy import deepcopy
|
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -31,6 +31,7 @@ from .. import (
|
|||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||||
|
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
|
||||||
MODEL_FOR_MASKED_LM_MAPPING,
|
MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||||
@@ -71,6 +72,7 @@ def _generate_supported_model_classes(
|
|||||||
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
"sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
"sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
|
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
|
||||||
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -100,6 +102,8 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
|
|||||||
"gpt_neo",
|
"gpt_neo",
|
||||||
"t5",
|
"t5",
|
||||||
"roberta",
|
"roberta",
|
||||||
|
"vit",
|
||||||
|
"swin",
|
||||||
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
|
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
|
||||||
# "layoutlm",
|
# "layoutlm",
|
||||||
# "xlnet",
|
# "xlnet",
|
||||||
@@ -276,6 +280,31 @@ def torch_tensor_index_select(self, dim, index):
|
|||||||
return torch_tensor_index_select(self, dim, index)
|
return torch_tensor_index_select(self, dim, index)
|
||||||
|
|
||||||
|
|
||||||
|
def torch_roll(input, shifts, dims=None):
|
||||||
|
return input
|
||||||
|
|
||||||
|
|
||||||
|
def torch_nn_conv2d(self, input):
|
||||||
|
h_in, w_in = input.shape[-2:]
|
||||||
|
shape = None
|
||||||
|
padding = self.padding
|
||||||
|
if padding == "valid":
|
||||||
|
padding = (0, 0)
|
||||||
|
if padding == "same":
|
||||||
|
shape = list(input.shape)
|
||||||
|
if shape is None:
|
||||||
|
shape = list(input.shape)
|
||||||
|
h_out = math.floor(
|
||||||
|
(h_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
|
||||||
|
)
|
||||||
|
w_out = math.floor(
|
||||||
|
(w_in + 2 * padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
|
||||||
|
)
|
||||||
|
shape[-2:] = [h_out, w_out]
|
||||||
|
shape[-3] = self.out_channels
|
||||||
|
return torch.empty(shape, device="meta")
|
||||||
|
|
||||||
|
|
||||||
def torch_nn_mseloss(self, input, target):
|
def torch_nn_mseloss(self, input, target):
|
||||||
if self.reduction == "none":
|
if self.reduction == "none":
|
||||||
shape = target.shape
|
shape = target.shape
|
||||||
@@ -317,9 +346,11 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
|
|||||||
torch.Tensor.mul: torch_tensor_mul_override,
|
torch.Tensor.mul: torch_tensor_mul_override,
|
||||||
torch.matmul: torch_matmul_override,
|
torch.matmul: torch_matmul_override,
|
||||||
torch.Tensor.repeat: torch_tensor_repeat_override,
|
torch.Tensor.repeat: torch_tensor_repeat_override,
|
||||||
|
torch.roll: torch_roll,
|
||||||
# TODO: those might not be needed.
|
# TODO: those might not be needed.
|
||||||
# torch.index_select: torch_index_select,
|
# torch.index_select: torch_index_select,
|
||||||
# torch.Tensor.index_select: torch_tensor_index_select,
|
# torch.Tensor.index_select: torch_tensor_index_select,
|
||||||
|
torch.nn.Conv2d: torch_nn_conv2d,
|
||||||
torch.nn.MSELoss: torch_nn_mseloss,
|
torch.nn.MSELoss: torch_nn_mseloss,
|
||||||
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
|
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
|
||||||
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
|
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
|
||||||
@@ -368,6 +399,9 @@ class HFProxy(Proxy):
|
|||||||
# we peephole optimize to the method invocation
|
# we peephole optimize to the method invocation
|
||||||
return HFAttribute(self, k)
|
return HFAttribute(self, k)
|
||||||
|
|
||||||
|
def __setitem__(self, indices, values):
|
||||||
|
return self.tracer.create_proxy("call_method", "__setitem__", (self, indices, values), {})
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
# To handle cases such as :
|
# To handle cases such as :
|
||||||
# `"some_key" in kwargs`
|
# `"some_key" in kwargs`
|
||||||
@@ -521,6 +555,15 @@ class HFTracer(Tracer):
|
|||||||
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
|
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"{model_class} not supported yet.")
|
raise NotImplementedError(f"{model_class} not supported yet.")
|
||||||
|
elif "pixel_values" in input_name:
|
||||||
|
batch_size = shape[0]
|
||||||
|
image_size = model.config.image_size
|
||||||
|
if not isinstance(image_size, collections.abc.Iterable):
|
||||||
|
image_size = (image_size, image_size)
|
||||||
|
height, width = image_size
|
||||||
|
inputs_dict[input_name] = torch.zeros(
|
||||||
|
batch_size, model.config.num_channels, height, width, dtype=torch.float32, device=device
|
||||||
|
)
|
||||||
|
|
||||||
elif "mask" in input_name or "ids" in input_name:
|
elif "mask" in input_name or "ids" in input_name:
|
||||||
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
|
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
|
||||||
@@ -663,6 +706,11 @@ class HFTracer(Tracer):
|
|||||||
else:
|
else:
|
||||||
self.graph.erase_node(node)
|
self.graph.erase_node(node)
|
||||||
|
|
||||||
|
# TODO: solves GraphModule creation.
|
||||||
|
# Without this, return type annotation "Tuple" is causing code execution failure.
|
||||||
|
if node.op == "output":
|
||||||
|
node.type = None
|
||||||
|
|
||||||
return self.graph
|
return self.graph
|
||||||
|
|
||||||
def _stateless_mod_instanciation_depends_on_proxies(self, mod: nn.Module) -> bool:
|
def _stateless_mod_instanciation_depends_on_proxies(self, mod: nn.Module) -> bool:
|
||||||
@@ -761,12 +809,4 @@ def symbolic_trace(
|
|||||||
traced_graph = tracer.trace(model, concrete_args=concrete_args)
|
traced_graph = tracer.trace(model, concrete_args=concrete_args)
|
||||||
traced = torch.fx.GraphModule(model, traced_graph)
|
traced = torch.fx.GraphModule(model, traced_graph)
|
||||||
|
|
||||||
# Copy all the original attributes to the traced GraphModule.
|
|
||||||
regular_module_attributes = dir(nn.Module())
|
|
||||||
for name in dir(model):
|
|
||||||
attr = getattr(model, name)
|
|
||||||
if name.startswith("_") or name in regular_module_attributes:
|
|
||||||
continue
|
|
||||||
setattr(traced, name, deepcopy(attr))
|
|
||||||
|
|
||||||
return traced
|
return traced
|
||||||
|
|||||||
@@ -175,6 +175,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
|
fx_compatible = True
|
||||||
|
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
|
|||||||
@@ -155,6 +155,7 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
|
fx_compatible = True
|
||||||
|
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
|
|||||||
@@ -738,8 +738,7 @@ class ModelTesterMixin:
|
|||||||
traced_model = symbolic_trace(model, input_names)
|
traced_model = symbolic_trace(model, input_names)
|
||||||
traced_output = traced_model(**filtered_inputs)
|
traced_output = traced_model(**filtered_inputs)
|
||||||
else:
|
else:
|
||||||
input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"]
|
||||||
input_ids = inputs["input_ids"]
|
|
||||||
|
|
||||||
labels = inputs.get("labels", None)
|
labels = inputs.get("labels", None)
|
||||||
start_positions = inputs.get("start_positions", None)
|
start_positions = inputs.get("start_positions", None)
|
||||||
@@ -756,12 +755,6 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
model_output = model(**filtered_inputs)
|
model_output = model(**filtered_inputs)
|
||||||
|
|
||||||
rank = len(input_ids.shape)
|
|
||||||
if rank not in [2, 3]:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"symbolic_trace automatic parameters inference not implemented for input of rank {rank}."
|
|
||||||
)
|
|
||||||
|
|
||||||
traced_model = symbolic_trace(model, input_names)
|
traced_model = symbolic_trace(model, input_names)
|
||||||
traced_output = traced_model(**filtered_inputs)
|
traced_output = traced_model(**filtered_inputs)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user