Fix ONNX exports for Optimum compatible models (#31311)
* fixed models * format with bumped ruff version on my local * fix copies * add tracing checks * format * Update src/transformers/utils/generic.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * format * style fix * Update modeling_mobilevit.py * add docstring and change name * Update __init__.py * Update __init__.py --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -37,6 +37,7 @@ from ...utils import (
|
|||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
|
torch_int,
|
||||||
)
|
)
|
||||||
from .configuration_clap import ClapAudioConfig, ClapConfig, ClapTextConfig
|
from .configuration_clap import ClapAudioConfig, ClapConfig, ClapTextConfig
|
||||||
|
|
||||||
@@ -590,8 +591,10 @@ class ClapAudioLayer(nn.Module):
|
|||||||
def set_shift_and_window_size(self, input_resolution):
|
def set_shift_and_window_size(self, input_resolution):
|
||||||
if min(input_resolution) <= self.window_size:
|
if min(input_resolution) <= self.window_size:
|
||||||
# if window size is larger than input resolution, we don't partition windows
|
# if window size is larger than input resolution, we don't partition windows
|
||||||
self.shift_size = 0
|
self.shift_size = torch_int(0)
|
||||||
self.window_size = min(input_resolution)
|
self.window_size = (
|
||||||
|
torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
|
||||||
|
)
|
||||||
|
|
||||||
def get_attn_mask(self, height, width, dtype, device):
|
def get_attn_mask(self, height, width, dtype, device):
|
||||||
if self.shift_size > 0:
|
if self.shift_size > 0:
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from ...utils import (
|
|||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
logging,
|
logging,
|
||||||
|
torch_int,
|
||||||
)
|
)
|
||||||
from .configuration_donut_swin import DonutSwinConfig
|
from .configuration_donut_swin import DonutSwinConfig
|
||||||
|
|
||||||
@@ -562,8 +563,10 @@ class DonutSwinLayer(nn.Module):
|
|||||||
def set_shift_and_window_size(self, input_resolution):
|
def set_shift_and_window_size(self, input_resolution):
|
||||||
if min(input_resolution) <= self.window_size:
|
if min(input_resolution) <= self.window_size:
|
||||||
# if window size is larger than input resolution, we don't partition windows
|
# if window size is larger than input resolution, we don't partition windows
|
||||||
self.shift_size = 0
|
self.shift_size = torch_int(0)
|
||||||
self.window_size = min(input_resolution)
|
self.window_size = (
|
||||||
|
torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
|
||||||
|
)
|
||||||
|
|
||||||
def get_attn_mask(self, height, width, dtype, device):
|
def get_attn_mask(self, height, width, dtype, device):
|
||||||
if self.shift_size > 0:
|
if self.shift_size > 0:
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ from ...file_utils import (
|
|||||||
from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput
|
from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import ModelOutput, logging
|
from ...utils import ModelOutput, logging, torch_int
|
||||||
from ...utils.backbone_utils import load_backbone
|
from ...utils.backbone_utils import load_backbone
|
||||||
from .configuration_dpt import DPTConfig
|
from .configuration_dpt import DPTConfig
|
||||||
|
|
||||||
@@ -226,7 +226,7 @@ class DPTViTEmbeddings(nn.Module):
|
|||||||
posemb_tok = posemb[:, :start_index]
|
posemb_tok = posemb[:, :start_index]
|
||||||
posemb_grid = posemb[0, start_index:]
|
posemb_grid = posemb[0, start_index:]
|
||||||
|
|
||||||
old_grid_size = int(math.sqrt(len(posemb_grid)))
|
old_grid_size = torch_int(posemb_grid.size(0) ** 0.5)
|
||||||
|
|
||||||
posemb_grid = posemb_grid.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2)
|
posemb_grid = posemb_grid.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2)
|
||||||
posemb_grid = nn.functional.interpolate(posemb_grid, size=(grid_size_height, grid_size_width), mode="bilinear")
|
posemb_grid = nn.functional.interpolate(posemb_grid, size=(grid_size_height, grid_size_width), mode="bilinear")
|
||||||
|
|||||||
@@ -33,7 +33,13 @@ from ...modeling_outputs import (
|
|||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
||||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
from ...utils import (
|
||||||
|
add_start_docstrings,
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
logging,
|
||||||
|
replace_return_docstrings,
|
||||||
|
torch_float,
|
||||||
|
)
|
||||||
from .configuration_imagegpt import ImageGPTConfig
|
from .configuration_imagegpt import ImageGPTConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -229,7 +235,7 @@ class ImageGPTAttention(nn.Module):
|
|||||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||||
|
|
||||||
if self.scale_attn_weights:
|
if self.scale_attn_weights:
|
||||||
attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)
|
attn_weights = attn_weights / torch_float(value.size(-1) ** 0.5)
|
||||||
|
|
||||||
# Layer-wise attention scaling
|
# Layer-wise attention scaling
|
||||||
if self.scale_attn_by_inverse_layer_idx:
|
if self.scale_attn_by_inverse_layer_idx:
|
||||||
|
|||||||
@@ -33,7 +33,13 @@ from ...modeling_outputs import (
|
|||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import apply_chunking_to_forward
|
from ...pytorch_utils import apply_chunking_to_forward
|
||||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
from ...utils import (
|
||||||
|
add_start_docstrings,
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
logging,
|
||||||
|
replace_return_docstrings,
|
||||||
|
torch_int,
|
||||||
|
)
|
||||||
from .configuration_layoutlmv3 import LayoutLMv3Config
|
from .configuration_layoutlmv3 import LayoutLMv3Config
|
||||||
|
|
||||||
|
|
||||||
@@ -910,8 +916,8 @@ class LayoutLMv3Model(LayoutLMv3PreTrainedModel):
|
|||||||
patch_height = patch_width = None
|
patch_height = patch_width = None
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
patch_height, patch_width = (
|
patch_height, patch_width = (
|
||||||
int(pixel_values.shape[2] / self.config.patch_size),
|
torch_int(pixel_values.shape[2] / self.config.patch_size),
|
||||||
int(pixel_values.shape[3] / self.config.patch_size),
|
torch_int(pixel_values.shape[3] / self.config.patch_size),
|
||||||
)
|
)
|
||||||
visual_embeddings = self.forward_image(pixel_values)
|
visual_embeddings = self.forward_image(pixel_values)
|
||||||
visual_attention_mask = torch.ones(
|
visual_attention_mask = torch.ones(
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from ...utils import (
|
|||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
|
torch_int,
|
||||||
)
|
)
|
||||||
from .configuration_mobilevit import MobileViTConfig
|
from .configuration_mobilevit import MobileViTConfig
|
||||||
|
|
||||||
@@ -437,8 +438,16 @@ class MobileViTLayer(nn.Module):
|
|||||||
|
|
||||||
batch_size, channels, orig_height, orig_width = features.shape
|
batch_size, channels, orig_height, orig_width = features.shape
|
||||||
|
|
||||||
new_height = int(math.ceil(orig_height / patch_height) * patch_height)
|
new_height = (
|
||||||
new_width = int(math.ceil(orig_width / patch_width) * patch_width)
|
torch_int(torch.ceil(orig_height / patch_height) * patch_height)
|
||||||
|
if torch.jit.is_tracing()
|
||||||
|
else int(math.ceil(orig_height / patch_height) * patch_height)
|
||||||
|
)
|
||||||
|
new_width = (
|
||||||
|
torch_int(torch.ceil(orig_width / patch_width) * patch_width)
|
||||||
|
if torch.jit.is_tracing()
|
||||||
|
else int(math.ceil(orig_width / patch_width) * patch_width)
|
||||||
|
)
|
||||||
|
|
||||||
interpolate = False
|
interpolate = False
|
||||||
if new_width != orig_width or new_height != orig_height:
|
if new_width != orig_width or new_height != orig_height:
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
"""PyTorch SAM model."""
|
"""PyTorch SAM model."""
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import math
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -232,7 +231,7 @@ class SamAttention(nn.Module):
|
|||||||
# SamAttention
|
# SamAttention
|
||||||
_, _, _, c_per_head = query.shape
|
_, _, _, c_per_head = query.shape
|
||||||
attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens
|
attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens
|
||||||
attn = attn / math.sqrt(c_per_head)
|
attn = attn / (c_per_head**0.5)
|
||||||
attn = torch.softmax(attn, dim=-1)
|
attn = torch.softmax(attn, dim=-1)
|
||||||
|
|
||||||
if attention_similarity is not None:
|
if attention_similarity is not None:
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ from ...utils import (
|
|||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
|
torch_int,
|
||||||
)
|
)
|
||||||
from ...utils.backbone_utils import BackboneMixin
|
from ...utils.backbone_utils import BackboneMixin
|
||||||
from .configuration_swin import SwinConfig
|
from .configuration_swin import SwinConfig
|
||||||
@@ -639,8 +640,10 @@ class SwinLayer(nn.Module):
|
|||||||
def set_shift_and_window_size(self, input_resolution):
|
def set_shift_and_window_size(self, input_resolution):
|
||||||
if min(input_resolution) <= self.window_size:
|
if min(input_resolution) <= self.window_size:
|
||||||
# if window size is larger than input resolution, we don't partition windows
|
# if window size is larger than input resolution, we don't partition windows
|
||||||
self.shift_size = 0
|
self.shift_size = torch_int(0)
|
||||||
self.window_size = min(input_resolution)
|
self.window_size = (
|
||||||
|
torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
|
||||||
|
)
|
||||||
|
|
||||||
def get_attn_mask(self, height, width, dtype, device):
|
def get_attn_mask(self, height, width, dtype, device):
|
||||||
if self.shift_size > 0:
|
if self.shift_size > 0:
|
||||||
|
|||||||
@@ -60,6 +60,8 @@ from .generic import (
|
|||||||
tensor_size,
|
tensor_size,
|
||||||
to_numpy,
|
to_numpy,
|
||||||
to_py_obj,
|
to_py_obj,
|
||||||
|
torch_float,
|
||||||
|
torch_int,
|
||||||
transpose,
|
transpose,
|
||||||
working_or_temp_dir,
|
working_or_temp_dir,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -753,6 +753,30 @@ def infer_framework(model_class):
|
|||||||
raise TypeError(f"Could not infer framework from class {model_class}.")
|
raise TypeError(f"Could not infer framework from class {model_class}.")
|
||||||
|
|
||||||
|
|
||||||
|
def torch_int(x):
|
||||||
|
"""
|
||||||
|
Casts an input to a torch int64 tensor if we are in a tracing context, otherwise to a Python int.
|
||||||
|
"""
|
||||||
|
if not is_torch_available():
|
||||||
|
return int(x)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
return x.to(torch.int64) if torch.jit.is_tracing() else int(x)
|
||||||
|
|
||||||
|
|
||||||
|
def torch_float(x):
|
||||||
|
"""
|
||||||
|
Casts an input to a torch float32 tensor if we are in a tracing context, otherwise to a Python float.
|
||||||
|
"""
|
||||||
|
if not is_torch_available():
|
||||||
|
return int(x)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
return x.to(torch.float32) if torch.jit.is_tracing() else int(x)
|
||||||
|
|
||||||
|
|
||||||
def filter_out_non_signature_kwargs(extra: Optional[list] = None):
|
def filter_out_non_signature_kwargs(extra: Optional[list] = None):
|
||||||
"""
|
"""
|
||||||
Decorator to filter out named arguments that are not in the function signature.
|
Decorator to filter out named arguments that are not in the function signature.
|
||||||
|
|||||||
Reference in New Issue
Block a user