diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 603340bb7a..1c993deac0 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -721,6 +721,8 @@
title: Swin2SR
- local: model_doc/table-transformer
title: Table Transformer
+ - local: model_doc/textnet
+ title: TextNet
- local: model_doc/timm_wrapper
title: Timm Wrapper
- local: model_doc/upernet
diff --git a/docs/source/en/index.md b/docs/source/en/index.md
index 810ea0565d..127c80e9cf 100644
--- a/docs/source/en/index.md
+++ b/docs/source/en/index.md
@@ -326,6 +326,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Table Transformer](model_doc/table-transformer) | ✅ | ❌ | ❌ |
| [TAPAS](model_doc/tapas) | ✅ | ✅ | ❌ |
| [TAPEX](model_doc/tapex) | ✅ | ✅ | ✅ |
+| [TextNet](model_doc/textnet) | ✅ | ❌ | ❌ |
| [Time Series Transformer](model_doc/time_series_transformer) | ✅ | ❌ | ❌ |
| [TimeSformer](model_doc/timesformer) | ✅ | ❌ | ❌ |
| [TimmWrapperModel](model_doc/timm_wrapper) | ✅ | ❌ | ❌ |
diff --git a/docs/source/en/model_doc/textnet.md b/docs/source/en/model_doc/textnet.md
new file mode 100644
index 0000000000..d6b431e648
--- /dev/null
+++ b/docs/source/en/model_doc/textnet.md
@@ -0,0 +1,55 @@
+
+
+# TextNet
+
+## Overview
+
+The TextNet model was proposed in [FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation](https://arxiv.org/abs/2111.02394) by Zhe Chen, Jiahao Wang, Wenhai Wang, Guo Chen, Enze Xie, Ping Luo, Tong Lu. TextNet is a vision backbone useful for text detection tasks. It is the result of neural architecture search (NAS) on backbones with reward function as text detection task (to provide powerful features for text detection).
+
+
+
+ TextNet backbone as part of FAST. Taken from the original paper.
+
+This model was contributed by [Raghavan](https://huggingface.co/Raghavan), [jadechoghari](https://huggingface.co/jadechoghari) and [nielsr](https://huggingface.co/nielsr).
+
+## Usage tips
+
+TextNet is mainly used as a backbone network for the architecture search of text detection. Each stage of the backbone network is comprised of a stride-2 convolution and searchable blocks.
+Specifically, we present a layer-level candidate set, defined as {conv3×3, conv1×3, conv3×1, identity}. As the 1×3 and 3×1 convolutions have asymmetric kernels and oriented structure priors, they may help to capture the features of extreme aspect-ratio and rotated text lines.
+
+TextNet is the backbone for Fast, but can also be used as an efficient text/image classification, we add a `TextNetForImageClassification` as is it would allow people to train an image classifier on top of the pre-trained textnet weights
+
+## TextNetConfig
+
+[[autodoc]] TextNetConfig
+
+## TextNetImageProcessor
+
+[[autodoc]] TextNetImageProcessor
+ - preprocess
+
+## TextNetModel
+
+[[autodoc]] TextNetModel
+ - forward
+
+## TextNetForImageClassification
+
+[[autodoc]] TextNetForImageClassification
+ - forward
+
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index 25898b6520..0c8765e303 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -789,6 +789,7 @@ _import_structure = {
"TapasConfig",
"TapasTokenizer",
],
+ "models.textnet": ["TextNetConfig"],
"models.time_series_transformer": ["TimeSeriesTransformerConfig"],
"models.timesformer": ["TimesformerConfig"],
"models.timm_backbone": ["TimmBackboneConfig"],
@@ -1258,6 +1259,7 @@ else:
_import_structure["models.siglip"].append("SiglipImageProcessor")
_import_structure["models.superpoint"].extend(["SuperPointImageProcessor"])
_import_structure["models.swin2sr"].append("Swin2SRImageProcessor")
+ _import_structure["models.textnet"].extend(["TextNetImageProcessor"])
_import_structure["models.tvp"].append("TvpImageProcessor")
_import_structure["models.video_llava"].append("VideoLlavaImageProcessor")
_import_structure["models.videomae"].extend(["VideoMAEFeatureExtractor", "VideoMAEImageProcessor"])
@@ -3584,6 +3586,14 @@ else:
"load_tf_weights_in_tapas",
]
)
+ _import_structure["models.textnet"].extend(
+ [
+ "TextNetBackbone",
+ "TextNetForImageClassification",
+ "TextNetModel",
+ "TextNetPreTrainedModel",
+ ]
+ )
_import_structure["models.time_series_transformer"].extend(
[
"TimeSeriesTransformerForPrediction",
@@ -5813,6 +5823,7 @@ if TYPE_CHECKING:
TapasConfig,
TapasTokenizer,
)
+ from .models.textnet import TextNetConfig
from .models.time_series_transformer import (
TimeSeriesTransformerConfig,
)
@@ -6293,6 +6304,7 @@ if TYPE_CHECKING:
from .models.siglip import SiglipImageProcessor
from .models.superpoint import SuperPointImageProcessor
from .models.swin2sr import Swin2SRImageProcessor
+ from .models.textnet import TextNetImageProcessor
from .models.tvp import TvpImageProcessor
from .models.video_llava import VideoLlavaImageProcessor
from .models.videomae import VideoMAEFeatureExtractor, VideoMAEImageProcessor
@@ -8155,6 +8167,12 @@ if TYPE_CHECKING:
TapasPreTrainedModel,
load_tf_weights_in_tapas,
)
+ from .models.textnet import (
+ TextNetBackbone,
+ TextNetForImageClassification,
+ TextNetModel,
+ TextNetPreTrainedModel,
+ )
from .models.time_series_transformer import (
TimeSeriesTransformerForPrediction,
TimeSeriesTransformerModel,
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 766a0eab94..7b4456240c 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -252,6 +252,7 @@ from . import (
t5,
table_transformer,
tapas,
+ textnet,
time_series_transformer,
timesformer,
timm_backbone,
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index b11bd36f79..1a5bb64039 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -279,6 +279,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("t5", "T5Config"),
("table-transformer", "TableTransformerConfig"),
("tapas", "TapasConfig"),
+ ("textnet", "TextNetConfig"),
("time_series_transformer", "TimeSeriesTransformerConfig"),
("timesformer", "TimesformerConfig"),
("timm_backbone", "TimmBackboneConfig"),
@@ -610,6 +611,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("table-transformer", "Table Transformer"),
("tapas", "TAPAS"),
("tapex", "TAPEX"),
+ ("textnet", "TextNet"),
("time_series_transformer", "Time Series Transformer"),
("timesformer", "TimeSformer"),
("timm_backbone", "TimmBackbone"),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index d40caeef39..0729caa626 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -257,6 +257,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("t5", "T5Model"),
("table-transformer", "TableTransformerModel"),
("tapas", "TapasModel"),
+ ("textnet", "TextNetModel"),
("time_series_transformer", "TimeSeriesTransformerModel"),
("timesformer", "TimesformerModel"),
("timm_backbone", "TimmBackbone"),
@@ -703,6 +704,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("swiftformer", "SwiftFormerForImageClassification"),
("swin", "SwinForImageClassification"),
("swinv2", "Swinv2ForImageClassification"),
+ ("textnet", "TextNetForImageClassification"),
("timm_wrapper", "TimmWrapperForImageClassification"),
("van", "VanForImageClassification"),
("vit", "ViTForImageClassification"),
@@ -1391,6 +1393,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
("rt_detr_resnet", "RTDetrResNetBackbone"),
("swin", "SwinBackbone"),
("swinv2", "Swinv2Backbone"),
+ ("textnet", "TextNetBackbone"),
("timm_backbone", "TimmBackbone"),
("vitdet", "VitDetBackbone"),
]
diff --git a/src/transformers/models/textnet/__init__.py b/src/transformers/models/textnet/__init__.py
new file mode 100644
index 0000000000..8f04a680b2
--- /dev/null
+++ b/src/transformers/models/textnet/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_textnet import *
+ from .image_processing_textnet import *
+ from .modeling_textnet import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/textnet/configuration_textnet.py b/src/transformers/models/textnet/configuration_textnet.py
new file mode 100644
index 0000000000..61ecaaeba8
--- /dev/null
+++ b/src/transformers/models/textnet/configuration_textnet.py
@@ -0,0 +1,135 @@
+# coding=utf-8
+# Copyright 2024 the Fast authors and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TextNet model configuration"""
+
+from transformers import PretrainedConfig
+from transformers.utils import logging
+from transformers.utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
+
+
+logger = logging.get_logger(__name__)
+
+
+class TextNetConfig(BackboneConfigMixin, PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`TextNextModel`]. It is used to instantiate a
+ TextNext model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the
+ [czczup/textnet-base](https://huggingface.co/czczup/textnet-base). Configuration objects inherit from
+ [`PretrainedConfig`] and can be used to control the model outputs.Read the documentation from [`PretrainedConfig`]
+ for more information.
+
+ Args:
+ stem_kernel_size (`int`, *optional*, defaults to 3):
+ The kernel size for the initial convolution layer.
+ stem_stride (`int`, *optional*, defaults to 2):
+ The stride for the initial convolution layer.
+ stem_num_channels (`int`, *optional*, defaults to 3):
+ The num of channels in input for the initial convolution layer.
+ stem_out_channels (`int`, *optional*, defaults to 64):
+ The num of channels in out for the initial convolution layer.
+ stem_act_func (`str`, *optional*, defaults to `"relu"`):
+ The activation function for the initial convolution layer.
+ image_size (`Tuple[int, int]`, *optional*, defaults to `[640, 640]`):
+ The size (resolution) of each image.
+ conv_layer_kernel_sizes (`List[List[List[int]]]`, *optional*):
+ A list of stage-wise kernel sizes. If `None`, defaults to:
+ `[[[3, 3], [3, 3], [3, 3]], [[3, 3], [1, 3], [3, 3], [3, 1]], [[3, 3], [3, 3], [3, 1], [1, 3]], [[3, 3], [3, 1], [1, 3], [3, 3]]]`.
+ conv_layer_strides (`List[List[int]]`, *optional*):
+ A list of stage-wise strides. If `None`, defaults to:
+ `[[1, 2, 1], [2, 1, 1, 1], [2, 1, 1, 1], [2, 1, 1, 1]]`.
+ hidden_sizes (`List[int]`, *optional*, defaults to `[64, 64, 128, 256, 512]`):
+ Dimensionality (hidden size) at each stage.
+ batch_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the batch normalization layers.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ out_features (`List[str]`, *optional*):
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage.
+ out_indices (`List[int]`, *optional*):
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+ If unset and `out_features` is unset, will default to the last stage.
+
+ Examples:
+
+ ```python
+ >>> from transformers import TextNetConfig, TextNetBackbone
+
+ >>> # Initializing a TextNetConfig
+ >>> configuration = TextNetConfig()
+
+ >>> # Initializing a model (with random weights)
+ >>> model = TextNetBackbone(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "textnet"
+
+ def __init__(
+ self,
+ stem_kernel_size=3,
+ stem_stride=2,
+ stem_num_channels=3,
+ stem_out_channels=64,
+ stem_act_func="relu",
+ image_size=[640, 640],
+ conv_layer_kernel_sizes=None,
+ conv_layer_strides=None,
+ hidden_sizes=[64, 64, 128, 256, 512],
+ batch_norm_eps=1e-5,
+ initializer_range=0.02,
+ out_features=None,
+ out_indices=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ if conv_layer_kernel_sizes is None:
+ conv_layer_kernel_sizes = [
+ [[3, 3], [3, 3], [3, 3]],
+ [[3, 3], [1, 3], [3, 3], [3, 1]],
+ [[3, 3], [3, 3], [3, 1], [1, 3]],
+ [[3, 3], [3, 1], [1, 3], [3, 3]],
+ ]
+ if conv_layer_strides is None:
+ conv_layer_strides = [[1, 2, 1], [2, 1, 1, 1], [2, 1, 1, 1], [2, 1, 1, 1]]
+
+ self.stem_kernel_size = stem_kernel_size
+ self.stem_stride = stem_stride
+ self.stem_num_channels = stem_num_channels
+ self.stem_out_channels = stem_out_channels
+ self.stem_act_func = stem_act_func
+
+ self.image_size = image_size
+ self.conv_layer_kernel_sizes = conv_layer_kernel_sizes
+ self.conv_layer_strides = conv_layer_strides
+
+ self.initializer_range = initializer_range
+ self.hidden_sizes = hidden_sizes
+ self.batch_norm_eps = batch_norm_eps
+
+ self.depths = [len(layer) for layer in self.conv_layer_kernel_sizes]
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, 5)]
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
+ )
+
+
+__all__ = ["TextNetConfig"]
diff --git a/src/transformers/models/textnet/convert_textnet_to_hf.py b/src/transformers/models/textnet/convert_textnet_to_hf.py
new file mode 100644
index 0000000000..a8a004d18a
--- /dev/null
+++ b/src/transformers/models/textnet/convert_textnet_to_hf.py
@@ -0,0 +1,208 @@
+# coding=utf-8
+# Copyright 2024 the Fast authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import json
+import logging
+import re
+from collections import OrderedDict
+
+import requests
+import torch
+from huggingface_hub import hf_hub_download
+from PIL import Image
+
+from transformers import TextNetBackbone, TextNetConfig, TextNetImageProcessor
+
+
+tiny_config_url = "https://raw.githubusercontent.com/czczup/FAST/main/config/fast/nas-configs/fast_tiny.config"
+small_config_url = "https://raw.githubusercontent.com/czczup/FAST/main/config/fast/nas-configs/fast_small.config"
+base_config_url = "https://raw.githubusercontent.com/czczup/FAST/main/config/fast/nas-configs/fast_base.config"
+
+rename_key_mappings = {
+ "module.backbone": "textnet",
+ "first_conv": "stem",
+ "bn": "batch_norm",
+ "ver": "vertical",
+ "hor": "horizontal",
+}
+
+
+def prepare_config(size_config_url, size):
+ config_dict = json.loads(requests.get(size_config_url).text)
+
+ backbone_config = {}
+ for stage_ix in range(1, 5):
+ stage_config = config_dict[f"stage{stage_ix}"]
+
+ merged_dict = {}
+
+ # Iterate through the list of dictionaries
+ for layer in stage_config:
+ for key, value in layer.items():
+ if key != "name":
+ # Check if the key is already in the merged_dict
+ if key in merged_dict:
+ merged_dict[key].append(value)
+ else:
+ # If the key is not in merged_dict, create a new list with the value
+ merged_dict[key] = [value]
+ backbone_config[f"stage{stage_ix}"] = merged_dict
+
+ neck_in_channels = []
+ neck_out_channels = []
+ neck_kernel_size = []
+ neck_stride = []
+ neck_dilation = []
+ neck_groups = []
+
+ for i in range(1, 5):
+ layer_key = f"reduce_layer{i}"
+ layer_dict = config_dict["neck"].get(layer_key)
+
+ if layer_dict:
+ # Append values to the corresponding lists
+ neck_in_channels.append(layer_dict["in_channels"])
+ neck_out_channels.append(layer_dict["out_channels"])
+ neck_kernel_size.append(layer_dict["kernel_size"])
+ neck_stride.append(layer_dict["stride"])
+ neck_dilation.append(layer_dict["dilation"])
+ neck_groups.append(layer_dict["groups"])
+
+ textnet_config = TextNetConfig(
+ stem_kernel_size=config_dict["first_conv"]["kernel_size"],
+ stem_stride=config_dict["first_conv"]["stride"],
+ stem_num_channels=config_dict["first_conv"]["in_channels"],
+ stem_out_channels=config_dict["first_conv"]["out_channels"],
+ stem_act_func=config_dict["first_conv"]["act_func"],
+ conv_layer_kernel_sizes=[
+ backbone_config["stage1"]["kernel_size"],
+ backbone_config["stage2"]["kernel_size"],
+ backbone_config["stage3"]["kernel_size"],
+ backbone_config["stage4"]["kernel_size"],
+ ],
+ conv_layer_strides=[
+ backbone_config["stage1"]["stride"],
+ backbone_config["stage2"]["stride"],
+ backbone_config["stage3"]["stride"],
+ backbone_config["stage4"]["stride"],
+ ],
+ hidden_sizes=[
+ config_dict["first_conv"]["out_channels"],
+ backbone_config["stage1"]["out_channels"][-1],
+ backbone_config["stage2"]["out_channels"][-1],
+ backbone_config["stage3"]["out_channels"][-1],
+ backbone_config["stage4"]["out_channels"][-1],
+ ],
+ out_features=["stage1", "stage2", "stage3", "stage4"],
+ out_indices=[1, 2, 3, 4],
+ )
+
+ return textnet_config
+
+
+def convert_textnet_checkpoint(checkpoint_url, checkpoint_config_filename, pytorch_dump_folder_path):
+ config_filepath = hf_hub_download(repo_id="Raghavan/fast_model_config_files", filename="fast_model_configs.json")
+
+ with open(config_filepath) as f:
+ content = json.loads(f.read())
+
+ size = content[checkpoint_config_filename]["short_size"]
+
+ if "tiny" in content[checkpoint_config_filename]["config"]:
+ config = prepare_config(tiny_config_url, size)
+ expected_slice_backbone = torch.tensor(
+ [0.0000, 0.0000, 0.0000, 0.0000, 0.5300, 0.0000, 0.0000, 0.0000, 0.0000, 1.1221]
+ )
+ elif "small" in content[checkpoint_config_filename]["config"]:
+ config = prepare_config(small_config_url, size)
+ expected_slice_backbone = torch.tensor(
+ [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1394]
+ )
+ else:
+ config = prepare_config(base_config_url, size)
+ expected_slice_backbone = torch.tensor(
+ [0.9210, 0.6099, 0.0000, 0.0000, 0.0000, 0.0000, 3.2207, 2.6602, 1.8925, 0.0000]
+ )
+
+ model = TextNetBackbone(config)
+ textnet_image_processor = TextNetImageProcessor(size={"shortest_edge": size})
+ state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)["ema"]
+ state_dict_changed = OrderedDict()
+ for key in state_dict:
+ if "backbone" in key:
+ val = state_dict[key]
+ new_key = key
+ for search, replacement in rename_key_mappings.items():
+ if search in new_key:
+ new_key = new_key.replace(search, replacement)
+
+ pattern = r"textnet\.stage(\d)"
+
+ def adjust_stage(match):
+ stage_number = int(match.group(1)) - 1
+ return f"textnet.encoder.stages.{stage_number}.stage"
+
+ # Using regex to find and replace the pattern in the string
+ new_key = re.sub(pattern, adjust_stage, new_key)
+ state_dict_changed[new_key] = val
+ model.load_state_dict(state_dict_changed)
+ model.eval()
+
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
+
+ original_pixel_values = torch.tensor(
+ [0.1939, 0.3481, 0.4166, 0.3309, 0.4508, 0.4679, 0.4851, 0.4851, 0.3309, 0.4337]
+ )
+ pixel_values = textnet_image_processor(image, return_tensors="pt").pixel_values
+
+ assert torch.allclose(original_pixel_values, pixel_values[0][0][3][:10], atol=1e-4)
+
+ with torch.no_grad():
+ output = model(pixel_values)
+
+ assert torch.allclose(output["feature_maps"][-1][0][10][12][:10].detach(), expected_slice_backbone, atol=1e-3)
+
+ model.save_pretrained(pytorch_dump_folder_path)
+ textnet_image_processor.save_pretrained(pytorch_dump_folder_path)
+ logging.info("The converted weights are saved here : " + pytorch_dump_folder_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--checkpoint_url",
+ default="https://github.com/czczup/FAST/releases/download/release/fast_base_ic17mlt_640.pth",
+ type=str,
+ help="URL to the original PyTorch checkpoint (.pth file).",
+ )
+ parser.add_argument(
+ "--checkpoint_config_filename",
+ default="fast_base_ic17mlt_640.py",
+ type=str,
+ help="URL to the original PyTorch checkpoint (.pth file).",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
+ )
+ args = parser.parse_args()
+
+ convert_textnet_checkpoint(
+ args.checkpoint_url,
+ args.checkpoint_config_filename,
+ args.pytorch_dump_folder_path,
+ )
diff --git a/src/transformers/models/textnet/image_processing_textnet.py b/src/transformers/models/textnet/image_processing_textnet.py
new file mode 100644
index 0000000000..b3d4250b41
--- /dev/null
+++ b/src/transformers/models/textnet/image_processing_textnet.py
@@ -0,0 +1,355 @@
+# coding=utf-8
+# Copyright 2024 the Fast authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for TextNet."""
+
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import (
+ convert_to_rgb,
+ get_resize_output_image_size,
+ resize,
+ to_channel_dimension_format,
+)
+from ...image_utils import (
+ IMAGENET_DEFAULT_MEAN,
+ IMAGENET_DEFAULT_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_kwargs,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, is_vision_available, logging
+
+
+logger = logging.get_logger(__name__)
+
+if is_vision_available():
+ import PIL
+
+
+class TextNetImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a TextNet image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
+ `do_resize` in the `preprocess` method.
+ size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
+ Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
+ method.
+ size_divisor (`int`, *optional*, defaults to 32):
+ Ensures height and width are rounded to a multiple of this value after resizing.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
+ do_center_crop (`bool`, *optional*, defaults to `False`):
+ Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
+ `preprocess` method.
+ crop_size (`Dict[str, int]` *optional*, defaults to 224):
+ Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
+ method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
+ the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
+ method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to RGB.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Dict[str, int] = None,
+ size_divisor: int = 32,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_center_crop: bool = False,
+ crop_size: Dict[str, int] = None,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = IMAGENET_DEFAULT_MEAN,
+ image_std: Optional[Union[float, List[float]]] = IMAGENET_DEFAULT_STD,
+ do_convert_rgb: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"shortest_edge": 224}
+ size = get_size_dict(size, default_to_square=False)
+ crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
+ crop_size = get_size_dict(crop_size, param_name="crop_size")
+
+ self.do_resize = do_resize
+ self.size = size
+ self.size_divisor = size_divisor
+ self.resample = resample
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
+ self.do_convert_rgb = do_convert_rgb
+
+ self._valid_processor_keys = [
+ "images",
+ "do_resize",
+ "size",
+ "size_divisor",
+ "resample",
+ "do_center_crop",
+ "crop_size",
+ "do_rescale",
+ "rescale_factor",
+ "do_normalize",
+ "image_mean",
+ "image_std",
+ "do_convert_rgb",
+ "return_tensors",
+ "data_format",
+ "input_data_format",
+ ]
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image. The shortest edge of the image is resized to size["shortest_edge"] , with the longest edge
+ resized to keep the input aspect ratio. Both the height and width are resized to be divisible by 32.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Size of the output image.
+ size_divisor (`int`, *optional*, defaults to `32`):
+ Ensures height and width are rounded to a multiple of this value after resizing.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ Resampling filter to use when resiizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ default_to_square (`bool`, *optional*, defaults to `False`):
+ The value to be passed to `get_size_dict` as `default_to_square` when computing the image size. If the
+ `size` argument in `get_size_dict` is an `int`, it determines whether to default to a square image or
+ not.Note that this attribute is not used in computing `crop_size` via calling `get_size_dict`.
+ """
+ if "shortest_edge" in size:
+ size = size["shortest_edge"]
+ elif "height" in size and "width" in size:
+ size = (size["height"], size["width"])
+ else:
+ raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
+
+ height, width = get_resize_output_image_size(
+ image, size=size, input_data_format=input_data_format, default_to_square=False
+ )
+ if height % self.size_divisor != 0:
+ height += self.size_divisor - (height % self.size_divisor)
+ if width % self.size_divisor != 0:
+ width += self.size_divisor - (width % self.size_divisor)
+
+ return resize(
+ image,
+ size=(height, width),
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: bool = None,
+ size: Dict[str, int] = None,
+ size_divisor: int = None,
+ resample: PILImageResampling = None,
+ do_center_crop: bool = None,
+ crop_size: int = None,
+ do_rescale: bool = None,
+ rescale_factor: float = None,
+ do_normalize: bool = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_convert_rgb: bool = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio.
+ size_divisor (`int`, *optional*, defaults to `32`):
+ Ensures height and width are rounded to a multiple of this value after resizing.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
+ has an effect if `do_resize` is set to `True`.
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
+ Whether to center crop the image.
+ crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
+ Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
+ `True`.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ size = get_size_dict(size, param_name="size", default_to_square=False)
+ size_divisor = size_divisor if size_divisor is not None else self.size_divisor
+ resample = resample if resample is not None else self.resample
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
+ crop_size = crop_size if crop_size is not None else self.crop_size
+ crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+
+ validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
+
+ images = make_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_center_crop=do_center_crop,
+ crop_size=crop_size,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ if do_convert_rgb:
+ images = [convert_to_rgb(image) for image in images]
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if is_scaled_image(images[0]) and do_rescale:
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ all_images = []
+ for image in images:
+ if do_resize:
+ image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
+
+ if do_center_crop:
+ image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
+
+ if do_rescale:
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+
+ if do_normalize:
+ image = self.normalize(
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
+ )
+
+ all_images.append(image)
+ images = [
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ for image in all_images
+ ]
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+__all__ = ["TextNetImageProcessor"]
diff --git a/src/transformers/models/textnet/modeling_textnet.py b/src/transformers/models/textnet/modeling_textnet.py
new file mode 100644
index 0000000000..c895e66dc1
--- /dev/null
+++ b/src/transformers/models/textnet/modeling_textnet.py
@@ -0,0 +1,487 @@
+# coding=utf-8
+# Copyright 2024 the Fast authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch TextNet model."""
+
+from typing import Any, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from transformers import PreTrainedModel, add_start_docstrings
+from transformers.activations import ACT2CLS
+from transformers.modeling_outputs import (
+ BackboneOutput,
+ BaseModelOutputWithNoAttention,
+ BaseModelOutputWithPoolingAndNoAttention,
+ ImageClassifierOutputWithNoAttention,
+)
+from transformers.models.textnet.configuration_textnet import TextNetConfig
+from transformers.utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from transformers.utils.backbone_utils import BackboneMixin
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "TextNetConfig"
+_CHECKPOINT_FOR_DOC = "czczup/textnet-base"
+_EXPECTED_OUTPUT_SHAPE = [1, 512, 20, 27]
+
+
+class TextNetConvLayer(nn.Module):
+ def __init__(self, config: TextNetConfig):
+ super().__init__()
+
+ self.kernel_size = config.stem_kernel_size
+ self.stride = config.stem_stride
+ self.activation_function = config.stem_act_func
+
+ padding = (
+ (config.kernel_size[0] // 2, config.kernel_size[1] // 2)
+ if isinstance(config.stem_kernel_size, tuple)
+ else config.stem_kernel_size // 2
+ )
+
+ self.conv = nn.Conv2d(
+ config.stem_num_channels,
+ config.stem_out_channels,
+ kernel_size=config.stem_kernel_size,
+ stride=config.stem_stride,
+ padding=padding,
+ bias=False,
+ )
+ self.batch_norm = nn.BatchNorm2d(config.stem_out_channels, config.batch_norm_eps)
+
+ self.activation = nn.Identity()
+ if self.activation_function is not None:
+ self.activation = ACT2CLS[self.activation_function]()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.batch_norm(hidden_states)
+ return self.activation(hidden_states)
+
+
+class TextNetRepConvLayer(nn.Module):
+ r"""
+ This layer supports re-parameterization by combining multiple convolutional branches
+ (e.g., main convolution, vertical, horizontal, and identity branches) during training.
+ At inference time, these branches can be collapsed into a single convolution for
+ efficiency, as per the re-parameterization paradigm.
+
+ The "Rep" in the name stands for "re-parameterization" (introduced by RepVGG).
+ """
+
+ def __init__(self, config: TextNetConfig, in_channels: int, out_channels: int, kernel_size: int, stride: int):
+ super().__init__()
+
+ self.num_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.stride = stride
+
+ padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
+
+ self.activation_function = nn.ReLU()
+
+ self.main_conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ bias=False,
+ )
+ self.main_batch_norm = nn.BatchNorm2d(num_features=out_channels, eps=config.batch_norm_eps)
+
+ vertical_padding = ((kernel_size[0] - 1) // 2, 0)
+ horizontal_padding = (0, (kernel_size[1] - 1) // 2)
+
+ if kernel_size[1] != 1:
+ self.vertical_conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=(kernel_size[0], 1),
+ stride=stride,
+ padding=vertical_padding,
+ bias=False,
+ )
+ self.vertical_batch_norm = nn.BatchNorm2d(num_features=out_channels, eps=config.batch_norm_eps)
+ else:
+ self.vertical_conv, self.vertical_batch_norm = None, None
+
+ if kernel_size[0] != 1:
+ self.horizontal_conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=(1, kernel_size[1]),
+ stride=stride,
+ padding=horizontal_padding,
+ bias=False,
+ )
+ self.horizontal_batch_norm = nn.BatchNorm2d(num_features=out_channels, eps=config.batch_norm_eps)
+ else:
+ self.horizontal_conv, self.horizontal_batch_norm = None, None
+
+ self.rbr_identity = (
+ nn.BatchNorm2d(num_features=in_channels, eps=config.batch_norm_eps)
+ if out_channels == in_channels and stride == 1
+ else None
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ main_outputs = self.main_conv(hidden_states)
+ main_outputs = self.main_batch_norm(main_outputs)
+
+ # applies a convolution with a vertical kernel
+ if self.vertical_conv is not None:
+ vertical_outputs = self.vertical_conv(hidden_states)
+ vertical_outputs = self.vertical_batch_norm(vertical_outputs)
+ main_outputs = main_outputs + vertical_outputs
+
+ # applies a convolution with a horizontal kernel
+ if self.horizontal_conv is not None:
+ horizontal_outputs = self.horizontal_conv(hidden_states)
+ horizontal_outputs = self.horizontal_batch_norm(horizontal_outputs)
+ main_outputs = main_outputs + horizontal_outputs
+
+ if self.rbr_identity is not None:
+ id_out = self.rbr_identity(hidden_states)
+ main_outputs = main_outputs + id_out
+
+ return self.activation_function(main_outputs)
+
+
+class TextNetStage(nn.Module):
+ def __init__(self, config: TextNetConfig, depth: int):
+ super().__init__()
+ kernel_size = config.conv_layer_kernel_sizes[depth]
+ stride = config.conv_layer_strides[depth]
+
+ num_layers = len(kernel_size)
+ stage_in_channel_size = config.hidden_sizes[depth]
+ stage_out_channel_size = config.hidden_sizes[depth + 1]
+
+ in_channels = [stage_in_channel_size] + [stage_out_channel_size] * (num_layers - 1)
+ out_channels = [stage_out_channel_size] * num_layers
+
+ stage = []
+ for stage_config in zip(in_channels, out_channels, kernel_size, stride):
+ stage.append(TextNetRepConvLayer(config, *stage_config))
+ self.stage = nn.ModuleList(stage)
+
+ def forward(self, hidden_state):
+ for block in self.stage:
+ hidden_state = block(hidden_state)
+ return hidden_state
+
+
+class TextNetEncoder(nn.Module):
+ def __init__(self, config: TextNetConfig):
+ super().__init__()
+
+ stages = []
+ num_stages = len(config.conv_layer_kernel_sizes)
+ for stage_ix in range(num_stages):
+ stages.append(TextNetStage(config, stage_ix))
+
+ self.stages = nn.ModuleList(stages)
+
+ def forward(
+ self,
+ hidden_state: torch.Tensor,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> BaseModelOutputWithNoAttention:
+ hidden_states = [hidden_state]
+ for stage in self.stages:
+ hidden_state = stage(hidden_state)
+ hidden_states.append(hidden_state)
+
+ if not return_dict:
+ output = (hidden_state,)
+ return output + (hidden_states,) if output_hidden_states else output
+
+ return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states)
+
+
+TEXTNET_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`TextNetConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+TEXTNET_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+ [`TextNetImageProcessor.__call__`] for details.
+
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class TextNetPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = TextNetConfig
+ base_model_prefix = "textnet"
+ main_input_name = "pixel_values"
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.BatchNorm2d):
+ module.weight.data.fill_(1.0)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+
+@add_start_docstrings(
+ "The bare Textnet model outputting raw features without any specific head on top.",
+ TEXTNET_START_DOCSTRING,
+)
+class TextNetModel(TextNetPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.stem = TextNetConvLayer(config)
+ self.encoder = TextNetEncoder(config)
+ self.pooler = nn.AdaptiveAvgPool2d((2, 2))
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(TEXTNET_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPoolingAndNoAttention,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
+ ) -> Union[Tuple[Any, List[Any]], Tuple[Any], BaseModelOutputWithPoolingAndNoAttention]:
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ hidden_state = self.stem(pixel_values)
+
+ encoder_outputs = self.encoder(
+ hidden_state, output_hidden_states=output_hidden_states, return_dict=return_dict
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ pooled_output = self.pooler(last_hidden_state)
+
+ if not return_dict:
+ output = (last_hidden_state, pooled_output)
+ return output + (encoder_outputs[1],) if output_hidden_states else output
+
+ return BaseModelOutputWithPoolingAndNoAttention(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs[1] if output_hidden_states else None,
+ )
+
+
+@add_start_docstrings(
+ """
+ TextNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
+ ImageNet.
+ """,
+ TEXTNET_START_DOCSTRING,
+)
+class TextNetForImageClassification(TextNetPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.textnet = TextNetModel(config)
+ self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
+ self.flatten = nn.Flatten()
+ self.fc = nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
+
+ # classification head
+ self.classifier = nn.ModuleList([self.avg_pool, self.flatten])
+
+ # initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(TEXTNET_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=ImageClassifierOutputWithNoAttention, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> ImageClassifierOutputWithNoAttention:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ Returns:
+
+ Examples:
+ ```python
+ >>> import torch
+ >>> import requests
+ >>> from transformers import TextNetForImageClassification, TextNetImageProcessor
+ >>> from PIL import Image
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> processor = TextNetImageProcessor.from_pretrained("czczup/textnet-base")
+ >>> model = TextNetForImageClassification.from_pretrained("czczup/textnet-base")
+
+ >>> inputs = processor(images=image, return_tensors="pt", size={"height": 640, "width": 640})
+ >>> with torch.no_grad():
+ ... outputs = model(**inputs)
+ >>> outputs.logits.shape
+ torch.Size([1, 2])
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.textnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
+ last_hidden_state = outputs[0]
+ for layer in self.classifier:
+ last_hidden_state = layer(last_hidden_state)
+ logits = self.fc(last_hidden_state)
+ loss = None
+
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return (loss,) + output if loss is not None else output
+
+ return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
+
+
+@add_start_docstrings(
+ """
+ TextNet backbone, to be used with frameworks like DETR and MaskFormer.
+ """,
+ TEXTNET_START_DOCSTRING,
+)
+class TextNetBackbone(TextNetPreTrainedModel, BackboneMixin):
+ def __init__(self, config):
+ super().__init__(config)
+ super()._init_backbone(config)
+
+ self.textnet = TextNetModel(config)
+ self.num_features = config.hidden_sizes
+
+ # initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(TEXTNET_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
+ ) -> Union[Tuple[Tuple], BackboneOutput]:
+ """
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> import torch
+ >>> import requests
+ >>> from PIL import Image
+ >>> from transformers import AutoImageProcessor, AutoBackbone
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> processor = AutoImageProcessor.from_pretrained("czczup/textnet-base")
+ >>> model = AutoBackbone.from_pretrained("czczup/textnet-base")
+
+ >>> inputs = processor(image, return_tensors="pt")
+ >>> with torch.no_grad():
+ >>> outputs = model(**inputs)
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ outputs = self.textnet(pixel_values, output_hidden_states=True, return_dict=return_dict)
+
+ hidden_states = outputs.hidden_states if return_dict else outputs[2]
+
+ feature_maps = ()
+ for idx, stage in enumerate(self.stage_names):
+ if stage in self.out_features:
+ feature_maps += (hidden_states[idx],)
+
+ if not return_dict:
+ output = (feature_maps,)
+ if output_hidden_states:
+ hidden_states = outputs.hidden_states if return_dict else outputs[2]
+ output += (hidden_states,)
+ return output
+
+ return BackboneOutput(
+ feature_maps=feature_maps,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=None,
+ )
+
+
+__all__ = ["TextNetBackbone", "TextNetModel", "TextNetPreTrainedModel", "TextNetForImageClassification"]
diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py
index 92711c27b3..20d449755f 100644
--- a/src/transformers/utils/dummy_pt_objects.py
+++ b/src/transformers/utils/dummy_pt_objects.py
@@ -9158,6 +9158,34 @@ def load_tf_weights_in_tapas(*args, **kwargs):
requires_backends(load_tf_weights_in_tapas, ["torch"])
+class TextNetBackbone(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class TextNetForImageClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class TextNetModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class TextNetPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
class TimeSeriesTransformerForPrediction(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py
index 3ebda4404a..c51feffc59 100644
--- a/src/transformers/utils/dummy_vision_objects.py
+++ b/src/transformers/utils/dummy_vision_objects.py
@@ -618,6 +618,13 @@ class Swin2SRImageProcessor(metaclass=DummyObject):
requires_backends(self, ["vision"])
+class TextNetImageProcessor(metaclass=DummyObject):
+ _backends = ["vision"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["vision"])
+
+
class TvpImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
diff --git a/tests/models/textnet/__init__.py b/tests/models/textnet/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/models/textnet/test_image_processing_textnet.py b/tests/models/textnet/test_image_processing_textnet.py
new file mode 100644
index 0000000000..4fcd93e872
--- /dev/null
+++ b/tests/models/textnet/test_image_processing_textnet.py
@@ -0,0 +1,126 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import unittest
+
+from transformers.testing_utils import require_torch, require_vision
+from transformers.utils import is_vision_available
+
+from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
+
+
+if is_vision_available():
+ from transformers import TextNetImageProcessor
+
+
+class TextNetImageProcessingTester(unittest.TestCase):
+ def __init__(
+ self,
+ parent,
+ batch_size=7,
+ num_channels=3,
+ image_size=18,
+ min_resolution=30,
+ max_resolution=400,
+ do_resize=True,
+ size=None,
+ size_divisor=32,
+ do_center_crop=True,
+ crop_size=None,
+ do_normalize=True,
+ image_mean=[0.48145466, 0.4578275, 0.40821073],
+ image_std=[0.26862954, 0.26130258, 0.27577711],
+ do_convert_rgb=True,
+ ):
+ size = size if size is not None else {"shortest_edge": 20}
+ crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.min_resolution = min_resolution
+ self.max_resolution = max_resolution
+ self.do_resize = do_resize
+ self.size = size
+ self.size_divisor = size_divisor
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean
+ self.image_std = image_std
+ self.do_convert_rgb = do_convert_rgb
+
+ def prepare_image_processor_dict(self):
+ return {
+ "do_resize": self.do_resize,
+ "size": self.size,
+ "size_divisor": self.size_divisor,
+ "do_center_crop": self.do_center_crop,
+ "crop_size": self.crop_size,
+ "do_normalize": self.do_normalize,
+ "image_mean": self.image_mean,
+ "image_std": self.image_std,
+ "do_convert_rgb": self.do_convert_rgb,
+ }
+
+ def expected_output_image_shape(self, images):
+ return self.num_channels, self.crop_size["height"], self.crop_size["width"]
+
+ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
+ return prepare_image_inputs(
+ batch_size=self.batch_size,
+ num_channels=self.num_channels,
+ min_resolution=self.min_resolution,
+ max_resolution=self.max_resolution,
+ equal_resolution=equal_resolution,
+ numpify=numpify,
+ torchify=torchify,
+ )
+
+
+@require_torch
+@require_vision
+class TextNetImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
+ image_processing_class = TextNetImageProcessor if is_vision_available() else None
+
+ def setUp(self):
+ super().setUp()
+ self.image_processor_tester = TextNetImageProcessingTester(self)
+
+ @property
+ def image_processor_dict(self):
+ return self.image_processor_tester.prepare_image_processor_dict()
+
+ def test_image_processor_properties(self):
+ image_processing = self.image_processing_class(**self.image_processor_dict)
+ self.assertTrue(hasattr(image_processing, "do_resize"))
+ self.assertTrue(hasattr(image_processing, "size"))
+ self.assertTrue(hasattr(image_processing, "size_divisor"))
+ self.assertTrue(hasattr(image_processing, "do_center_crop"))
+ self.assertTrue(hasattr(image_processing, "center_crop"))
+ self.assertTrue(hasattr(image_processing, "do_normalize"))
+ self.assertTrue(hasattr(image_processing, "image_mean"))
+ self.assertTrue(hasattr(image_processing, "image_std"))
+ self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
+
+ def test_image_processor_from_dict_with_kwargs(self):
+ image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
+ self.assertEqual(image_processor.size, {"shortest_edge": 20})
+ self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
+
+ image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
+ self.assertEqual(image_processor.size, {"shortest_edge": 42})
+ self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
diff --git a/tests/models/textnet/test_modeling_textnet.py b/tests/models/textnet/test_modeling_textnet.py
new file mode 100644
index 0000000000..cf5e48506e
--- /dev/null
+++ b/tests/models/textnet/test_modeling_textnet.py
@@ -0,0 +1,348 @@
+# coding=utf-8
+# Copyright 2024 the Fast authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch TextNet model."""
+
+import unittest
+
+import requests
+from PIL import Image
+
+from transformers import TextNetConfig
+from transformers.models.textnet.image_processing_textnet import TextNetImageProcessor
+from transformers.testing_utils import (
+ require_torch,
+ require_vision,
+ slow,
+ torch_device,
+)
+from transformers.utils import is_torch_available
+
+from ...test_backbone_common import BackboneTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+from ...test_pipeline_mixin import PipelineTesterMixin
+
+
+if is_torch_available():
+ import torch
+ from torch import nn
+
+ from transformers import TextNetBackbone, TextNetForImageClassification, TextNetModel
+
+
+class TextNetConfigTester(ConfigTester):
+ def create_and_test_config_common_properties(self):
+ config = self.config_class(**self.inputs_dict)
+ self.parent.assertTrue(hasattr(config, "hidden_sizes"))
+ self.parent.assertTrue(hasattr(config, "num_attention_heads"))
+ self.parent.assertTrue(hasattr(config, "num_encoder_blocks"))
+
+
+class TextNetModelTester:
+ def __init__(
+ self,
+ parent,
+ stem_kernel_size=3,
+ stem_stride=2,
+ stem_in_channels=3,
+ stem_out_channels=32,
+ stem_act_func="relu",
+ dropout_rate=0,
+ ops_order="weight_bn_act",
+ conv_layer_kernel_sizes=[
+ [[3, 3]],
+ [[3, 3]],
+ [[3, 3]],
+ [[3, 3]],
+ ],
+ conv_layer_strides=[
+ [2],
+ [2],
+ [2],
+ [2],
+ ],
+ out_features=["stage1", "stage2", "stage3", "stage4"],
+ out_indices=[1, 2, 3, 4],
+ batch_size=3,
+ num_channels=3,
+ image_size=[32, 32],
+ is_training=True,
+ use_labels=True,
+ num_labels=3,
+ hidden_sizes=[32, 32, 32, 32, 32],
+ ):
+ self.parent = parent
+ self.stem_kernel_size = stem_kernel_size
+ self.stem_stride = stem_stride
+ self.stem_in_channels = stem_in_channels
+ self.stem_out_channels = stem_out_channels
+ self.act_func = stem_act_func
+ self.dropout_rate = dropout_rate
+ self.ops_order = ops_order
+ self.conv_layer_kernel_sizes = conv_layer_kernel_sizes
+ self.conv_layer_strides = conv_layer_strides
+
+ self.out_features = out_features
+ self.out_indices = out_indices
+
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.num_labels = num_labels
+ self.hidden_sizes = hidden_sizes
+
+ self.num_stages = 5
+
+ def get_config(self):
+ return TextNetConfig(
+ stem_kernel_size=self.stem_kernel_size,
+ stem_stride=self.stem_stride,
+ stem_num_channels=self.stem_in_channels,
+ stem_out_channels=self.stem_out_channels,
+ act_func=self.act_func,
+ dropout_rate=self.dropout_rate,
+ ops_order=self.ops_order,
+ conv_layer_kernel_sizes=self.conv_layer_kernel_sizes,
+ conv_layer_strides=self.conv_layer_strides,
+ out_features=self.out_features,
+ out_indices=self.out_indices,
+ hidden_sizes=self.hidden_sizes,
+ image_size=self.image_size,
+ )
+
+ def create_and_check_model(self, config, pixel_values, labels):
+ model = TextNetModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+ scale_h = self.image_size[0] // 32
+ scale_w = self.image_size[1] // 32
+ self.parent.assertEqual(
+ result.last_hidden_state.shape,
+ (self.batch_size, self.hidden_sizes[-1], scale_h, scale_w),
+ )
+
+ def create_and_check_for_image_classification(self, config, pixel_values, labels):
+ config.num_labels = self.num_labels
+ model = TextNetForImageClassification(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values, labels=labels)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size[0], self.image_size[1]])
+
+ labels = None
+ if self.use_labels:
+ labels = ids_tensor([self.batch_size], self.num_labels)
+
+ config = self.get_config()
+
+ return config, pixel_values, labels
+
+ def create_and_check_backbone(self, config, pixel_values, labels):
+ model = TextNetBackbone(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+
+ # verify feature maps
+ self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
+ scale_h = self.image_size[0] // 32
+ scale_w = self.image_size[1] // 32
+ self.parent.assertListEqual(
+ list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 8 * scale_h, 8 * scale_w]
+ )
+
+ # verify channels
+ self.parent.assertEqual(len(model.channels), len(config.out_features))
+ self.parent.assertListEqual(model.channels, config.hidden_sizes[1:])
+
+ # verify backbone works with out_features=None
+ config.out_features = None
+ model = TextNetBackbone(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+
+ # verify feature maps
+ self.parent.assertEqual(len(result.feature_maps), 1)
+ scale_h = self.image_size[0] // 32
+ scale_w = self.image_size[1] // 32
+ self.parent.assertListEqual(
+ list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[0], scale_h, scale_w]
+ )
+
+ # verify channels
+ self.parent.assertEqual(len(model.channels), 1)
+ self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]])
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values, labels = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_torch
+class TextNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
+ """
+ Here we also overwrite some tests of test_modeling_common.py, as TextNet does not use input_ids, inputs_embeds,
+ attention_mask and seq_length.
+ """
+
+ all_model_classes = (TextNetModel, TextNetForImageClassification, TextNetBackbone) if is_torch_available() else ()
+ pipeline_model_mapping = (
+ {"feature-extraction": TextNetModel, "image-classification": TextNetForImageClassification}
+ if is_torch_available()
+ else {}
+ )
+
+ fx_compatible = False
+ test_pruning = False
+ test_resize_embeddings = False
+ test_head_masking = False
+ has_attentions = False
+
+ def setUp(self):
+ self.model_tester = TextNetModelTester(self)
+ self.config_tester = TextNetConfigTester(self, config_class=TextNetConfig, has_text_modality=False)
+
+ @unittest.skip(reason="TextNet does not output attentions")
+ def test_attention_outputs(self):
+ pass
+
+ @unittest.skip(reason="TextNet does not have input/output embeddings")
+ def test_model_get_set_embeddings(self):
+ pass
+
+ @unittest.skip(reason="TextNet does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ @unittest.skip(reason="TextNet does not support input and output embeddings")
+ def test_model_common_attributes(self):
+ pass
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_backbone(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_backbone(*config_and_inputs)
+
+ def test_initialization(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config=config)
+ for name, module in model.named_modules():
+ if isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
+ self.assertTrue(
+ torch.all(module.weight == 1),
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+ self.assertTrue(
+ torch.all(module.bias == 0),
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+
+ def test_hidden_states_output(self):
+ def check_hidden_states_output(inputs_dict, config, model_class):
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
+
+ self.assertEqual(len(hidden_states), self.model_tester.num_stages)
+
+ self.assertListEqual(
+ list(hidden_states[0].shape[-2:]),
+ [self.model_tester.image_size[0] // 2, self.model_tester.image_size[1] // 2],
+ )
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ layers_type = ["preactivation", "bottleneck"]
+ for model_class in self.all_model_classes:
+ for layer_type in layers_type:
+ config.layer_type = layer_type
+ inputs_dict["output_hidden_states"] = True
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ @unittest.skip(reason="TextNet does not use feedforward chunking")
+ def test_feed_forward_chunking(self):
+ pass
+
+ def test_for_image_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
+
+ @slow
+ def test_model_from_pretrained(self):
+ model_name = "czczup/textnet-base"
+ model = TextNetModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+@require_torch
+@require_vision
+class TextNetModelIntegrationTest(unittest.TestCase):
+ @slow
+ def test_inference_no_head(self):
+ processor = TextNetImageProcessor.from_pretrained("czczup/textnet-base")
+ model = TextNetModel.from_pretrained("czczup/textnet-base").to(torch_device)
+
+ # prepare image
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ image = Image.open(requests.get(url, stream=True).raw)
+ inputs = processor(images=image, return_tensors="pt").to(torch_device)
+
+ # forward pass
+ with torch.no_grad():
+ output = model(**inputs)
+
+ # verify logits
+ self.assertEqual(output.logits.shape, torch.Size([1, 2]))
+ expected_slice_backbone = torch.tensor(
+ [0.9210, 0.6099, 0.0000, 0.0000, 0.0000, 0.0000, 3.2207, 2.6602, 1.8925, 0.0000],
+ device=torch_device,
+ )
+ self.assertTrue(torch.allclose(output.feature_maps[-1][0][10][12][:10], expected_slice_backbone, atol=1e-3))
+
+
+@require_torch
+# Copied from tests.models.bit.test_modeling_bit.BitBackboneTest with Bit->TextNet
+class TextNetBackboneTest(BackboneTesterMixin, unittest.TestCase):
+ all_model_classes = (TextNetBackbone,) if is_torch_available() else ()
+ config_class = TextNetConfig
+
+ has_attentions = False
+
+ def setUp(self):
+ self.model_tester = TextNetModelTester(self)
diff --git a/utils/check_repo.py b/utils/check_repo.py
index 365213f649..b5792eaea6 100644
--- a/utils/check_repo.py
+++ b/utils/check_repo.py
@@ -1020,6 +1020,7 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
"ResNetBackbone",
"SwinBackbone",
"Swinv2Backbone",
+ "TextNetBackbone",
"TimmBackbone",
"TimmBackboneConfig",
"VitDetBackbone",