Fix ONNX exports for Optimum compatible models (#31311)

* fixed models

* format with bumped ruff version on my local

* fix copies

* add tracing checks

* format

* Update src/transformers/utils/generic.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* format

* style fix

* Update modeling_mobilevit.py

* add docstring and change name

* Update __init__.py

* Update __init__.py

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Merve Noyan
2024-06-27 12:46:36 +03:00
committed by GitHub
parent dc76e9fa7f
commit c9f191a0b7
10 changed files with 72 additions and 17 deletions

View File

@@ -37,6 +37,7 @@ from ...utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
torch_int,
) )
from .configuration_clap import ClapAudioConfig, ClapConfig, ClapTextConfig from .configuration_clap import ClapAudioConfig, ClapConfig, ClapTextConfig
@@ -590,8 +591,10 @@ class ClapAudioLayer(nn.Module):
def set_shift_and_window_size(self, input_resolution): def set_shift_and_window_size(self, input_resolution):
if min(input_resolution) <= self.window_size: if min(input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows # if window size is larger than input resolution, we don't partition windows
self.shift_size = 0 self.shift_size = torch_int(0)
self.window_size = min(input_resolution) self.window_size = (
torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
)
def get_attn_mask(self, height, width, dtype, device): def get_attn_mask(self, height, width, dtype, device):
if self.shift_size > 0: if self.shift_size > 0:

View File

@@ -35,6 +35,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
logging, logging,
torch_int,
) )
from .configuration_donut_swin import DonutSwinConfig from .configuration_donut_swin import DonutSwinConfig
@@ -562,8 +563,10 @@ class DonutSwinLayer(nn.Module):
def set_shift_and_window_size(self, input_resolution): def set_shift_and_window_size(self, input_resolution):
if min(input_resolution) <= self.window_size: if min(input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows # if window size is larger than input resolution, we don't partition windows
self.shift_size = 0 self.shift_size = torch_int(0)
self.window_size = min(input_resolution) self.window_size = (
torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
)
def get_attn_mask(self, height, width, dtype, device): def get_attn_mask(self, height, width, dtype, device):
if self.shift_size > 0: if self.shift_size > 0:

View File

@@ -39,7 +39,7 @@ from ...file_utils import (
from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ModelOutput, logging from ...utils import ModelOutput, logging, torch_int
from ...utils.backbone_utils import load_backbone from ...utils.backbone_utils import load_backbone
from .configuration_dpt import DPTConfig from .configuration_dpt import DPTConfig
@@ -226,7 +226,7 @@ class DPTViTEmbeddings(nn.Module):
posemb_tok = posemb[:, :start_index] posemb_tok = posemb[:, :start_index]
posemb_grid = posemb[0, start_index:] posemb_grid = posemb[0, start_index:]
old_grid_size = int(math.sqrt(len(posemb_grid))) old_grid_size = torch_int(posemb_grid.size(0) ** 0.5)
posemb_grid = posemb_grid.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2) posemb_grid = posemb_grid.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2)
posemb_grid = nn.functional.interpolate(posemb_grid, size=(grid_size_height, grid_size_width), mode="bilinear") posemb_grid = nn.functional.interpolate(posemb_grid, size=(grid_size_height, grid_size_width), mode="bilinear")

View File

@@ -33,7 +33,13 @@ from ...modeling_outputs import (
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_float,
)
from .configuration_imagegpt import ImageGPTConfig from .configuration_imagegpt import ImageGPTConfig
@@ -229,7 +235,7 @@ class ImageGPTAttention(nn.Module):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) attn_weights = torch.matmul(query, key.transpose(-1, -2))
if self.scale_attn_weights: if self.scale_attn_weights:
attn_weights = attn_weights / (float(value.size(-1)) ** 0.5) attn_weights = attn_weights / torch_float(value.size(-1) ** 0.5)
# Layer-wise attention scaling # Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx: if self.scale_attn_by_inverse_layer_idx:

View File

@@ -33,7 +33,13 @@ from ...modeling_outputs import (
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward from ...pytorch_utils import apply_chunking_to_forward
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_int,
)
from .configuration_layoutlmv3 import LayoutLMv3Config from .configuration_layoutlmv3 import LayoutLMv3Config
@@ -910,8 +916,8 @@ class LayoutLMv3Model(LayoutLMv3PreTrainedModel):
patch_height = patch_width = None patch_height = patch_width = None
if pixel_values is not None: if pixel_values is not None:
patch_height, patch_width = ( patch_height, patch_width = (
int(pixel_values.shape[2] / self.config.patch_size), torch_int(pixel_values.shape[2] / self.config.patch_size),
int(pixel_values.shape[3] / self.config.patch_size), torch_int(pixel_values.shape[3] / self.config.patch_size),
) )
visual_embeddings = self.forward_image(pixel_values) visual_embeddings = self.forward_image(pixel_values)
visual_attention_mask = torch.ones( visual_attention_mask = torch.ones(

View File

@@ -39,6 +39,7 @@ from ...utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
torch_int,
) )
from .configuration_mobilevit import MobileViTConfig from .configuration_mobilevit import MobileViTConfig
@@ -437,8 +438,16 @@ class MobileViTLayer(nn.Module):
batch_size, channels, orig_height, orig_width = features.shape batch_size, channels, orig_height, orig_width = features.shape
new_height = int(math.ceil(orig_height / patch_height) * patch_height) new_height = (
new_width = int(math.ceil(orig_width / patch_width) * patch_width) torch_int(torch.ceil(orig_height / patch_height) * patch_height)
if torch.jit.is_tracing()
else int(math.ceil(orig_height / patch_height) * patch_height)
)
new_width = (
torch_int(torch.ceil(orig_width / patch_width) * patch_width)
if torch.jit.is_tracing()
else int(math.ceil(orig_width / patch_width) * patch_width)
)
interpolate = False interpolate = False
if new_width != orig_width or new_height != orig_height: if new_width != orig_width or new_height != orig_height:

View File

@@ -15,7 +15,6 @@
"""PyTorch SAM model.""" """PyTorch SAM model."""
import collections import collections
import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
@@ -232,7 +231,7 @@ class SamAttention(nn.Module):
# SamAttention # SamAttention
_, _, _, c_per_head = query.shape _, _, _, c_per_head = query.shape
attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens
attn = attn / math.sqrt(c_per_head) attn = attn / (c_per_head**0.5)
attn = torch.softmax(attn, dim=-1) attn = torch.softmax(attn, dim=-1)
if attention_similarity is not None: if attention_similarity is not None:

View File

@@ -36,6 +36,7 @@ from ...utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
torch_int,
) )
from ...utils.backbone_utils import BackboneMixin from ...utils.backbone_utils import BackboneMixin
from .configuration_swin import SwinConfig from .configuration_swin import SwinConfig
@@ -639,8 +640,10 @@ class SwinLayer(nn.Module):
def set_shift_and_window_size(self, input_resolution): def set_shift_and_window_size(self, input_resolution):
if min(input_resolution) <= self.window_size: if min(input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows # if window size is larger than input resolution, we don't partition windows
self.shift_size = 0 self.shift_size = torch_int(0)
self.window_size = min(input_resolution) self.window_size = (
torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
)
def get_attn_mask(self, height, width, dtype, device): def get_attn_mask(self, height, width, dtype, device):
if self.shift_size > 0: if self.shift_size > 0:

View File

@@ -60,6 +60,8 @@ from .generic import (
tensor_size, tensor_size,
to_numpy, to_numpy,
to_py_obj, to_py_obj,
torch_float,
torch_int,
transpose, transpose,
working_or_temp_dir, working_or_temp_dir,
) )

View File

@@ -753,6 +753,30 @@ def infer_framework(model_class):
raise TypeError(f"Could not infer framework from class {model_class}.") raise TypeError(f"Could not infer framework from class {model_class}.")
def torch_int(x):
"""
Casts an input to a torch int64 tensor if we are in a tracing context, otherwise to a Python int.
"""
if not is_torch_available():
return int(x)
import torch
return x.to(torch.int64) if torch.jit.is_tracing() else int(x)
def torch_float(x):
"""
Casts an input to a torch float32 tensor if we are in a tracing context, otherwise to a Python float.
"""
if not is_torch_available():
return int(x)
import torch
return x.to(torch.float32) if torch.jit.is_tracing() else int(x)
def filter_out_non_signature_kwargs(extra: Optional[list] = None): def filter_out_non_signature_kwargs(extra: Optional[list] = None):
""" """
Decorator to filter out named arguments that are not in the function signature. Decorator to filter out named arguments that are not in the function signature.