Decorators for deprecation and named arguments validation (#30799)
* Fix do_reduce_labels for maskformer image processor * Deprecate reduce_labels in favor to do_reduce_labels * Deprecate reduce_labels in favor to do_reduce_labels (segformer) * Deprecate reduce_labels in favor to do_reduce_labels (oneformer) * Deprecate reduce_labels in favor to do_reduce_labels (maskformer) * Deprecate reduce_labels in favor to do_reduce_labels (mask2former) * Fix typo * Update mask2former test * fixup * Update segmentation examples * Update docs * Fixup * Imports fixup * Add deprecation decorator draft * Add deprecation decorator * Fixup * Add deprecate_kwarg decorator * Validate kwargs decorator * Kwargs validation (beit) * fixup * Kwargs validation (mask2former) * Kwargs validation (maskformer) * Kwargs validation (oneformer) * Kwargs validation (segformer) * Better message * Fix oneformer processor save-load test * Update src/transformers/utils/deprecation.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/utils/deprecation.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/utils/deprecation.py Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * Update src/transformers/utils/deprecation.py Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * Better handle classmethod warning * Fix typo, remove warn * Add header * Docs and `additional_message` * Move to filter decorator ot generic * Proper deprecation for semantic segm scripts * Add to __init__ and update import * Basic tests for filter decorator * Fix doc * Override `to_dict()` to pop depracated `_max_size` * Pop unused parameters * Fix trailing whitespace * Add test for deprecation * Add deprecation warning control parameter * Update generic test * Fixup deprecation tests * Introduce init service kwargs * Revert popping unused params * Revert oneformer test * Allow "metadata" to pass * Better docs * Fix test * Add notion in docstring * Fix notification for both names * Add func name to warning message * Fixup --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
4fa4dcb2be
commit
517df566f5
@@ -66,12 +66,12 @@ of the model was contributed by [sayakpaul](https://huggingface.co/sayakpaul). T
|
|||||||
important preprocessing step is that images and segmentation maps are randomly cropped and padded to the same size,
|
important preprocessing step is that images and segmentation maps are randomly cropped and padded to the same size,
|
||||||
such as 512x512 or 640x640, after which they are normalized.
|
such as 512x512 or 640x640, after which they are normalized.
|
||||||
- One additional thing to keep in mind is that one can initialize [`SegformerImageProcessor`] with
|
- One additional thing to keep in mind is that one can initialize [`SegformerImageProcessor`] with
|
||||||
`reduce_labels` set to `True` or `False`. In some datasets (like ADE20k), the 0 index is used in the annotated
|
`do_reduce_labels` set to `True` or `False`. In some datasets (like ADE20k), the 0 index is used in the annotated
|
||||||
segmentation maps for background. However, ADE20k doesn't include the "background" class in its 150 labels.
|
segmentation maps for background. However, ADE20k doesn't include the "background" class in its 150 labels.
|
||||||
Therefore, `reduce_labels` is used to reduce all labels by 1, and to make sure no loss is computed for the
|
Therefore, `do_reduce_labels` is used to reduce all labels by 1, and to make sure no loss is computed for the
|
||||||
background class (i.e. it replaces 0 in the annotated maps by 255, which is the *ignore_index* of the loss function
|
background class (i.e. it replaces 0 in the annotated maps by 255, which is the *ignore_index* of the loss function
|
||||||
used by [`SegformerForSemanticSegmentation`]). However, other datasets use the 0 index as
|
used by [`SegformerForSemanticSegmentation`]). However, other datasets use the 0 index as
|
||||||
background class and include this class as part of all labels. In that case, `reduce_labels` should be set to
|
background class and include this class as part of all labels. In that case, `do_reduce_labels` should be set to
|
||||||
`False`, as loss should also be computed for the background class.
|
`False`, as loss should also be computed for the background class.
|
||||||
- As most models, SegFormer comes in different sizes, the details of which can be found in the table below
|
- As most models, SegFormer comes in different sizes, the details of which can be found in the table below
|
||||||
(taken from Table 7 of the [original paper](https://arxiv.org/abs/2105.15203)).
|
(taken from Table 7 of the [original paper](https://arxiv.org/abs/2105.15203)).
|
||||||
|
|||||||
@@ -310,13 +310,13 @@ As an example, take a look at this [example dataset](https://huggingface.co/data
|
|||||||
|
|
||||||
### Preprocess
|
### Preprocess
|
||||||
|
|
||||||
The next step is to load a SegFormer image processor to prepare the images and annotations for the model. Some datasets, like this one, use the zero-index as the background class. However, the background class isn't actually included in the 150 classes, so you'll need to set `reduce_labels=True` to subtract one from all the labels. The zero-index is replaced by `255` so it's ignored by SegFormer's loss function:
|
The next step is to load a SegFormer image processor to prepare the images and annotations for the model. Some datasets, like this one, use the zero-index as the background class. However, the background class isn't actually included in the 150 classes, so you'll need to set `do_reduce_labels=True` to subtract one from all the labels. The zero-index is replaced by `255` so it's ignored by SegFormer's loss function:
|
||||||
|
|
||||||
```py
|
```py
|
||||||
>>> from transformers import AutoImageProcessor
|
>>> from transformers import AutoImageProcessor
|
||||||
|
|
||||||
>>> checkpoint = "nvidia/mit-b0"
|
>>> checkpoint = "nvidia/mit-b0"
|
||||||
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint, reduce_labels=True)
|
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint, do_reduce_labels=True)
|
||||||
```
|
```
|
||||||
|
|
||||||
<frameworkcontent>
|
<frameworkcontent>
|
||||||
|
|||||||
@@ -96,13 +96,13 @@ pip install -q datasets transformers evaluate
|
|||||||
|
|
||||||
## Preprocess
|
## Preprocess
|
||||||
|
|
||||||
次のステップでは、SegFormer 画像プロセッサをロードして、モデルの画像と注釈を準備します。このデータセットのような一部のデータセットは、バックグラウンド クラスとしてゼロインデックスを使用します。ただし、実際には背景クラスは 150 個のクラスに含まれていないため、`reduce_labels=True`を設定してすべてのラベルから 1 つを引く必要があります。ゼロインデックスは `255` に置き換えられるため、SegFormer の損失関数によって無視されます。
|
次のステップでは、SegFormer 画像プロセッサをロードして、モデルの画像と注釈を準備します。このデータセットのような一部のデータセットは、バックグラウンド クラスとしてゼロインデックスを使用します。ただし、実際には背景クラスは 150 個のクラスに含まれていないため、`do_reduce_labels=True`を設定してすべてのラベルから 1 つを引く必要があります。ゼロインデックスは `255` に置き換えられるため、SegFormer の損失関数によって無視されます。
|
||||||
|
|
||||||
```py
|
```py
|
||||||
>>> from transformers import AutoImageProcessor
|
>>> from transformers import AutoImageProcessor
|
||||||
|
|
||||||
>>> checkpoint = "nvidia/mit-b0"
|
>>> checkpoint = "nvidia/mit-b0"
|
||||||
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint, reduce_labels=True)
|
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint, do_reduce_labels=True)
|
||||||
```
|
```
|
||||||
|
|
||||||
<frameworkcontent>
|
<frameworkcontent>
|
||||||
|
|||||||
@@ -96,13 +96,13 @@ pip install -q datasets transformers evaluate
|
|||||||
|
|
||||||
## Preprocess
|
## Preprocess
|
||||||
|
|
||||||
次のステップでは、SegFormer 画像プロセッサをロードして、モデルの画像と注釈を準備します。このデータセットのような一部のデータセットは、バックグラウンド クラスとしてゼロインデックスを使用します。ただし、実際には背景クラスは 150 個のクラスに含まれていないため、`reduce_labels=True`を設定してすべてのラベルから 1 つを引く必要があります。ゼロインデックスは `255` に置き換えられるため、SegFormer の損失関数によって無視されます。
|
次のステップでは、SegFormer 画像プロセッサをロードして、モデルの画像と注釈を準備します。このデータセットのような一部のデータセットは、バックグラウンド クラスとしてゼロインデックスを使用します。ただし、実際には背景クラスは 150 個のクラスに含まれていないため、`do_reduce_labels=True`を設定してすべてのラベルから 1 つを引く必要があります。ゼロインデックスは `255` に置き換えられるため、SegFormer の損失関数によって無視されます。
|
||||||
|
|
||||||
```py
|
```py
|
||||||
>>> from transformers import AutoImageProcessor
|
>>> from transformers import AutoImageProcessor
|
||||||
|
|
||||||
>>> checkpoint = "nvidia/mit-b0"
|
>>> checkpoint = "nvidia/mit-b0"
|
||||||
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint, reduce_labels=True)
|
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint, do_reduce_labels=True)
|
||||||
```
|
```
|
||||||
|
|
||||||
<frameworkcontent>
|
<frameworkcontent>
|
||||||
|
|||||||
@@ -95,13 +95,13 @@ pip install -q datasets transformers evaluate
|
|||||||
|
|
||||||
## 전처리하기[[preprocess]
|
## 전처리하기[[preprocess]
|
||||||
|
|
||||||
다음 단계는 모델에 사용할 이미지와 주석을 준비하기 위해 SegFormer 이미지 프로세서를 불러오는 것입니다. 우리가 사용하는 데이터 세트와 같은 일부 데이터 세트는 배경 클래스로 제로 인덱스를 사용합니다. 하지만 배경 클래스는 150개의 클래스에 실제로는 포함되지 않기 때문에 `reduce_labels=True` 를 설정해 모든 레이블에서 배경 클래스를 제거해야 합니다. 제로 인덱스는 `255`로 대체되므로 SegFormer의 손실 함수에서 무시됩니다:
|
다음 단계는 모델에 사용할 이미지와 주석을 준비하기 위해 SegFormer 이미지 프로세서를 불러오는 것입니다. 우리가 사용하는 데이터 세트와 같은 일부 데이터 세트는 배경 클래스로 제로 인덱스를 사용합니다. 하지만 배경 클래스는 150개의 클래스에 실제로는 포함되지 않기 때문에 `do_reduce_labels=True` 를 설정해 모든 레이블에서 배경 클래스를 제거해야 합니다. 제로 인덱스는 `255`로 대체되므로 SegFormer의 손실 함수에서 무시됩니다:
|
||||||
|
|
||||||
```py
|
```py
|
||||||
>>> from transformers import AutoImageProcessor
|
>>> from transformers import AutoImageProcessor
|
||||||
|
|
||||||
>>> checkpoint = "nvidia/mit-b0"
|
>>> checkpoint = "nvidia/mit-b0"
|
||||||
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint, reduce_labels=True)
|
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint, do_reduce_labels=True)
|
||||||
```
|
```
|
||||||
|
|
||||||
<frameworkcontent>
|
<frameworkcontent>
|
||||||
|
|||||||
@@ -204,4 +204,4 @@ For visualization of the segmentation maps, we refer to the [example notebook](h
|
|||||||
|
|
||||||
Some datasets, like [`scene_parse_150`](https://huggingface.co/datasets/scene_parse_150), contain a "background" label that is not part of the classes. The Scene Parse 150 dataset for instance contains labels between 0 and 150, with 0 being the background class, and 1 to 150 being actual class names (like "tree", "person", etc.). For these kind of datasets, one replaces the background label (0) by 255, which is the `ignore_index` of the PyTorch model's loss function, and reduces all labels by 1. This way, the `labels` are PyTorch tensors containing values between 0 and 149, and 255 for all background/padding.
|
Some datasets, like [`scene_parse_150`](https://huggingface.co/datasets/scene_parse_150), contain a "background" label that is not part of the classes. The Scene Parse 150 dataset for instance contains labels between 0 and 150, with 0 being the background class, and 1 to 150 being actual class names (like "tree", "person", etc.). For these kind of datasets, one replaces the background label (0) by 255, which is the `ignore_index` of the PyTorch model's loss function, and reduces all labels by 1. This way, the `labels` are PyTorch tensors containing values between 0 and 149, and 255 for all background/padding.
|
||||||
|
|
||||||
In case you're training on such a dataset, make sure to set the ``reduce_labels`` flag, which will take care of this.
|
In case you're training on such a dataset, make sure to set the ``do_reduce_labels`` flag, which will take care of this.
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -108,6 +109,10 @@ class DataTrainingArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
do_reduce_labels: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to reduce all labels by 1 and replace background by 255."},
|
||||||
|
)
|
||||||
reduce_labels: Optional[bool] = field(
|
reduce_labels: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to reduce all labels by 1 and replace background by 255."},
|
metadata={"help": "Whether or not to reduce all labels by 1 and replace background by 255."},
|
||||||
@@ -118,6 +123,12 @@ class DataTrainingArguments:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You must specify either a dataset name from the hub or a train and/or validation directory."
|
"You must specify either a dataset name from the hub or a train and/or validation directory."
|
||||||
)
|
)
|
||||||
|
if self.reduce_labels:
|
||||||
|
self.do_reduce_labels = self.reduce_labels
|
||||||
|
warnings.warn(
|
||||||
|
"The `reduce_labels` argument is deprecated and will be removed in v4.45. Please use `do_reduce_labels` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -303,14 +314,12 @@ def main():
|
|||||||
)
|
)
|
||||||
image_processor = AutoImageProcessor.from_pretrained(
|
image_processor = AutoImageProcessor.from_pretrained(
|
||||||
model_args.image_processor_name or model_args.model_name_or_path,
|
model_args.image_processor_name or model_args.model_name_or_path,
|
||||||
|
do_reduce_labels=data_args.do_reduce_labels,
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
revision=model_args.model_revision,
|
revision=model_args.model_revision,
|
||||||
token=model_args.token,
|
token=model_args.token,
|
||||||
trust_remote_code=model_args.trust_remote_code,
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
)
|
)
|
||||||
# `reduce_labels` is a property of dataset labels, in case we use image_processor
|
|
||||||
# pretrained on another dataset we should override the default setting
|
|
||||||
image_processor.do_reduce_labels = data_args.reduce_labels
|
|
||||||
|
|
||||||
# Define transforms to be applied to each image and target.
|
# Define transforms to be applied to each image and target.
|
||||||
if "shortest_edge" in image_processor.size:
|
if "shortest_edge" in image_processor.size:
|
||||||
@@ -322,7 +331,7 @@ def main():
|
|||||||
[
|
[
|
||||||
A.Lambda(
|
A.Lambda(
|
||||||
name="reduce_labels",
|
name="reduce_labels",
|
||||||
mask=reduce_labels_transform if data_args.reduce_labels else None,
|
mask=reduce_labels_transform if data_args.do_reduce_labels else None,
|
||||||
p=1.0,
|
p=1.0,
|
||||||
),
|
),
|
||||||
# pad image with 255, because it is ignored by loss
|
# pad image with 255, because it is ignored by loss
|
||||||
@@ -337,7 +346,7 @@ def main():
|
|||||||
[
|
[
|
||||||
A.Lambda(
|
A.Lambda(
|
||||||
name="reduce_labels",
|
name="reduce_labels",
|
||||||
mask=reduce_labels_transform if data_args.reduce_labels else None,
|
mask=reduce_labels_transform if data_args.do_reduce_labels else None,
|
||||||
p=1.0,
|
p=1.0,
|
||||||
),
|
),
|
||||||
A.Resize(height=height, width=width, p=1.0),
|
A.Resize(height=height, width=width, p=1.0),
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -85,6 +86,11 @@ def parse_args():
|
|||||||
help="Name of the dataset on the hub.",
|
help="Name of the dataset on the hub.",
|
||||||
default="segments/sidewalk-semantic",
|
default="segments/sidewalk-semantic",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--do_reduce_labels",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether or not to reduce all labels by 1 and replace background by 255.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--reduce_labels",
|
"--reduce_labels",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@@ -219,6 +225,14 @@ def parse_args():
|
|||||||
"Need an `output_dir` to create a repo when `--push_to_hub` or `with_tracking` is specified."
|
"Need an `output_dir` to create a repo when `--push_to_hub` or `with_tracking` is specified."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Deprecation
|
||||||
|
if args.reduce_labels:
|
||||||
|
args.do_reduce_labels = args.reduce_labels
|
||||||
|
warnings.warn(
|
||||||
|
"The `reduce_labels` argument is deprecated and will be removed in v4.45. Please use `do_reduce_labels` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
|
||||||
if args.output_dir is not None:
|
if args.output_dir is not None:
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
@@ -315,11 +329,11 @@ def main():
|
|||||||
args.model_name_or_path, trust_remote_code=args.trust_remote_code
|
args.model_name_or_path, trust_remote_code=args.trust_remote_code
|
||||||
)
|
)
|
||||||
model = AutoModelForSemanticSegmentation.from_pretrained(
|
model = AutoModelForSemanticSegmentation.from_pretrained(
|
||||||
args.model_name_or_path, config=config, trust_remote_code=args.trust_remote_code
|
args.model_name_or_path,
|
||||||
|
config=config,
|
||||||
|
trust_remote_code=args.trust_remote_code,
|
||||||
|
do_reduce_labels=args.do_reduce_labels,
|
||||||
)
|
)
|
||||||
# `reduce_labels` is a property of dataset labels, in case we use image_processor
|
|
||||||
# pretrained on another dataset we should override the default setting
|
|
||||||
image_processor.do_reduce_labels = args.reduce_labels
|
|
||||||
|
|
||||||
# Define transforms to be applied to each image and target.
|
# Define transforms to be applied to each image and target.
|
||||||
if "shortest_edge" in image_processor.size:
|
if "shortest_edge" in image_processor.size:
|
||||||
@@ -329,7 +343,7 @@ def main():
|
|||||||
height, width = image_processor.size["height"], image_processor.size["width"]
|
height, width = image_processor.size["height"], image_processor.size["width"]
|
||||||
train_transforms = A.Compose(
|
train_transforms = A.Compose(
|
||||||
[
|
[
|
||||||
A.Lambda(name="reduce_labels", mask=reduce_labels_transform if args.reduce_labels else None, p=1.0),
|
A.Lambda(name="reduce_labels", mask=reduce_labels_transform if args.do_reduce_labels else None, p=1.0),
|
||||||
# pad image with 255, because it is ignored by loss
|
# pad image with 255, because it is ignored by loss
|
||||||
A.PadIfNeeded(min_height=height, min_width=width, border_mode=0, value=255, p=1.0),
|
A.PadIfNeeded(min_height=height, min_width=width, border_mode=0, value=255, p=1.0),
|
||||||
A.RandomCrop(height=height, width=width, p=1.0),
|
A.RandomCrop(height=height, width=width, p=1.0),
|
||||||
@@ -340,7 +354,7 @@ def main():
|
|||||||
)
|
)
|
||||||
val_transforms = A.Compose(
|
val_transforms = A.Compose(
|
||||||
[
|
[
|
||||||
A.Lambda(name="reduce_labels", mask=reduce_labels_transform if args.reduce_labels else None, p=1.0),
|
A.Lambda(name="reduce_labels", mask=reduce_labels_transform if args.do_reduce_labels else None, p=1.0),
|
||||||
A.Resize(height=height, width=width, p=1.0),
|
A.Resize(height=height, width=width, p=1.0),
|
||||||
A.Normalize(mean=image_processor.image_mean, std=image_processor.image_std, max_pixel_value=255.0, p=1.0),
|
A.Normalize(mean=image_processor.image_mean, std=image_processor.image_std, max_pixel_value=255.0, p=1.0),
|
||||||
ToTensorV2(),
|
ToTensorV2(),
|
||||||
|
|||||||
@@ -48,6 +48,12 @@ if is_vision_available():
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
INIT_SERVICE_KWARGS = [
|
||||||
|
"processor_class",
|
||||||
|
"image_processor_type",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
# TODO: Move BatchFeature to be imported by both image_processing_utils and image_processing_utils
|
# TODO: Move BatchFeature to be imported by both image_processing_utils and image_processing_utils
|
||||||
# We override the class string here, but logic is the same.
|
# We override the class string here, but logic is the same.
|
||||||
class BatchFeature(BaseBatchFeature):
|
class BatchFeature(BaseBatchFeature):
|
||||||
|
|||||||
@@ -14,12 +14,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Image processor class for Beit."""
|
"""Image processor class for Beit."""
|
||||||
|
|
||||||
import warnings
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
from ...image_processing_utils import INIT_SERVICE_KWARGS, BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
from ...image_transforms import resize, to_channel_dimension_format
|
from ...image_transforms import resize, to_channel_dimension_format
|
||||||
from ...image_utils import (
|
from ...image_utils import (
|
||||||
IMAGENET_STANDARD_MEAN,
|
IMAGENET_STANDARD_MEAN,
|
||||||
@@ -32,10 +31,17 @@ from ...image_utils import (
|
|||||||
make_list_of_images,
|
make_list_of_images,
|
||||||
to_numpy_array,
|
to_numpy_array,
|
||||||
valid_images,
|
valid_images,
|
||||||
validate_kwargs,
|
|
||||||
validate_preprocess_arguments,
|
validate_preprocess_arguments,
|
||||||
)
|
)
|
||||||
from ...utils import TensorType, is_torch_available, is_torch_tensor, is_vision_available, logging
|
from ...utils import (
|
||||||
|
TensorType,
|
||||||
|
filter_out_non_signature_kwargs,
|
||||||
|
is_torch_available,
|
||||||
|
is_torch_tensor,
|
||||||
|
is_vision_available,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
from ...utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
@@ -93,6 +99,8 @@ class BeitImageProcessor(BaseImageProcessor):
|
|||||||
|
|
||||||
model_input_names = ["pixel_values"]
|
model_input_names = ["pixel_values"]
|
||||||
|
|
||||||
|
@deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.41.0")
|
||||||
|
@filter_out_non_signature_kwargs(extra=INIT_SERVICE_KWARGS)
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
do_resize: bool = True,
|
do_resize: bool = True,
|
||||||
@@ -108,13 +116,6 @@ class BeitImageProcessor(BaseImageProcessor):
|
|||||||
do_reduce_labels: bool = False,
|
do_reduce_labels: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
if "reduce_labels" in kwargs:
|
|
||||||
warnings.warn(
|
|
||||||
"The `reduce_labels` parameter is deprecated and will be removed in a future version. Please use"
|
|
||||||
" `do_reduce_labels` instead.",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
do_reduce_labels = kwargs.pop("reduce_labels")
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
size = size if size is not None else {"height": 256, "width": 256}
|
size = size if size is not None else {"height": 256, "width": 256}
|
||||||
size = get_size_dict(size)
|
size = get_size_dict(size)
|
||||||
@@ -131,34 +132,15 @@ class BeitImageProcessor(BaseImageProcessor):
|
|||||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
||||||
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
||||||
self.do_reduce_labels = do_reduce_labels
|
self.do_reduce_labels = do_reduce_labels
|
||||||
self._valid_processor_keys = [
|
|
||||||
"images",
|
|
||||||
"segmentation_maps",
|
|
||||||
"do_resize",
|
|
||||||
"size",
|
|
||||||
"resample",
|
|
||||||
"do_center_crop",
|
|
||||||
"crop_size",
|
|
||||||
"do_rescale",
|
|
||||||
"rescale_factor",
|
|
||||||
"do_normalize",
|
|
||||||
"image_mean",
|
|
||||||
"image_std",
|
|
||||||
"do_reduce_labels",
|
|
||||||
"return_tensors",
|
|
||||||
"data_format",
|
|
||||||
"input_data_format",
|
|
||||||
]
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
|
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
|
||||||
"""
|
"""
|
||||||
Overrides the `from_dict` method from the base class to make sure `reduce_labels` is updated if image processor
|
Overrides the `from_dict` method from the base class to save support of deprecated `reduce_labels` in old configs
|
||||||
is created using from_dict and kwargs e.g. `BeitImageProcessor.from_pretrained(checkpoint, reduce_labels=True)`
|
|
||||||
"""
|
"""
|
||||||
image_processor_dict = image_processor_dict.copy()
|
image_processor_dict = image_processor_dict.copy()
|
||||||
if "reduce_labels" in kwargs:
|
if "reduce_labels" in image_processor_dict:
|
||||||
image_processor_dict["reduce_labels"] = kwargs.pop("reduce_labels")
|
image_processor_dict["do_reduce_labels"] = image_processor_dict.pop("reduce_labels")
|
||||||
return super().from_dict(image_processor_dict, **kwargs)
|
return super().from_dict(image_processor_dict, **kwargs)
|
||||||
|
|
||||||
def resize(
|
def resize(
|
||||||
@@ -329,6 +311,8 @@ class BeitImageProcessor(BaseImageProcessor):
|
|||||||
# be passed in as positional arguments.
|
# be passed in as positional arguments.
|
||||||
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
|
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
|
||||||
|
|
||||||
|
@deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.41.0")
|
||||||
|
@filter_out_non_signature_kwargs()
|
||||||
def preprocess(
|
def preprocess(
|
||||||
self,
|
self,
|
||||||
images: ImageInput,
|
images: ImageInput,
|
||||||
@@ -347,7 +331,6 @@ class BeitImageProcessor(BaseImageProcessor):
|
|||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
**kwargs,
|
|
||||||
) -> PIL.Image.Image:
|
) -> PIL.Image.Image:
|
||||||
"""
|
"""
|
||||||
Preprocess an image or batch of images.
|
Preprocess an image or batch of images.
|
||||||
@@ -418,8 +401,6 @@ class BeitImageProcessor(BaseImageProcessor):
|
|||||||
image_std = image_std if image_std is not None else self.image_std
|
image_std = image_std if image_std is not None else self.image_std
|
||||||
do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
|
do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
|
||||||
|
|
||||||
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
|
|
||||||
|
|
||||||
images = make_list_of_images(images)
|
images = make_list_of_images(images)
|
||||||
|
|
||||||
if segmentation_maps is not None:
|
if segmentation_maps is not None:
|
||||||
|
|||||||
@@ -15,12 +15,11 @@
|
|||||||
"""Image processor class for Mask2Former."""
|
"""Image processor class for Mask2Former."""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import warnings
|
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
from ...image_processing_utils import INIT_SERVICE_KWARGS, BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
from ...image_transforms import (
|
from ...image_transforms import (
|
||||||
PaddingMode,
|
PaddingMode,
|
||||||
get_resize_output_image_size,
|
get_resize_output_image_size,
|
||||||
@@ -39,17 +38,18 @@ from ...image_utils import (
|
|||||||
is_scaled_image,
|
is_scaled_image,
|
||||||
to_numpy_array,
|
to_numpy_array,
|
||||||
valid_images,
|
valid_images,
|
||||||
validate_kwargs,
|
|
||||||
validate_preprocess_arguments,
|
validate_preprocess_arguments,
|
||||||
)
|
)
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
IMAGENET_DEFAULT_MEAN,
|
IMAGENET_DEFAULT_MEAN,
|
||||||
IMAGENET_DEFAULT_STD,
|
IMAGENET_DEFAULT_STD,
|
||||||
TensorType,
|
TensorType,
|
||||||
|
filter_out_non_signature_kwargs,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_torch_tensor,
|
is_torch_tensor,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
|
from ...utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -266,12 +266,12 @@ def convert_segmentation_map_to_binary_masks(
|
|||||||
segmentation_map: "np.ndarray",
|
segmentation_map: "np.ndarray",
|
||||||
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
|
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
|
||||||
ignore_index: Optional[int] = None,
|
ignore_index: Optional[int] = None,
|
||||||
reduce_labels: bool = False,
|
do_reduce_labels: bool = False,
|
||||||
):
|
):
|
||||||
if reduce_labels and ignore_index is None:
|
if do_reduce_labels and ignore_index is None:
|
||||||
raise ValueError("If `reduce_labels` is True, `ignore_index` must be provided.")
|
raise ValueError("If `do_reduce_labels` is True, `ignore_index` must be provided.")
|
||||||
|
|
||||||
if reduce_labels:
|
if do_reduce_labels:
|
||||||
segmentation_map = np.where(segmentation_map == 0, ignore_index, segmentation_map - 1)
|
segmentation_map = np.where(segmentation_map == 0, ignore_index, segmentation_map - 1)
|
||||||
|
|
||||||
# Get unique ids (class or instance ids based on input)
|
# Get unique ids (class or instance ids based on input)
|
||||||
@@ -295,8 +295,8 @@ def convert_segmentation_map_to_binary_masks(
|
|||||||
labels = np.zeros(all_labels.shape[0])
|
labels = np.zeros(all_labels.shape[0])
|
||||||
|
|
||||||
for label in all_labels:
|
for label in all_labels:
|
||||||
class_id = instance_id_to_semantic_id[label + 1 if reduce_labels else label]
|
class_id = instance_id_to_semantic_id[label + 1 if do_reduce_labels else label]
|
||||||
labels[all_labels == label] = class_id - 1 if reduce_labels else class_id
|
labels[all_labels == label] = class_id - 1 if do_reduce_labels else class_id
|
||||||
else:
|
else:
|
||||||
labels = all_labels
|
labels = all_labels
|
||||||
|
|
||||||
@@ -387,15 +387,20 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
|||||||
ignore_index (`int`, *optional*):
|
ignore_index (`int`, *optional*):
|
||||||
Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels
|
Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels
|
||||||
denoted with 0 (background) will be replaced with `ignore_index`.
|
denoted with 0 (background) will be replaced with `ignore_index`.
|
||||||
reduce_labels (`bool`, *optional*, defaults to `False`):
|
do_reduce_labels (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not to decrement all label values of segmentation maps by 1. Usually used for datasets where 0
|
Whether or not to decrement all label values of segmentation maps by 1. Usually used for datasets where 0
|
||||||
is used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k).
|
is used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k).
|
||||||
The background label will be replaced by `ignore_index`.
|
The background label will be replaced by `ignore_index`.
|
||||||
|
num_labels (`int`, *optional*):
|
||||||
|
The number of labels in the segmentation map.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_input_names = ["pixel_values", "pixel_mask"]
|
model_input_names = ["pixel_values", "pixel_mask"]
|
||||||
|
|
||||||
|
@deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.44.0")
|
||||||
|
@deprecate_kwarg("size_divisibility", new_name="size_divisor", version="4.41.0")
|
||||||
|
@deprecate_kwarg("max_size", version="4.27.0", warn_if_greater_or_equal_version=True)
|
||||||
|
@filter_out_non_signature_kwargs(extra=["max_size", *INIT_SERVICE_KWARGS])
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
do_resize: bool = True,
|
do_resize: bool = True,
|
||||||
@@ -408,32 +413,19 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
|||||||
image_mean: Union[float, List[float]] = None,
|
image_mean: Union[float, List[float]] = None,
|
||||||
image_std: Union[float, List[float]] = None,
|
image_std: Union[float, List[float]] = None,
|
||||||
ignore_index: Optional[int] = None,
|
ignore_index: Optional[int] = None,
|
||||||
reduce_labels: bool = False,
|
do_reduce_labels: bool = False,
|
||||||
|
num_labels: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if "size_divisibility" in kwargs:
|
super().__init__(**kwargs)
|
||||||
warnings.warn(
|
|
||||||
"The `size_divisibility` argument is deprecated and will be removed in v4.27. Please use "
|
# We make max_size a private attribute so we can pass it as a default value in the preprocess method whilst
|
||||||
"`size_divisor` instead.",
|
# `size` can still be pass in as an int
|
||||||
FutureWarning,
|
self._max_size = kwargs.pop("max_size", 1333)
|
||||||
)
|
|
||||||
size_divisor = kwargs.pop("size_divisibility")
|
|
||||||
if "max_size" in kwargs:
|
|
||||||
warnings.warn(
|
|
||||||
"The `max_size` argument is deprecated and will be removed in v4.27. Please use size['longest_edge']"
|
|
||||||
" instead.",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
# We make max_size a private attribute so we can pass it as a default value in the preprocess method whilst
|
|
||||||
# `size` can still be pass in as an int
|
|
||||||
self._max_size = kwargs.pop("max_size")
|
|
||||||
else:
|
|
||||||
self._max_size = 1333
|
|
||||||
|
|
||||||
size = size if size is not None else {"shortest_edge": 800, "longest_edge": self._max_size}
|
size = size if size is not None else {"shortest_edge": 800, "longest_edge": self._max_size}
|
||||||
size = get_size_dict(size, max_size=self._max_size, default_to_square=False)
|
size = get_size_dict(size, max_size=self._max_size, default_to_square=False)
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.do_resize = do_resize
|
self.do_resize = do_resize
|
||||||
self.size = size
|
self.size = size
|
||||||
self.resample = resample
|
self.resample = resample
|
||||||
@@ -444,26 +436,8 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
|||||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
||||||
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
||||||
self.ignore_index = ignore_index
|
self.ignore_index = ignore_index
|
||||||
self.reduce_labels = reduce_labels
|
self.do_reduce_labels = do_reduce_labels
|
||||||
self._valid_processor_keys = [
|
self.num_labels = num_labels
|
||||||
"images",
|
|
||||||
"segmentation_maps",
|
|
||||||
"instance_id_to_semantic_id",
|
|
||||||
"do_resize",
|
|
||||||
"size",
|
|
||||||
"size_divisor",
|
|
||||||
"resample",
|
|
||||||
"do_rescale",
|
|
||||||
"rescale_factor",
|
|
||||||
"do_normalize",
|
|
||||||
"image_mean",
|
|
||||||
"image_std",
|
|
||||||
"ignore_index",
|
|
||||||
"reduce_labels",
|
|
||||||
"return_tensors",
|
|
||||||
"data_format",
|
|
||||||
"input_data_format",
|
|
||||||
]
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
|
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
|
||||||
@@ -475,9 +449,22 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
|||||||
if "max_size" in kwargs:
|
if "max_size" in kwargs:
|
||||||
image_processor_dict["max_size"] = kwargs.pop("max_size")
|
image_processor_dict["max_size"] = kwargs.pop("max_size")
|
||||||
if "size_divisibility" in kwargs:
|
if "size_divisibility" in kwargs:
|
||||||
image_processor_dict["size_divisibility"] = kwargs.pop("size_divisibility")
|
image_processor_dict["size_divisor"] = kwargs.pop("size_divisibility")
|
||||||
|
if "reduce_labels" in image_processor_dict:
|
||||||
|
image_processor_dict["do_reduce_labels"] = image_processor_dict.pop("reduce_labels")
|
||||||
return super().from_dict(image_processor_dict, **kwargs)
|
return super().from_dict(image_processor_dict, **kwargs)
|
||||||
|
|
||||||
|
# Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.to_dict
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Serializes this instance to a Python dictionary. This method calls the superclass method and then removes the
|
||||||
|
`_max_size` attribute from the dictionary.
|
||||||
|
"""
|
||||||
|
image_processor_dict = super().to_dict()
|
||||||
|
image_processor_dict.pop("_max_size", None)
|
||||||
|
return image_processor_dict
|
||||||
|
|
||||||
|
@deprecate_kwarg("max_size", version="4.27.0", warn_if_greater_or_equal_version=True)
|
||||||
# Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.resize with get_maskformer_resize_output_image_size->get_mask2former_resize_output_image_size
|
# Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.resize with get_maskformer_resize_output_image_size->get_mask2former_resize_output_image_size
|
||||||
def resize(
|
def resize(
|
||||||
self,
|
self,
|
||||||
@@ -508,15 +495,10 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
|||||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||||
"""
|
"""
|
||||||
if "max_size" in kwargs:
|
|
||||||
warnings.warn(
|
# Deprecated, backward compatibility
|
||||||
"The `max_size` parameter is deprecated and will be removed in v4.27. "
|
max_size = kwargs.pop("max_size", None)
|
||||||
"Please specify in `size['longest_edge'] instead`.",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
max_size = kwargs.pop("max_size")
|
|
||||||
else:
|
|
||||||
max_size = None
|
|
||||||
size = get_size_dict(size, max_size=max_size, default_to_square=False)
|
size = get_size_dict(size, max_size=max_size, default_to_square=False)
|
||||||
if "shortest_edge" in size and "longest_edge" in size:
|
if "shortest_edge" in size and "longest_edge" in size:
|
||||||
size, max_size = size["shortest_edge"], size["longest_edge"]
|
size, max_size = size["shortest_edge"], size["longest_edge"]
|
||||||
@@ -576,15 +558,15 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
|||||||
segmentation_map: "np.ndarray",
|
segmentation_map: "np.ndarray",
|
||||||
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
|
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
|
||||||
ignore_index: Optional[int] = None,
|
ignore_index: Optional[int] = None,
|
||||||
reduce_labels: bool = False,
|
do_reduce_labels: bool = False,
|
||||||
):
|
):
|
||||||
reduce_labels = reduce_labels if reduce_labels is not None else self.reduce_labels
|
do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
|
||||||
ignore_index = ignore_index if ignore_index is not None else self.ignore_index
|
ignore_index = ignore_index if ignore_index is not None else self.ignore_index
|
||||||
return convert_segmentation_map_to_binary_masks(
|
return convert_segmentation_map_to_binary_masks(
|
||||||
segmentation_map=segmentation_map,
|
segmentation_map=segmentation_map,
|
||||||
instance_id_to_semantic_id=instance_id_to_semantic_id,
|
instance_id_to_semantic_id=instance_id_to_semantic_id,
|
||||||
ignore_index=ignore_index,
|
ignore_index=ignore_index,
|
||||||
reduce_labels=reduce_labels,
|
do_reduce_labels=do_reduce_labels,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, images, segmentation_maps=None, **kwargs) -> BatchFeature:
|
def __call__(self, images, segmentation_maps=None, **kwargs) -> BatchFeature:
|
||||||
@@ -693,6 +675,8 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
|||||||
segmentation_map = segmentation_map.squeeze(0)
|
segmentation_map = segmentation_map.squeeze(0)
|
||||||
return segmentation_map
|
return segmentation_map
|
||||||
|
|
||||||
|
@deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.44.0")
|
||||||
|
@filter_out_non_signature_kwargs()
|
||||||
def preprocess(
|
def preprocess(
|
||||||
self,
|
self,
|
||||||
images: ImageInput,
|
images: ImageInput,
|
||||||
@@ -708,18 +692,11 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
|||||||
image_mean: Optional[Union[float, List[float]]] = None,
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
image_std: Optional[Union[float, List[float]]] = None,
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
ignore_index: Optional[int] = None,
|
ignore_index: Optional[int] = None,
|
||||||
reduce_labels: Optional[bool] = None,
|
do_reduce_labels: Optional[bool] = None,
|
||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
||||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
**kwargs,
|
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
if "pad_and_return_pixel_mask" in kwargs:
|
|
||||||
warnings.warn(
|
|
||||||
"The `pad_and_return_pixel_mask` argument is deprecated and will be removed in a future version",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||||
size = size if size is not None else self.size
|
size = size if size is not None else self.size
|
||||||
size = get_size_dict(size, default_to_square=False, max_size=self._max_size)
|
size = get_size_dict(size, default_to_square=False, max_size=self._max_size)
|
||||||
@@ -731,9 +708,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
|||||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||||
image_std = image_std if image_std is not None else self.image_std
|
image_std = image_std if image_std is not None else self.image_std
|
||||||
ignore_index = ignore_index if ignore_index is not None else self.ignore_index
|
ignore_index = ignore_index if ignore_index is not None else self.ignore_index
|
||||||
reduce_labels = reduce_labels if reduce_labels is not None else self.reduce_labels
|
do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
|
||||||
|
|
||||||
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
|
|
||||||
|
|
||||||
if not valid_images(images):
|
if not valid_images(images):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -795,7 +770,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
|||||||
segmentation_maps,
|
segmentation_maps,
|
||||||
instance_id_to_semantic_id,
|
instance_id_to_semantic_id,
|
||||||
ignore_index,
|
ignore_index,
|
||||||
reduce_labels,
|
do_reduce_labels,
|
||||||
return_tensors,
|
return_tensors,
|
||||||
input_data_format=input_data_format,
|
input_data_format=input_data_format,
|
||||||
)
|
)
|
||||||
@@ -891,7 +866,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
|||||||
segmentation_maps: ImageInput = None,
|
segmentation_maps: ImageInput = None,
|
||||||
instance_id_to_semantic_id: Optional[Union[List[Dict[int, int]], Dict[int, int]]] = None,
|
instance_id_to_semantic_id: Optional[Union[List[Dict[int, int]], Dict[int, int]]] = None,
|
||||||
ignore_index: Optional[int] = None,
|
ignore_index: Optional[int] = None,
|
||||||
reduce_labels: bool = False,
|
do_reduce_labels: bool = False,
|
||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
):
|
):
|
||||||
@@ -946,7 +921,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
|||||||
`mask_labels[i][j]` if `class_labels[i][j]`.
|
`mask_labels[i][j]` if `class_labels[i][j]`.
|
||||||
"""
|
"""
|
||||||
ignore_index = self.ignore_index if ignore_index is None else ignore_index
|
ignore_index = self.ignore_index if ignore_index is None else ignore_index
|
||||||
reduce_labels = self.reduce_labels if reduce_labels is None else reduce_labels
|
do_reduce_labels = self.do_reduce_labels if do_reduce_labels is None else do_reduce_labels
|
||||||
|
|
||||||
pixel_values_list = [to_numpy_array(pixel_values) for pixel_values in pixel_values_list]
|
pixel_values_list = [to_numpy_array(pixel_values) for pixel_values in pixel_values_list]
|
||||||
|
|
||||||
@@ -970,7 +945,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
|||||||
instance_id = instance_id_to_semantic_id
|
instance_id = instance_id_to_semantic_id
|
||||||
# Use instance2class_id mapping per image
|
# Use instance2class_id mapping per image
|
||||||
masks, classes = self.convert_segmentation_map_to_binary_masks(
|
masks, classes = self.convert_segmentation_map_to_binary_masks(
|
||||||
segmentation_map, instance_id, ignore_index=ignore_index, reduce_labels=reduce_labels
|
segmentation_map, instance_id, ignore_index=ignore_index, do_reduce_labels=do_reduce_labels
|
||||||
)
|
)
|
||||||
# We add an axis to make them compatible with the transformations library
|
# We add an axis to make them compatible with the transformations library
|
||||||
# this will be removed in the future
|
# this will be removed in the future
|
||||||
|
|||||||
@@ -295,8 +295,8 @@ def convert_maskformer_checkpoint(
|
|||||||
ignore_index = 65535
|
ignore_index = 65535
|
||||||
else:
|
else:
|
||||||
ignore_index = 255
|
ignore_index = 255
|
||||||
reduce_labels = True if "ade" in model_name else False
|
do_reduce_labels = True if "ade" in model_name else False
|
||||||
image_processor = MaskFormerImageProcessor(ignore_index=ignore_index, reduce_labels=reduce_labels)
|
image_processor = MaskFormerImageProcessor(ignore_index=ignore_index, do_reduce_labels=do_reduce_labels)
|
||||||
|
|
||||||
inputs = image_processor(image, return_tensors="pt")
|
inputs = image_processor(image, return_tensors="pt")
|
||||||
|
|
||||||
|
|||||||
@@ -276,8 +276,8 @@ def convert_maskformer_checkpoint(
|
|||||||
ignore_index = 65535
|
ignore_index = 65535
|
||||||
else:
|
else:
|
||||||
ignore_index = 255
|
ignore_index = 255
|
||||||
reduce_labels = True if "ade" in model_name else False
|
do_reduce_labels = True if "ade" in model_name else False
|
||||||
image_processor = MaskFormerImageProcessor(ignore_index=ignore_index, reduce_labels=reduce_labels)
|
image_processor = MaskFormerImageProcessor(ignore_index=ignore_index, do_reduce_labels=do_reduce_labels)
|
||||||
|
|
||||||
inputs = image_processor(image, return_tensors="pt")
|
inputs = image_processor(image, return_tensors="pt")
|
||||||
|
|
||||||
|
|||||||
@@ -15,12 +15,11 @@
|
|||||||
"""Image processor class for MaskFormer."""
|
"""Image processor class for MaskFormer."""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import warnings
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
from ...image_processing_utils import INIT_SERVICE_KWARGS, BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
from ...image_transforms import (
|
from ...image_transforms import (
|
||||||
PaddingMode,
|
PaddingMode,
|
||||||
get_resize_output_image_size,
|
get_resize_output_image_size,
|
||||||
@@ -39,17 +38,18 @@ from ...image_utils import (
|
|||||||
make_list_of_images,
|
make_list_of_images,
|
||||||
to_numpy_array,
|
to_numpy_array,
|
||||||
valid_images,
|
valid_images,
|
||||||
validate_kwargs,
|
|
||||||
validate_preprocess_arguments,
|
validate_preprocess_arguments,
|
||||||
)
|
)
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
IMAGENET_DEFAULT_MEAN,
|
IMAGENET_DEFAULT_MEAN,
|
||||||
IMAGENET_DEFAULT_STD,
|
IMAGENET_DEFAULT_STD,
|
||||||
TensorType,
|
TensorType,
|
||||||
|
filter_out_non_signature_kwargs,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_torch_tensor,
|
is_torch_tensor,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
|
from ...utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -269,12 +269,12 @@ def convert_segmentation_map_to_binary_masks(
|
|||||||
segmentation_map: "np.ndarray",
|
segmentation_map: "np.ndarray",
|
||||||
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
|
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
|
||||||
ignore_index: Optional[int] = None,
|
ignore_index: Optional[int] = None,
|
||||||
reduce_labels: bool = False,
|
do_reduce_labels: bool = False,
|
||||||
):
|
):
|
||||||
if reduce_labels and ignore_index is None:
|
if do_reduce_labels and ignore_index is None:
|
||||||
raise ValueError("If `reduce_labels` is True, `ignore_index` must be provided.")
|
raise ValueError("If `do_reduce_labels` is True, `ignore_index` must be provided.")
|
||||||
|
|
||||||
if reduce_labels:
|
if do_reduce_labels:
|
||||||
segmentation_map = np.where(segmentation_map == 0, ignore_index, segmentation_map - 1)
|
segmentation_map = np.where(segmentation_map == 0, ignore_index, segmentation_map - 1)
|
||||||
|
|
||||||
# Get unique ids (class or instance ids based on input)
|
# Get unique ids (class or instance ids based on input)
|
||||||
@@ -298,8 +298,8 @@ def convert_segmentation_map_to_binary_masks(
|
|||||||
labels = np.zeros(all_labels.shape[0])
|
labels = np.zeros(all_labels.shape[0])
|
||||||
|
|
||||||
for label in all_labels:
|
for label in all_labels:
|
||||||
class_id = instance_id_to_semantic_id[label + 1 if reduce_labels else label]
|
class_id = instance_id_to_semantic_id[label + 1 if do_reduce_labels else label]
|
||||||
labels[all_labels == label] = class_id - 1 if reduce_labels else class_id
|
labels[all_labels == label] = class_id - 1 if do_reduce_labels else class_id
|
||||||
else:
|
else:
|
||||||
labels = all_labels
|
labels = all_labels
|
||||||
|
|
||||||
@@ -393,11 +393,17 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
|||||||
Whether or not to decrement all label values of segmentation maps by 1. Usually used for datasets where 0
|
Whether or not to decrement all label values of segmentation maps by 1. Usually used for datasets where 0
|
||||||
is used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k).
|
is used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k).
|
||||||
The background label will be replaced by `ignore_index`.
|
The background label will be replaced by `ignore_index`.
|
||||||
|
num_labels (`int`, *optional*):
|
||||||
|
The number of labels in the segmentation map.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_input_names = ["pixel_values", "pixel_mask"]
|
model_input_names = ["pixel_values", "pixel_mask"]
|
||||||
|
|
||||||
|
@deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.44.0")
|
||||||
|
@deprecate_kwarg("size_divisibility", new_name="size_divisor", version="4.41.0")
|
||||||
|
@deprecate_kwarg("max_size", version="4.27.0", warn_if_greater_or_equal_version=True)
|
||||||
|
@filter_out_non_signature_kwargs(extra=["max_size", *INIT_SERVICE_KWARGS])
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
do_resize: bool = True,
|
do_resize: bool = True,
|
||||||
@@ -411,38 +417,18 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
|||||||
image_std: Union[float, List[float]] = None,
|
image_std: Union[float, List[float]] = None,
|
||||||
ignore_index: Optional[int] = None,
|
ignore_index: Optional[int] = None,
|
||||||
do_reduce_labels: bool = False,
|
do_reduce_labels: bool = False,
|
||||||
|
num_labels: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if "size_divisibility" in kwargs:
|
super().__init__(**kwargs)
|
||||||
warnings.warn(
|
|
||||||
"The `size_divisibility` argument is deprecated and will be removed in v4.27. Please use "
|
# We make max_size a private attribute so we can pass it as a default value in the preprocess method whilst
|
||||||
"`size_divisor` instead.",
|
# `size` can still be pass in as an int
|
||||||
FutureWarning,
|
self._max_size = kwargs.pop("max_size", 1333)
|
||||||
)
|
|
||||||
size_divisor = kwargs.pop("size_divisibility")
|
|
||||||
if "max_size" in kwargs:
|
|
||||||
warnings.warn(
|
|
||||||
"The `max_size` argument is deprecated and will be removed in v4.27. Please use size['longest_edge']"
|
|
||||||
" instead.",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
# We make max_size a private attribute so we can pass it as a default value in the preprocess method whilst
|
|
||||||
# `size` can still be pass in as an int
|
|
||||||
self._max_size = kwargs.pop("max_size")
|
|
||||||
else:
|
|
||||||
self._max_size = 1333
|
|
||||||
if "reduce_labels" in kwargs:
|
|
||||||
warnings.warn(
|
|
||||||
"The `reduce_labels` argument is deprecated and will be removed in v4.27. Please use "
|
|
||||||
"`do_reduce_labels` instead.",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
do_reduce_labels = kwargs.pop("reduce_labels")
|
|
||||||
|
|
||||||
size = size if size is not None else {"shortest_edge": 800, "longest_edge": self._max_size}
|
size = size if size is not None else {"shortest_edge": 800, "longest_edge": self._max_size}
|
||||||
size = get_size_dict(size, max_size=self._max_size, default_to_square=False)
|
size = get_size_dict(size, max_size=self._max_size, default_to_square=False)
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.do_resize = do_resize
|
self.do_resize = do_resize
|
||||||
self.size = size
|
self.size = size
|
||||||
self.resample = resample
|
self.resample = resample
|
||||||
@@ -454,25 +440,7 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
|||||||
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
||||||
self.ignore_index = ignore_index
|
self.ignore_index = ignore_index
|
||||||
self.do_reduce_labels = do_reduce_labels
|
self.do_reduce_labels = do_reduce_labels
|
||||||
self._valid_processor_keys = [
|
self.num_labels = num_labels
|
||||||
"images",
|
|
||||||
"segmentation_maps",
|
|
||||||
"instance_id_to_semantic_id",
|
|
||||||
"do_resize",
|
|
||||||
"size",
|
|
||||||
"size_divisor",
|
|
||||||
"resample",
|
|
||||||
"do_rescale",
|
|
||||||
"rescale_factor",
|
|
||||||
"do_normalize",
|
|
||||||
"image_mean",
|
|
||||||
"image_std",
|
|
||||||
"ignore_index",
|
|
||||||
"do_reduce_labels",
|
|
||||||
"return_tensors",
|
|
||||||
"data_format",
|
|
||||||
"input_data_format",
|
|
||||||
]
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
|
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
|
||||||
@@ -484,9 +452,21 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
|||||||
if "max_size" in kwargs:
|
if "max_size" in kwargs:
|
||||||
image_processor_dict["max_size"] = kwargs.pop("max_size")
|
image_processor_dict["max_size"] = kwargs.pop("max_size")
|
||||||
if "size_divisibility" in kwargs:
|
if "size_divisibility" in kwargs:
|
||||||
image_processor_dict["size_divisibility"] = kwargs.pop("size_divisibility")
|
image_processor_dict["size_divisor"] = kwargs.pop("size_divisibility")
|
||||||
|
if "reduce_labels" in image_processor_dict:
|
||||||
|
image_processor_dict["do_reduce_labels"] = image_processor_dict.pop("reduce_labels")
|
||||||
return super().from_dict(image_processor_dict, **kwargs)
|
return super().from_dict(image_processor_dict, **kwargs)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Serializes this instance to a Python dictionary. This method calls the superclass method and then removes the
|
||||||
|
`_max_size` attribute from the dictionary.
|
||||||
|
"""
|
||||||
|
image_processor_dict = super().to_dict()
|
||||||
|
image_processor_dict.pop("_max_size", None)
|
||||||
|
return image_processor_dict
|
||||||
|
|
||||||
|
@deprecate_kwarg("max_size", version="4.27.0", warn_if_greater_or_equal_version=True)
|
||||||
def resize(
|
def resize(
|
||||||
self,
|
self,
|
||||||
image: np.ndarray,
|
image: np.ndarray,
|
||||||
@@ -516,15 +496,10 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
|||||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||||
"""
|
"""
|
||||||
if "max_size" in kwargs:
|
|
||||||
warnings.warn(
|
# Deprecated, backward compatibility
|
||||||
"The `max_size` parameter is deprecated and will be removed in v4.27. "
|
max_size = kwargs.pop("max_size", None)
|
||||||
"Please specify in `size['longest_edge'] instead`.",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
max_size = kwargs.pop("max_size")
|
|
||||||
else:
|
|
||||||
max_size = None
|
|
||||||
size = get_size_dict(size, max_size=max_size, default_to_square=False)
|
size = get_size_dict(size, max_size=max_size, default_to_square=False)
|
||||||
if "shortest_edge" in size and "longest_edge" in size:
|
if "shortest_edge" in size and "longest_edge" in size:
|
||||||
size, max_size = size["shortest_edge"], size["longest_edge"]
|
size, max_size = size["shortest_edge"], size["longest_edge"]
|
||||||
@@ -583,15 +558,15 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
|||||||
segmentation_map: "np.ndarray",
|
segmentation_map: "np.ndarray",
|
||||||
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
|
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
|
||||||
ignore_index: Optional[int] = None,
|
ignore_index: Optional[int] = None,
|
||||||
reduce_labels: bool = False,
|
do_reduce_labels: bool = False,
|
||||||
):
|
):
|
||||||
reduce_labels = reduce_labels if reduce_labels is not None else self.reduce_labels
|
do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
|
||||||
ignore_index = ignore_index if ignore_index is not None else self.ignore_index
|
ignore_index = ignore_index if ignore_index is not None else self.ignore_index
|
||||||
return convert_segmentation_map_to_binary_masks(
|
return convert_segmentation_map_to_binary_masks(
|
||||||
segmentation_map=segmentation_map,
|
segmentation_map=segmentation_map,
|
||||||
instance_id_to_semantic_id=instance_id_to_semantic_id,
|
instance_id_to_semantic_id=instance_id_to_semantic_id,
|
||||||
ignore_index=ignore_index,
|
ignore_index=ignore_index,
|
||||||
reduce_labels=reduce_labels,
|
do_reduce_labels=do_reduce_labels,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, images, segmentation_maps=None, **kwargs) -> BatchFeature:
|
def __call__(self, images, segmentation_maps=None, **kwargs) -> BatchFeature:
|
||||||
@@ -700,6 +675,8 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
|||||||
segmentation_map = segmentation_map.squeeze(0)
|
segmentation_map = segmentation_map.squeeze(0)
|
||||||
return segmentation_map
|
return segmentation_map
|
||||||
|
|
||||||
|
@deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.44.0")
|
||||||
|
@filter_out_non_signature_kwargs()
|
||||||
def preprocess(
|
def preprocess(
|
||||||
self,
|
self,
|
||||||
images: ImageInput,
|
images: ImageInput,
|
||||||
@@ -719,24 +696,7 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
|||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
||||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
**kwargs,
|
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
if "pad_and_return_pixel_mask" in kwargs:
|
|
||||||
warnings.warn(
|
|
||||||
"The `pad_and_return_pixel_mask` argument is deprecated and will be removed in v4.27",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
if "reduce_labels" in kwargs:
|
|
||||||
warnings.warn(
|
|
||||||
"The `reduce_labels` argument is deprecated and will be removed in v4.27. Please use"
|
|
||||||
" `do_reduce_labels` instead.",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
if do_reduce_labels is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot use both `reduce_labels` and `do_reduce_labels`. Please use `do_reduce_labels` instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||||
size = size if size is not None else self.size
|
size = size if size is not None else self.size
|
||||||
size = get_size_dict(size, default_to_square=False, max_size=self._max_size)
|
size = get_size_dict(size, default_to_square=False, max_size=self._max_size)
|
||||||
@@ -755,7 +715,6 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
|||||||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
)
|
)
|
||||||
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
|
|
||||||
|
|
||||||
validate_preprocess_arguments(
|
validate_preprocess_arguments(
|
||||||
do_rescale=do_rescale,
|
do_rescale=do_rescale,
|
||||||
@@ -907,7 +866,7 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
|||||||
segmentation_maps: ImageInput = None,
|
segmentation_maps: ImageInput = None,
|
||||||
instance_id_to_semantic_id: Optional[Union[List[Dict[int, int]], Dict[int, int]]] = None,
|
instance_id_to_semantic_id: Optional[Union[List[Dict[int, int]], Dict[int, int]]] = None,
|
||||||
ignore_index: Optional[int] = None,
|
ignore_index: Optional[int] = None,
|
||||||
reduce_labels: bool = False,
|
do_reduce_labels: bool = False,
|
||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
):
|
):
|
||||||
@@ -959,7 +918,7 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
|||||||
`mask_labels[i][j]` if `class_labels[i][j]`.
|
`mask_labels[i][j]` if `class_labels[i][j]`.
|
||||||
"""
|
"""
|
||||||
ignore_index = self.ignore_index if ignore_index is None else ignore_index
|
ignore_index = self.ignore_index if ignore_index is None else ignore_index
|
||||||
reduce_labels = self.do_reduce_labels if reduce_labels is None else reduce_labels
|
do_reduce_labels = self.do_reduce_labels if do_reduce_labels is None else do_reduce_labels
|
||||||
|
|
||||||
pixel_values_list = [to_numpy_array(pixel_values) for pixel_values in pixel_values_list]
|
pixel_values_list = [to_numpy_array(pixel_values) for pixel_values in pixel_values_list]
|
||||||
|
|
||||||
@@ -983,7 +942,7 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
|||||||
instance_id = instance_id_to_semantic_id
|
instance_id = instance_id_to_semantic_id
|
||||||
# Use instance2class_id mapping per image
|
# Use instance2class_id mapping per image
|
||||||
masks, classes = self.convert_segmentation_map_to_binary_masks(
|
masks, classes = self.convert_segmentation_map_to_binary_masks(
|
||||||
segmentation_map, instance_id, ignore_index=ignore_index, reduce_labels=reduce_labels
|
segmentation_map, instance_id, ignore_index=ignore_index, do_reduce_labels=do_reduce_labels
|
||||||
)
|
)
|
||||||
# We add an axis to make them compatible with the transformations library
|
# We add an axis to make them compatible with the transformations library
|
||||||
# this will be removed in the future
|
# this will be removed in the future
|
||||||
|
|||||||
@@ -16,14 +16,13 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import warnings
|
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from huggingface_hub.utils import RepositoryNotFoundError
|
from huggingface_hub.utils import RepositoryNotFoundError
|
||||||
|
|
||||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
from ...image_processing_utils import INIT_SERVICE_KWARGS, BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
from ...image_transforms import (
|
from ...image_transforms import (
|
||||||
PaddingMode,
|
PaddingMode,
|
||||||
get_resize_output_image_size,
|
get_resize_output_image_size,
|
||||||
@@ -42,17 +41,18 @@ from ...image_utils import (
|
|||||||
make_list_of_images,
|
make_list_of_images,
|
||||||
to_numpy_array,
|
to_numpy_array,
|
||||||
valid_images,
|
valid_images,
|
||||||
validate_kwargs,
|
|
||||||
validate_preprocess_arguments,
|
validate_preprocess_arguments,
|
||||||
)
|
)
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
IMAGENET_DEFAULT_MEAN,
|
IMAGENET_DEFAULT_MEAN,
|
||||||
IMAGENET_DEFAULT_STD,
|
IMAGENET_DEFAULT_STD,
|
||||||
TensorType,
|
TensorType,
|
||||||
|
filter_out_non_signature_kwargs,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_torch_tensor,
|
is_torch_tensor,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
|
from ...utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -268,12 +268,12 @@ def convert_segmentation_map_to_binary_masks(
|
|||||||
segmentation_map: "np.ndarray",
|
segmentation_map: "np.ndarray",
|
||||||
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
|
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
|
||||||
ignore_index: Optional[int] = None,
|
ignore_index: Optional[int] = None,
|
||||||
reduce_labels: bool = False,
|
do_reduce_labels: bool = False,
|
||||||
):
|
):
|
||||||
if reduce_labels and ignore_index is None:
|
if do_reduce_labels and ignore_index is None:
|
||||||
raise ValueError("If `reduce_labels` is True, `ignore_index` must be provided.")
|
raise ValueError("If `do_reduce_labels` is True, `ignore_index` must be provided.")
|
||||||
|
|
||||||
if reduce_labels:
|
if do_reduce_labels:
|
||||||
segmentation_map = np.where(segmentation_map == 0, ignore_index, segmentation_map - 1)
|
segmentation_map = np.where(segmentation_map == 0, ignore_index, segmentation_map - 1)
|
||||||
|
|
||||||
# Get unique ids (class or instance ids based on input)
|
# Get unique ids (class or instance ids based on input)
|
||||||
@@ -297,8 +297,8 @@ def convert_segmentation_map_to_binary_masks(
|
|||||||
labels = np.zeros(all_labels.shape[0])
|
labels = np.zeros(all_labels.shape[0])
|
||||||
|
|
||||||
for label in all_labels:
|
for label in all_labels:
|
||||||
class_id = instance_id_to_semantic_id[label + 1 if reduce_labels else label]
|
class_id = instance_id_to_semantic_id[label + 1 if do_reduce_labels else label]
|
||||||
labels[all_labels == label] = class_id - 1 if reduce_labels else class_id
|
labels[all_labels == label] = class_id - 1 if do_reduce_labels else class_id
|
||||||
else:
|
else:
|
||||||
labels = all_labels
|
labels = all_labels
|
||||||
|
|
||||||
@@ -418,10 +418,15 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
|||||||
JSON file containing class information for the dataset. See `shi-labs/oneformer_demo/cityscapes_panoptic.json` for an example.
|
JSON file containing class information for the dataset. See `shi-labs/oneformer_demo/cityscapes_panoptic.json` for an example.
|
||||||
num_text (`int`, *optional*):
|
num_text (`int`, *optional*):
|
||||||
Number of text entries in the text input list.
|
Number of text entries in the text input list.
|
||||||
|
num_labels (`int`, *optional*):
|
||||||
|
The number of labels in the segmentation map.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_input_names = ["pixel_values", "pixel_mask", "task_inputs"]
|
model_input_names = ["pixel_values", "pixel_mask", "task_inputs"]
|
||||||
|
|
||||||
|
@deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.44.0")
|
||||||
|
@deprecate_kwarg("max_size", version="4.27.0", warn_if_greater_or_equal_version=True)
|
||||||
|
@filter_out_non_signature_kwargs(extra=["max_size", "metadata", *INIT_SERVICE_KWARGS])
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
do_resize: bool = True,
|
do_resize: bool = True,
|
||||||
@@ -437,28 +442,20 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
|||||||
repo_path: Optional[str] = "shi-labs/oneformer_demo",
|
repo_path: Optional[str] = "shi-labs/oneformer_demo",
|
||||||
class_info_file: str = None,
|
class_info_file: str = None,
|
||||||
num_text: Optional[int] = None,
|
num_text: Optional[int] = None,
|
||||||
|
num_labels: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if "max_size" in kwargs:
|
super().__init__(**kwargs)
|
||||||
self._max_size = kwargs.pop("max_size")
|
|
||||||
else:
|
# Deprecated, backward compatibility
|
||||||
self._max_size = 1333
|
self._max_size = kwargs.pop("max_size", 1333)
|
||||||
|
|
||||||
size = size if size is not None else {"shortest_edge": 800, "longest_edge": self._max_size}
|
size = size if size is not None else {"shortest_edge": 800, "longest_edge": self._max_size}
|
||||||
size = get_size_dict(size, max_size=self._max_size, default_to_square=False)
|
size = get_size_dict(size, max_size=self._max_size, default_to_square=False)
|
||||||
|
|
||||||
if "reduce_labels" in kwargs:
|
|
||||||
warnings.warn(
|
|
||||||
"The `reduce_labels` argument is deprecated and will be removed in v4.27. "
|
|
||||||
"Please use `do_reduce_labels` instead.",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
do_reduce_labels = kwargs.pop("reduce_labels")
|
|
||||||
|
|
||||||
if class_info_file is None:
|
if class_info_file is None:
|
||||||
raise ValueError("You must provide a `class_info_file`")
|
raise ValueError("You must provide a `class_info_file`")
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.do_resize = do_resize
|
self.do_resize = do_resize
|
||||||
self.size = size
|
self.size = size
|
||||||
self.resample = resample
|
self.resample = resample
|
||||||
@@ -473,26 +470,30 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
|||||||
self.repo_path = repo_path
|
self.repo_path = repo_path
|
||||||
self.metadata = prepare_metadata(load_metadata(repo_path, class_info_file))
|
self.metadata = prepare_metadata(load_metadata(repo_path, class_info_file))
|
||||||
self.num_text = num_text
|
self.num_text = num_text
|
||||||
self._valid_processor_keys = [
|
self.num_labels = num_labels
|
||||||
"images",
|
|
||||||
"task_inputs",
|
|
||||||
"segmentation_maps",
|
|
||||||
"instance_id_to_semantic_id",
|
|
||||||
"do_resize",
|
|
||||||
"size",
|
|
||||||
"resample",
|
|
||||||
"do_rescale",
|
|
||||||
"rescale_factor",
|
|
||||||
"do_normalize",
|
|
||||||
"image_mean",
|
|
||||||
"image_std",
|
|
||||||
"ignore_index",
|
|
||||||
"do_reduce_labels",
|
|
||||||
"return_tensors",
|
|
||||||
"data_format",
|
|
||||||
"input_data_format",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
|
||||||
|
"""
|
||||||
|
Overrides the `from_dict` method from the base class to save support of deprecated `reduce_labels` in old configs
|
||||||
|
"""
|
||||||
|
image_processor_dict = image_processor_dict.copy()
|
||||||
|
if "reduce_labels" in image_processor_dict:
|
||||||
|
image_processor_dict["do_reduce_labels"] = image_processor_dict.pop("reduce_labels")
|
||||||
|
return super().from_dict(image_processor_dict, **kwargs)
|
||||||
|
|
||||||
|
# Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.to_dict
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Serializes this instance to a Python dictionary. This method calls the superclass method and then removes the
|
||||||
|
`_max_size` attribute from the dictionary.
|
||||||
|
"""
|
||||||
|
image_processor_dict = super().to_dict()
|
||||||
|
image_processor_dict.pop("_max_size", None)
|
||||||
|
return image_processor_dict
|
||||||
|
|
||||||
|
@deprecate_kwarg("max_size", version="4.27.0", warn_if_greater_or_equal_version=True)
|
||||||
|
@filter_out_non_signature_kwargs(extra=["max_size"])
|
||||||
def resize(
|
def resize(
|
||||||
self,
|
self,
|
||||||
image: np.ndarray,
|
image: np.ndarray,
|
||||||
@@ -506,15 +507,10 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
|||||||
Resize the image to the given size. Size can be min_size (scalar) or `(height, width)` tuple. If size is an
|
Resize the image to the given size. Size can be min_size (scalar) or `(height, width)` tuple. If size is an
|
||||||
int, smaller edge of the image will be matched to this number.
|
int, smaller edge of the image will be matched to this number.
|
||||||
"""
|
"""
|
||||||
if "max_size" in kwargs:
|
|
||||||
warnings.warn(
|
# Deprecated, backward compatibility
|
||||||
"The `max_size` parameter is deprecated and will be removed in v4.27. "
|
max_size = kwargs.pop("max_size", None)
|
||||||
"Please specify in `size['longest_edge'] instead`.",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
max_size = kwargs.pop("max_size")
|
|
||||||
else:
|
|
||||||
max_size = None
|
|
||||||
size = get_size_dict(size, max_size=max_size, default_to_square=False)
|
size = get_size_dict(size, max_size=max_size, default_to_square=False)
|
||||||
if "shortest_edge" in size and "longest_edge" in size:
|
if "shortest_edge" in size and "longest_edge" in size:
|
||||||
size, max_size = size["shortest_edge"], size["longest_edge"]
|
size, max_size = size["shortest_edge"], size["longest_edge"]
|
||||||
@@ -569,15 +565,15 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
|||||||
segmentation_map: "np.ndarray",
|
segmentation_map: "np.ndarray",
|
||||||
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
|
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
|
||||||
ignore_index: Optional[int] = None,
|
ignore_index: Optional[int] = None,
|
||||||
reduce_labels: bool = False,
|
do_reduce_labels: bool = False,
|
||||||
):
|
):
|
||||||
reduce_labels = reduce_labels if reduce_labels is not None else self.reduce_labels
|
do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
|
||||||
ignore_index = ignore_index if ignore_index is not None else self.ignore_index
|
ignore_index = ignore_index if ignore_index is not None else self.ignore_index
|
||||||
return convert_segmentation_map_to_binary_masks(
|
return convert_segmentation_map_to_binary_masks(
|
||||||
segmentation_map=segmentation_map,
|
segmentation_map=segmentation_map,
|
||||||
instance_id_to_semantic_id=instance_id_to_semantic_id,
|
instance_id_to_semantic_id=instance_id_to_semantic_id,
|
||||||
ignore_index=ignore_index,
|
ignore_index=ignore_index,
|
||||||
reduce_labels=reduce_labels,
|
do_reduce_labels=do_reduce_labels,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, images, task_inputs=None, segmentation_maps=None, **kwargs) -> BatchFeature:
|
def __call__(self, images, task_inputs=None, segmentation_maps=None, **kwargs) -> BatchFeature:
|
||||||
@@ -679,6 +675,7 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
|||||||
segmentation_map = segmentation_map.squeeze(0)
|
segmentation_map = segmentation_map.squeeze(0)
|
||||||
return segmentation_map
|
return segmentation_map
|
||||||
|
|
||||||
|
@filter_out_non_signature_kwargs()
|
||||||
def preprocess(
|
def preprocess(
|
||||||
self,
|
self,
|
||||||
images: ImageInput,
|
images: ImageInput,
|
||||||
@@ -698,26 +695,7 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
|||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
||||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
**kwargs,
|
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
if "pad_and_return_pixel_mask" in kwargs:
|
|
||||||
warnings.warn(
|
|
||||||
"The `pad_and_return_pixel_mask` argument is deprecated and will be removed in v4.27",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
if "reduce_labels" in kwargs:
|
|
||||||
warnings.warn(
|
|
||||||
"The `reduce_labels` argument is deprecated and will be removed in a v4.27. Please use"
|
|
||||||
" `do_reduce_labels` instead.",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
if do_reduce_labels is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot use both `reduce_labels` and `do_reduce_labels` arguments. Please use"
|
|
||||||
" `do_reduce_labels` instead."
|
|
||||||
)
|
|
||||||
do_reduce_labels = kwargs.pop("reduce_labels")
|
|
||||||
|
|
||||||
if task_inputs is None:
|
if task_inputs is None:
|
||||||
# Default value
|
# Default value
|
||||||
task_inputs = ["panoptic"]
|
task_inputs = ["panoptic"]
|
||||||
@@ -740,8 +718,6 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
|||||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
)
|
)
|
||||||
|
|
||||||
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
|
|
||||||
|
|
||||||
validate_preprocess_arguments(
|
validate_preprocess_arguments(
|
||||||
do_rescale=do_rescale,
|
do_rescale=do_rescale,
|
||||||
rescale_factor=rescale_factor,
|
rescale_factor=rescale_factor,
|
||||||
@@ -988,7 +964,7 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
|||||||
segmentation_maps: ImageInput = None,
|
segmentation_maps: ImageInput = None,
|
||||||
instance_id_to_semantic_id: Optional[Union[List[Dict[int, int]], Dict[int, int]]] = None,
|
instance_id_to_semantic_id: Optional[Union[List[Dict[int, int]], Dict[int, int]]] = None,
|
||||||
ignore_index: Optional[int] = None,
|
ignore_index: Optional[int] = None,
|
||||||
reduce_labels: bool = False,
|
do_reduce_labels: bool = False,
|
||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
):
|
):
|
||||||
@@ -1049,7 +1025,7 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
|||||||
provided). They identify the binary masks present in the image.
|
provided). They identify the binary masks present in the image.
|
||||||
"""
|
"""
|
||||||
ignore_index = self.ignore_index if ignore_index is None else ignore_index
|
ignore_index = self.ignore_index if ignore_index is None else ignore_index
|
||||||
reduce_labels = self.do_reduce_labels if reduce_labels is None else reduce_labels
|
do_reduce_labels = self.do_reduce_labels if do_reduce_labels is None else do_reduce_labels
|
||||||
pixel_values_list = [to_numpy_array(pixel_values) for pixel_values in pixel_values_list]
|
pixel_values_list = [to_numpy_array(pixel_values) for pixel_values in pixel_values_list]
|
||||||
|
|
||||||
if input_data_format is None:
|
if input_data_format is None:
|
||||||
@@ -1072,7 +1048,7 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
|||||||
instance_id = instance_id_to_semantic_id
|
instance_id = instance_id_to_semantic_id
|
||||||
# Use instance2class_id mapping per image
|
# Use instance2class_id mapping per image
|
||||||
masks, classes = self.convert_segmentation_map_to_binary_masks(
|
masks, classes = self.convert_segmentation_map_to_binary_masks(
|
||||||
segmentation_map, instance_id, ignore_index=ignore_index, reduce_labels=reduce_labels
|
segmentation_map, instance_id, ignore_index=ignore_index, do_reduce_labels=do_reduce_labels
|
||||||
)
|
)
|
||||||
annotations.append({"masks": masks, "classes": classes})
|
annotations.append({"masks": masks, "classes": classes})
|
||||||
|
|
||||||
|
|||||||
@@ -14,12 +14,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Image processor class for Segformer."""
|
"""Image processor class for Segformer."""
|
||||||
|
|
||||||
import warnings
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
from ...image_processing_utils import INIT_SERVICE_KWARGS, BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
from ...image_transforms import resize, to_channel_dimension_format
|
from ...image_transforms import resize, to_channel_dimension_format
|
||||||
from ...image_utils import (
|
from ...image_utils import (
|
||||||
IMAGENET_DEFAULT_MEAN,
|
IMAGENET_DEFAULT_MEAN,
|
||||||
@@ -32,10 +31,17 @@ from ...image_utils import (
|
|||||||
make_list_of_images,
|
make_list_of_images,
|
||||||
to_numpy_array,
|
to_numpy_array,
|
||||||
valid_images,
|
valid_images,
|
||||||
validate_kwargs,
|
|
||||||
validate_preprocess_arguments,
|
validate_preprocess_arguments,
|
||||||
)
|
)
|
||||||
from ...utils import TensorType, is_torch_available, is_torch_tensor, is_vision_available, logging
|
from ...utils import (
|
||||||
|
TensorType,
|
||||||
|
filter_out_non_signature_kwargs,
|
||||||
|
is_torch_available,
|
||||||
|
is_torch_tensor,
|
||||||
|
is_vision_available,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
from ...utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
@@ -86,6 +92,8 @@ class SegformerImageProcessor(BaseImageProcessor):
|
|||||||
|
|
||||||
model_input_names = ["pixel_values"]
|
model_input_names = ["pixel_values"]
|
||||||
|
|
||||||
|
@deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.41.0")
|
||||||
|
@filter_out_non_signature_kwargs(extra=INIT_SERVICE_KWARGS)
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
do_resize: bool = True,
|
do_resize: bool = True,
|
||||||
@@ -99,14 +107,6 @@ class SegformerImageProcessor(BaseImageProcessor):
|
|||||||
do_reduce_labels: bool = False,
|
do_reduce_labels: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
if "reduce_labels" in kwargs:
|
|
||||||
warnings.warn(
|
|
||||||
"The `reduce_labels` parameter is deprecated and will be removed in a future version. Please use "
|
|
||||||
"`do_reduce_labels` instead.",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
do_reduce_labels = kwargs.pop("reduce_labels")
|
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
size = size if size is not None else {"height": 512, "width": 512}
|
size = size if size is not None else {"height": 512, "width": 512}
|
||||||
size = get_size_dict(size)
|
size = get_size_dict(size)
|
||||||
@@ -119,33 +119,15 @@ class SegformerImageProcessor(BaseImageProcessor):
|
|||||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
||||||
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
||||||
self.do_reduce_labels = do_reduce_labels
|
self.do_reduce_labels = do_reduce_labels
|
||||||
self._valid_processor_keys = [
|
|
||||||
"images",
|
|
||||||
"segmentation_maps",
|
|
||||||
"do_resize",
|
|
||||||
"size",
|
|
||||||
"resample",
|
|
||||||
"do_rescale",
|
|
||||||
"rescale_factor",
|
|
||||||
"do_normalize",
|
|
||||||
"image_mean",
|
|
||||||
"image_std",
|
|
||||||
"do_reduce_labels",
|
|
||||||
"return_tensors",
|
|
||||||
"data_format",
|
|
||||||
"input_data_format",
|
|
||||||
]
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
|
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
|
||||||
"""
|
"""
|
||||||
Overrides the `from_dict` method from the base class to make sure `do_reduce_labels` is updated if image
|
Overrides the `from_dict` method from the base class to save support of deprecated `reduce_labels` in old configs
|
||||||
processor is created using from_dict and kwargs e.g. `SegformerImageProcessor.from_pretrained(checkpoint,
|
|
||||||
reduce_labels=True)`
|
|
||||||
"""
|
"""
|
||||||
image_processor_dict = image_processor_dict.copy()
|
image_processor_dict = image_processor_dict.copy()
|
||||||
if "reduce_labels" in kwargs:
|
if "reduce_labels" in image_processor_dict:
|
||||||
image_processor_dict["reduce_labels"] = kwargs.pop("reduce_labels")
|
image_processor_dict["do_reduce_labels"] = image_processor_dict.pop("reduce_labels")
|
||||||
return super().from_dict(image_processor_dict, **kwargs)
|
return super().from_dict(image_processor_dict, **kwargs)
|
||||||
|
|
||||||
# Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize
|
# Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize
|
||||||
@@ -320,6 +302,8 @@ class SegformerImageProcessor(BaseImageProcessor):
|
|||||||
"""
|
"""
|
||||||
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
|
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
|
||||||
|
|
||||||
|
@deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.41.0")
|
||||||
|
@filter_out_non_signature_kwargs()
|
||||||
def preprocess(
|
def preprocess(
|
||||||
self,
|
self,
|
||||||
images: ImageInput,
|
images: ImageInput,
|
||||||
@@ -336,7 +320,6 @@ class SegformerImageProcessor(BaseImageProcessor):
|
|||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
**kwargs,
|
|
||||||
) -> PIL.Image.Image:
|
) -> PIL.Image.Image:
|
||||||
"""
|
"""
|
||||||
Preprocess an image or batch of images.
|
Preprocess an image or batch of images.
|
||||||
@@ -398,8 +381,6 @@ class SegformerImageProcessor(BaseImageProcessor):
|
|||||||
|
|
||||||
images = make_list_of_images(images)
|
images = make_list_of_images(images)
|
||||||
|
|
||||||
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
|
|
||||||
|
|
||||||
if segmentation_maps is not None:
|
if segmentation_maps is not None:
|
||||||
segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)
|
segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)
|
||||||
|
|
||||||
|
|||||||
@@ -212,7 +212,7 @@ class ViltImageProcessor(BaseImageProcessor):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
|
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
|
||||||
"""
|
"""
|
||||||
Overrides the `from_dict` method from the base class to make sure `reduce_labels` is updated if image processor
|
Overrides the `from_dict` method from the base class to make sure `pad_and_return_pixel_mask` is updated if image processor
|
||||||
is created using from_dict and kwargs e.g. `ViltImageProcessor.from_pretrained(checkpoint,
|
is created using from_dict and kwargs e.g. `ViltImageProcessor.from_pretrained(checkpoint,
|
||||||
pad_and_return_pixel_mask=False)`
|
pad_and_return_pixel_mask=False)`
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ from .generic import (
|
|||||||
cached_property,
|
cached_property,
|
||||||
can_return_loss,
|
can_return_loss,
|
||||||
expand_dims,
|
expand_dims,
|
||||||
|
filter_out_non_signature_kwargs,
|
||||||
find_labels,
|
find_labels,
|
||||||
flatten_dict,
|
flatten_dict,
|
||||||
infer_framework,
|
infer_framework,
|
||||||
|
|||||||
169
src/transformers/utils/deprecation.py
Normal file
169
src/transformers/utils/deprecation.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import inspect
|
||||||
|
import warnings
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import packaging.version
|
||||||
|
|
||||||
|
from .. import __version__
|
||||||
|
from . import ExplicitEnum
|
||||||
|
|
||||||
|
|
||||||
|
class Action(ExplicitEnum):
|
||||||
|
NONE = "none"
|
||||||
|
NOTIFY = "notify"
|
||||||
|
NOTIFY_ALWAYS = "notify_always"
|
||||||
|
RAISE = "raise"
|
||||||
|
|
||||||
|
|
||||||
|
def deprecate_kwarg(
|
||||||
|
old_name: str,
|
||||||
|
version: str,
|
||||||
|
new_name: Optional[str] = None,
|
||||||
|
warn_if_greater_or_equal_version: bool = False,
|
||||||
|
raise_if_greater_or_equal_version: bool = False,
|
||||||
|
raise_if_both_names: bool = False,
|
||||||
|
additional_message: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Function or method decorator to notify users about deprecated keyword arguments, replacing them with a new name if specified.
|
||||||
|
|
||||||
|
This decorator allows you to:
|
||||||
|
- Notify users when a keyword argument is deprecated.
|
||||||
|
- Automatically replace deprecated keyword arguments with new ones.
|
||||||
|
- Raise an error if deprecated arguments are used, depending on the specified conditions.
|
||||||
|
|
||||||
|
By default, the decorator notifies the user about the deprecated argument while the `transformers.__version__` < specified `version`
|
||||||
|
in the decorator. To keep notifications with any version `warn_if_greater_or_equal_version=True` can be set.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
old_name (`str`):
|
||||||
|
Name of the deprecated keyword argument.
|
||||||
|
version (`str`):
|
||||||
|
The version in which the keyword argument was (or will be) deprecated.
|
||||||
|
new_name (`Optional[str]`, *optional*):
|
||||||
|
The new name for the deprecated keyword argument. If specified, the deprecated keyword argument will be replaced with this new name.
|
||||||
|
warn_if_greater_or_equal_version (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to show warning if current `transformers` version is greater or equal to the deprecated version.
|
||||||
|
raise_if_greater_or_equal_version (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to raise `ValueError` if current `transformers` version is greater or equal to the deprecated version.
|
||||||
|
raise_if_both_names (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to raise `ValueError` if both deprecated and new keyword arguments are set.
|
||||||
|
additional_message (`Optional[str]`, *optional*):
|
||||||
|
An additional message to append to the default deprecation message.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError:
|
||||||
|
If raise_if_greater_or_equal_version is True and the current version is greater than or equal to the deprecated version, or if raise_if_both_names is True and both old and new keyword arguments are provided.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable:
|
||||||
|
A wrapped function that handles the deprecated keyword arguments according to the specified parameters.
|
||||||
|
|
||||||
|
Example usage with renaming argument:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="6.0.0")
|
||||||
|
def my_function(do_reduce_labels):
|
||||||
|
print(do_reduce_labels)
|
||||||
|
|
||||||
|
my_function(reduce_labels=True) # Will show a deprecation warning and use do_reduce_labels=True
|
||||||
|
```
|
||||||
|
|
||||||
|
Example usage without renaming argument:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@deprecate_kwarg("max_size", version="6.0.0")
|
||||||
|
def my_function(max_size):
|
||||||
|
print(max_size)
|
||||||
|
|
||||||
|
my_function(max_size=1333) # Will show a deprecation warning
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
deprecated_version = packaging.version.parse(version)
|
||||||
|
current_version = packaging.version.parse(__version__)
|
||||||
|
is_greater_or_equal_version = current_version >= deprecated_version
|
||||||
|
|
||||||
|
if is_greater_or_equal_version:
|
||||||
|
version_message = f"and removed starting from version {version}"
|
||||||
|
else:
|
||||||
|
version_message = f"and will be removed in version {version}"
|
||||||
|
|
||||||
|
def wrapper(func):
|
||||||
|
# Required for better warning message
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
function_named_args = set(sig.parameters.keys())
|
||||||
|
is_instance_method = "self" in function_named_args
|
||||||
|
is_class_method = "cls" in function_named_args
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapped_func(*args, **kwargs):
|
||||||
|
# Get class + function name (just for better warning message)
|
||||||
|
func_name = func.__name__
|
||||||
|
if is_instance_method:
|
||||||
|
func_name = f"{args[0].__class__.__name__}.{func_name}"
|
||||||
|
elif is_class_method:
|
||||||
|
func_name = f"{args[0].__name__}.{func_name}"
|
||||||
|
|
||||||
|
minimum_action = Action.NONE
|
||||||
|
message = None
|
||||||
|
|
||||||
|
# deprecated kwarg and its new version are set for function call -> replace it with new name
|
||||||
|
if old_name in kwargs and new_name in kwargs:
|
||||||
|
minimum_action = Action.RAISE if raise_if_both_names else Action.NOTIFY_ALWAYS
|
||||||
|
message = f"Both `{old_name}` and `{new_name}` are set for `{func_name}`. Using `{new_name}={kwargs[new_name]}` and ignoring deprecated `{old_name}={kwargs[old_name]}`."
|
||||||
|
kwargs.pop(old_name)
|
||||||
|
|
||||||
|
# only deprecated kwarg is set for function call -> replace it with new name
|
||||||
|
elif old_name in kwargs and new_name is not None and new_name not in kwargs:
|
||||||
|
minimum_action = Action.NOTIFY
|
||||||
|
message = f"`{old_name}` is deprecated {version_message} for `{func_name}`. Use `{new_name}` instead."
|
||||||
|
kwargs[new_name] = kwargs.pop(old_name)
|
||||||
|
|
||||||
|
# deprecated kwarg is not set for function call and new name is not specified -> just notify
|
||||||
|
elif old_name in kwargs:
|
||||||
|
minimum_action = Action.NOTIFY
|
||||||
|
message = f"`{old_name}` is deprecated {version_message} for `{func_name}`."
|
||||||
|
|
||||||
|
if message is not None and additional_message is not None:
|
||||||
|
message = f"{message} {additional_message}"
|
||||||
|
|
||||||
|
# update minimum_action if argument is ALREADY deprecated (current version >= deprecated version)
|
||||||
|
if is_greater_or_equal_version:
|
||||||
|
# change to (NOTIFY, NOTIFY_ALWAYS) -> RAISE if specified
|
||||||
|
# in case we want to raise error for already deprecated arguments
|
||||||
|
if raise_if_greater_or_equal_version and minimum_action != Action.NONE:
|
||||||
|
minimum_action = Action.RAISE
|
||||||
|
|
||||||
|
# change to NOTIFY -> NONE if specified (NOTIFY_ALWAYS can't be changed to NONE)
|
||||||
|
# in case we want to ignore notifications for already deprecated arguments
|
||||||
|
elif not warn_if_greater_or_equal_version and minimum_action == Action.NOTIFY:
|
||||||
|
minimum_action = Action.NONE
|
||||||
|
|
||||||
|
# raise error or notify user
|
||||||
|
if minimum_action == Action.RAISE:
|
||||||
|
raise ValueError(message)
|
||||||
|
elif minimum_action in (Action.NOTIFY, Action.NOTIFY_ALWAYS):
|
||||||
|
# DeprecationWarning is ignored by default, so we use FutureWarning instead
|
||||||
|
warnings.warn(message, FutureWarning, stacklevel=2)
|
||||||
|
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapped_func
|
||||||
|
|
||||||
|
return wrapper
|
||||||
@@ -17,13 +17,14 @@ Generic utilities
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import warnings
|
||||||
from collections import OrderedDict, UserDict
|
from collections import OrderedDict, UserDict
|
||||||
from collections.abc import MutableMapping
|
from collections.abc import MutableMapping
|
||||||
from contextlib import ExitStack, contextmanager
|
from contextlib import ExitStack, contextmanager
|
||||||
from dataclasses import fields, is_dataclass
|
from dataclasses import fields, is_dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial, wraps
|
||||||
from typing import Any, ContextManager, Iterable, List, Tuple
|
from typing import Any, ContextManager, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@@ -750,3 +751,79 @@ def infer_framework(model_class):
|
|||||||
return "flax"
|
return "flax"
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Could not infer framework from class {model_class}.")
|
raise TypeError(f"Could not infer framework from class {model_class}.")
|
||||||
|
|
||||||
|
|
||||||
|
def filter_out_non_signature_kwargs(extra: Optional[list] = None):
|
||||||
|
"""
|
||||||
|
Decorator to filter out named arguments that are not in the function signature.
|
||||||
|
|
||||||
|
This decorator ensures that only the keyword arguments that match the function's signature, or are specified in the
|
||||||
|
`extra` list, are passed to the function. Any additional keyword arguments are filtered out and a warning is issued.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
extra (`Optional[list]`, *optional*):
|
||||||
|
A list of extra keyword argument names that are allowed even if they are not in the function's signature.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable:
|
||||||
|
A decorator that wraps the function and filters out invalid keyword arguments.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@filter_out_non_signature_kwargs(extra=["allowed_extra_arg"])
|
||||||
|
def my_function(arg1, arg2, **kwargs):
|
||||||
|
print(arg1, arg2, kwargs)
|
||||||
|
|
||||||
|
my_function(arg1=1, arg2=2, allowed_extra_arg=3, invalid_arg=4)
|
||||||
|
# This will print: 1 2 {"allowed_extra_arg": 3}
|
||||||
|
# And issue a warning: "The following named arguments are not valid for `my_function` and were ignored: 'invalid_arg'"
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
extra = extra or []
|
||||||
|
extra_params_to_pass = set(extra)
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
function_named_args = set(sig.parameters.keys())
|
||||||
|
valid_kwargs_to_pass = function_named_args.union(extra_params_to_pass)
|
||||||
|
|
||||||
|
# Required for better warning message
|
||||||
|
is_instance_method = "self" in function_named_args
|
||||||
|
is_class_method = "cls" in function_named_args
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
valid_kwargs = {}
|
||||||
|
invalid_kwargs = {}
|
||||||
|
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if k in valid_kwargs_to_pass:
|
||||||
|
valid_kwargs[k] = v
|
||||||
|
else:
|
||||||
|
invalid_kwargs[k] = v
|
||||||
|
|
||||||
|
if invalid_kwargs:
|
||||||
|
invalid_kwargs_names = [f"'{k}'" for k in invalid_kwargs.keys()]
|
||||||
|
invalid_kwargs_names = ", ".join(invalid_kwargs_names)
|
||||||
|
|
||||||
|
# Get the class name for better warning message
|
||||||
|
if is_instance_method:
|
||||||
|
cls_prefix = args[0].__class__.__name__ + "."
|
||||||
|
elif is_class_method:
|
||||||
|
cls_prefix = args[0].__name__ + "."
|
||||||
|
else:
|
||||||
|
cls_prefix = ""
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
|
f"The following named arguments are not valid for `{cls_prefix}{func.__name__}`"
|
||||||
|
f" and were ignored: {invalid_kwargs_names}",
|
||||||
|
UserWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
return func(*args, **valid_kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|||||||
@@ -136,6 +136,7 @@ class BeitImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||||
|
self.assertTrue(hasattr(image_processing, "do_reduce_labels"))
|
||||||
|
|
||||||
def test_image_processor_from_dict_with_kwargs(self):
|
def test_image_processor_from_dict_with_kwargs(self):
|
||||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||||
@@ -144,7 +145,7 @@ class BeitImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
self.assertEqual(image_processor.do_reduce_labels, False)
|
self.assertEqual(image_processor.do_reduce_labels, False)
|
||||||
|
|
||||||
image_processor = self.image_processing_class.from_dict(
|
image_processor = self.image_processing_class.from_dict(
|
||||||
self.image_processor_dict, size=42, crop_size=84, reduce_labels=True
|
self.image_processor_dict, size=42, crop_size=84, do_reduce_labels=True
|
||||||
)
|
)
|
||||||
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
|
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
|
||||||
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
|
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
|
||||||
@@ -270,3 +271,16 @@ class BeitImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
encoding = image_processing(image, map, return_tensors="pt")
|
encoding = image_processing(image, map, return_tensors="pt")
|
||||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||||
|
|
||||||
|
def test_removed_deprecated_kwargs(self):
|
||||||
|
image_processor_dict = dict(self.image_processor_dict)
|
||||||
|
image_processor_dict.pop("do_reduce_labels", None)
|
||||||
|
image_processor_dict["reduce_labels"] = True
|
||||||
|
|
||||||
|
# test we are able to create the image processor with the deprecated kwargs
|
||||||
|
image_processor = self.image_processing_class(**image_processor_dict)
|
||||||
|
self.assertEqual(image_processor.do_reduce_labels, True)
|
||||||
|
|
||||||
|
# test we still support reduce_labels with config
|
||||||
|
image_processor = self.image_processing_class.from_dict(image_processor_dict)
|
||||||
|
self.assertEqual(image_processor.do_reduce_labels, True)
|
||||||
|
|||||||
@@ -274,7 +274,7 @@ class Mask2FormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
|
|||||||
instance_seg2, inst2class2 = get_instance_segmentation_and_mapping(annotation2)
|
instance_seg2, inst2class2 = get_instance_segmentation_and_mapping(annotation2)
|
||||||
|
|
||||||
# create a image processor
|
# create a image processor
|
||||||
image_processing = Mask2FormerImageProcessor(reduce_labels=True, ignore_index=255, size=(512, 512))
|
image_processing = Mask2FormerImageProcessor(do_reduce_labels=True, ignore_index=255, size=(512, 512))
|
||||||
|
|
||||||
# prepare the images and annotations
|
# prepare the images and annotations
|
||||||
inputs = image_processing(
|
inputs = image_processing(
|
||||||
@@ -317,7 +317,7 @@ class Mask2FormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
|
|||||||
)
|
)
|
||||||
|
|
||||||
# create a image processor
|
# create a image processor
|
||||||
image_processing = Mask2FormerImageProcessor(reduce_labels=True, ignore_index=255, size=(512, 512))
|
image_processing = Mask2FormerImageProcessor(do_reduce_labels=True, ignore_index=255, size=(512, 512))
|
||||||
|
|
||||||
# prepare the images and annotations
|
# prepare the images and annotations
|
||||||
inputs = image_processing(
|
inputs = image_processing(
|
||||||
@@ -490,3 +490,16 @@ class Mask2FormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
|
|||||||
expected_num_segments = max([el["id"] for el in el_unfused]) - num_to_fuse
|
expected_num_segments = max([el["id"] for el in el_unfused]) - num_to_fuse
|
||||||
num_segments_fused = max([el["id"] for el in el_fused])
|
num_segments_fused = max([el["id"] for el in el_fused])
|
||||||
self.assertEqual(num_segments_fused, expected_num_segments)
|
self.assertEqual(num_segments_fused, expected_num_segments)
|
||||||
|
|
||||||
|
def test_removed_deprecated_kwargs(self):
|
||||||
|
image_processor_dict = dict(self.image_processor_dict)
|
||||||
|
image_processor_dict.pop("do_reduce_labels", None)
|
||||||
|
image_processor_dict["reduce_labels"] = True
|
||||||
|
|
||||||
|
# test we are able to create the image processor with the deprecated kwargs
|
||||||
|
image_processor = self.image_processing_class(**image_processor_dict)
|
||||||
|
self.assertEqual(image_processor.do_reduce_labels, True)
|
||||||
|
|
||||||
|
# test we still support reduce_labels with config
|
||||||
|
image_processor = self.image_processing_class.from_dict(image_processor_dict)
|
||||||
|
self.assertEqual(image_processor.do_reduce_labels, True)
|
||||||
|
|||||||
@@ -274,7 +274,7 @@ class MaskFormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
|
|||||||
instance_seg2, inst2class2 = get_instance_segmentation_and_mapping(annotation2)
|
instance_seg2, inst2class2 = get_instance_segmentation_and_mapping(annotation2)
|
||||||
|
|
||||||
# create a image processor
|
# create a image processor
|
||||||
image_processing = MaskFormerImageProcessor(reduce_labels=True, ignore_index=255, size=(512, 512))
|
image_processing = MaskFormerImageProcessor(do_reduce_labels=True, ignore_index=255, size=(512, 512))
|
||||||
|
|
||||||
# prepare the images and annotations
|
# prepare the images and annotations
|
||||||
inputs = image_processing(
|
inputs = image_processing(
|
||||||
@@ -317,7 +317,7 @@ class MaskFormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
|
|||||||
)
|
)
|
||||||
|
|
||||||
# create a image processor
|
# create a image processor
|
||||||
image_processing = MaskFormerImageProcessor(reduce_labels=True, ignore_index=255, size=(512, 512))
|
image_processing = MaskFormerImageProcessor(do_reduce_labels=True, ignore_index=255, size=(512, 512))
|
||||||
|
|
||||||
# prepare the images and annotations
|
# prepare the images and annotations
|
||||||
inputs = image_processing(
|
inputs = image_processing(
|
||||||
@@ -525,3 +525,16 @@ class MaskFormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
|
|||||||
expected_num_segments = max([el["id"] for el in el_unfused]) - num_to_fuse
|
expected_num_segments = max([el["id"] for el in el_unfused]) - num_to_fuse
|
||||||
num_segments_fused = max([el["id"] for el in el_fused])
|
num_segments_fused = max([el["id"] for el in el_fused])
|
||||||
self.assertEqual(num_segments_fused, expected_num_segments)
|
self.assertEqual(num_segments_fused, expected_num_segments)
|
||||||
|
|
||||||
|
def test_removed_deprecated_kwargs(self):
|
||||||
|
image_processor_dict = dict(self.image_processor_dict)
|
||||||
|
image_processor_dict.pop("do_reduce_labels", None)
|
||||||
|
image_processor_dict["reduce_labels"] = True
|
||||||
|
|
||||||
|
# test we are able to create the image processor with the deprecated kwargs
|
||||||
|
image_processor = self.image_processing_class(**image_processor_dict)
|
||||||
|
self.assertEqual(image_processor.do_reduce_labels, True)
|
||||||
|
|
||||||
|
# test we still support reduce_labels with config
|
||||||
|
image_processor = self.image_processing_class.from_dict(image_processor_dict)
|
||||||
|
self.assertEqual(image_processor.do_reduce_labels, True)
|
||||||
|
|||||||
@@ -349,3 +349,16 @@ class OneFormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
image_processor = self.image_processing_class(**config_dict)
|
image_processor = self.image_processing_class(**config_dict)
|
||||||
|
|
||||||
self.assertEqual(image_processor.metadata, metadata)
|
self.assertEqual(image_processor.metadata, metadata)
|
||||||
|
|
||||||
|
def test_removed_deprecated_kwargs(self):
|
||||||
|
image_processor_dict = dict(self.image_processor_dict)
|
||||||
|
image_processor_dict.pop("do_reduce_labels", None)
|
||||||
|
image_processor_dict["reduce_labels"] = True
|
||||||
|
|
||||||
|
# test we are able to create the image processor with the deprecated kwargs
|
||||||
|
image_processor = self.image_processing_class(**image_processor_dict)
|
||||||
|
self.assertEqual(image_processor.do_reduce_labels, True)
|
||||||
|
|
||||||
|
# test we still support reduce_labels with config
|
||||||
|
image_processor = self.image_processing_class.from_dict(image_processor_dict)
|
||||||
|
self.assertEqual(image_processor.do_reduce_labels, True)
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ class OneFormerProcessorTester(unittest.TestCase):
|
|||||||
image_mean=[0.5, 0.5, 0.5],
|
image_mean=[0.5, 0.5, 0.5],
|
||||||
image_std=[0.5, 0.5, 0.5],
|
image_std=[0.5, 0.5, 0.5],
|
||||||
num_labels=10,
|
num_labels=10,
|
||||||
reduce_labels=False,
|
do_reduce_labels=False,
|
||||||
ignore_index=255,
|
ignore_index=255,
|
||||||
max_seq_length=77,
|
max_seq_length=77,
|
||||||
task_seq_length=77,
|
task_seq_length=77,
|
||||||
@@ -105,7 +105,7 @@ class OneFormerProcessorTester(unittest.TestCase):
|
|||||||
self.height = 3
|
self.height = 3
|
||||||
self.width = 4
|
self.width = 4
|
||||||
self.num_labels = num_labels
|
self.num_labels = num_labels
|
||||||
self.reduce_labels = reduce_labels
|
self.do_reduce_labels = do_reduce_labels
|
||||||
self.ignore_index = ignore_index
|
self.ignore_index = ignore_index
|
||||||
|
|
||||||
def prepare_processor_dict(self):
|
def prepare_processor_dict(self):
|
||||||
@@ -116,7 +116,7 @@ class OneFormerProcessorTester(unittest.TestCase):
|
|||||||
"image_mean": self.image_mean,
|
"image_mean": self.image_mean,
|
||||||
"image_std": self.image_std,
|
"image_std": self.image_std,
|
||||||
"num_labels": self.num_labels,
|
"num_labels": self.num_labels,
|
||||||
"reduce_labels": self.reduce_labels,
|
"do_reduce_labels": self.do_reduce_labels,
|
||||||
"ignore_index": self.ignore_index,
|
"ignore_index": self.ignore_index,
|
||||||
"class_info_file": self.class_info_file,
|
"class_info_file": self.class_info_file,
|
||||||
"metadata": self.metadata,
|
"metadata": self.metadata,
|
||||||
@@ -465,7 +465,7 @@ class OneFormerProcessingTest(unittest.TestCase):
|
|||||||
panoptic_map2, inst2class2 = create_panoptic_map(annotation2, segments_info2)
|
panoptic_map2, inst2class2 = create_panoptic_map(annotation2, segments_info2)
|
||||||
|
|
||||||
image_processor = OneFormerImageProcessor(
|
image_processor = OneFormerImageProcessor(
|
||||||
reduce_labels=True,
|
do_reduce_labels=True,
|
||||||
ignore_index=0,
|
ignore_index=0,
|
||||||
size=(512, 512),
|
size=(512, 512),
|
||||||
class_info_file="ade20k_panoptic.json",
|
class_info_file="ade20k_panoptic.json",
|
||||||
@@ -553,7 +553,7 @@ class OneFormerProcessingTest(unittest.TestCase):
|
|||||||
panoptic_map2, inst2class2 = create_panoptic_map(annotation2, segments_info2)
|
panoptic_map2, inst2class2 = create_panoptic_map(annotation2, segments_info2)
|
||||||
|
|
||||||
image_processor = OneFormerImageProcessor(
|
image_processor = OneFormerImageProcessor(
|
||||||
reduce_labels=True,
|
do_reduce_labels=True,
|
||||||
ignore_index=0,
|
ignore_index=0,
|
||||||
size=(512, 512),
|
size=(512, 512),
|
||||||
class_info_file="ade20k_panoptic.json",
|
class_info_file="ade20k_panoptic.json",
|
||||||
@@ -641,7 +641,7 @@ class OneFormerProcessingTest(unittest.TestCase):
|
|||||||
panoptic_map2, inst2class2 = create_panoptic_map(annotation2, segments_info2)
|
panoptic_map2, inst2class2 = create_panoptic_map(annotation2, segments_info2)
|
||||||
|
|
||||||
image_processor = OneFormerImageProcessor(
|
image_processor = OneFormerImageProcessor(
|
||||||
reduce_labels=True,
|
do_reduce_labels=True,
|
||||||
ignore_index=0,
|
ignore_index=0,
|
||||||
size=(512, 512),
|
size=(512, 512),
|
||||||
class_info_file="ade20k_panoptic.json",
|
class_info_file="ade20k_panoptic.json",
|
||||||
@@ -710,7 +710,7 @@ class OneFormerProcessingTest(unittest.TestCase):
|
|||||||
|
|
||||||
def test_post_process_semantic_segmentation(self):
|
def test_post_process_semantic_segmentation(self):
|
||||||
image_processor = OneFormerImageProcessor(
|
image_processor = OneFormerImageProcessor(
|
||||||
reduce_labels=True,
|
do_reduce_labels=True,
|
||||||
ignore_index=0,
|
ignore_index=0,
|
||||||
size=(512, 512),
|
size=(512, 512),
|
||||||
class_info_file="ade20k_panoptic.json",
|
class_info_file="ade20k_panoptic.json",
|
||||||
@@ -744,7 +744,7 @@ class OneFormerProcessingTest(unittest.TestCase):
|
|||||||
|
|
||||||
def test_post_process_instance_segmentation(self):
|
def test_post_process_instance_segmentation(self):
|
||||||
image_processor = OneFormerImageProcessor(
|
image_processor = OneFormerImageProcessor(
|
||||||
reduce_labels=True,
|
do_reduce_labels=True,
|
||||||
ignore_index=0,
|
ignore_index=0,
|
||||||
size=(512, 512),
|
size=(512, 512),
|
||||||
class_info_file="ade20k_panoptic.json",
|
class_info_file="ade20k_panoptic.json",
|
||||||
@@ -770,7 +770,7 @@ class OneFormerProcessingTest(unittest.TestCase):
|
|||||||
|
|
||||||
def test_post_process_panoptic_segmentation(self):
|
def test_post_process_panoptic_segmentation(self):
|
||||||
image_processor = OneFormerImageProcessor(
|
image_processor = OneFormerImageProcessor(
|
||||||
reduce_labels=True,
|
do_reduce_labels=True,
|
||||||
ignore_index=0,
|
ignore_index=0,
|
||||||
size=(512, 512),
|
size=(512, 512),
|
||||||
class_info_file="ade20k_panoptic.json",
|
class_info_file="ade20k_panoptic.json",
|
||||||
|
|||||||
@@ -132,7 +132,9 @@ class SegformerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
self.assertEqual(image_processor.size, {"height": 30, "width": 30})
|
self.assertEqual(image_processor.size, {"height": 30, "width": 30})
|
||||||
self.assertEqual(image_processor.do_reduce_labels, False)
|
self.assertEqual(image_processor.do_reduce_labels, False)
|
||||||
|
|
||||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, reduce_labels=True)
|
image_processor = self.image_processing_class.from_dict(
|
||||||
|
self.image_processor_dict, size=42, do_reduce_labels=True
|
||||||
|
)
|
||||||
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
|
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
|
||||||
self.assertEqual(image_processor.do_reduce_labels, True)
|
self.assertEqual(image_processor.do_reduce_labels, True)
|
||||||
|
|
||||||
@@ -256,3 +258,16 @@ class SegformerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
encoding = image_processing(image, map, return_tensors="pt")
|
encoding = image_processing(image, map, return_tensors="pt")
|
||||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||||
|
|
||||||
|
def test_removed_deprecated_kwargs(self):
|
||||||
|
image_processor_dict = dict(self.image_processor_dict)
|
||||||
|
image_processor_dict.pop("do_reduce_labels", None)
|
||||||
|
image_processor_dict["reduce_labels"] = True
|
||||||
|
|
||||||
|
# test we are able to create the image processor with the deprecated kwargs
|
||||||
|
image_processor = self.image_processing_class(**image_processor_dict)
|
||||||
|
self.assertEqual(image_processor.do_reduce_labels, True)
|
||||||
|
|
||||||
|
# test we still support reduce_labels with config
|
||||||
|
image_processor = self.image_processing_class.from_dict(image_processor_dict)
|
||||||
|
self.assertEqual(image_processor.do_reduce_labels, True)
|
||||||
|
|||||||
170
tests/utils/test_deprecation.py
Normal file
170
tests/utils/test_deprecation.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
from transformers import __version__
|
||||||
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
|
|
||||||
|
INFINITE_VERSION = "9999.0.0"
|
||||||
|
|
||||||
|
|
||||||
|
class DeprecationDecoratorTester(unittest.TestCase):
|
||||||
|
def test_rename_kwarg(self):
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
|
||||||
|
@deprecate_kwarg("deprecated_name", new_name="new_name", version=INFINITE_VERSION)
|
||||||
|
def dummy_function(new_name=None, other_name=None):
|
||||||
|
return new_name, other_name
|
||||||
|
|
||||||
|
# Test keyword argument is renamed
|
||||||
|
value, other_value = dummy_function(deprecated_name="old_value")
|
||||||
|
self.assertEqual(value, "old_value")
|
||||||
|
self.assertIsNone(other_value)
|
||||||
|
|
||||||
|
# Test deprecated keyword argument not passed
|
||||||
|
value, other_value = dummy_function(new_name="new_value")
|
||||||
|
self.assertEqual(value, "new_value")
|
||||||
|
self.assertIsNone(other_value)
|
||||||
|
|
||||||
|
# Test other keyword argument
|
||||||
|
value, other_value = dummy_function(other_name="other_value")
|
||||||
|
self.assertIsNone(value)
|
||||||
|
self.assertEqual(other_value, "other_value")
|
||||||
|
|
||||||
|
# Test deprecated and new args are passed, the new one should be returned
|
||||||
|
value, other_value = dummy_function(deprecated_name="old_value", new_name="new_value")
|
||||||
|
self.assertEqual(value, "new_value")
|
||||||
|
self.assertIsNone(other_value)
|
||||||
|
|
||||||
|
def test_rename_multiple_kwargs(self):
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
|
||||||
|
@deprecate_kwarg("deprecated_name1", new_name="new_name1", version=INFINITE_VERSION)
|
||||||
|
@deprecate_kwarg("deprecated_name2", new_name="new_name2", version=INFINITE_VERSION)
|
||||||
|
def dummy_function(new_name1=None, new_name2=None, other_name=None):
|
||||||
|
return new_name1, new_name2, other_name
|
||||||
|
|
||||||
|
# Test keyword argument is renamed
|
||||||
|
value1, value2, other_value = dummy_function(deprecated_name1="old_value1", deprecated_name2="old_value2")
|
||||||
|
self.assertEqual(value1, "old_value1")
|
||||||
|
self.assertEqual(value2, "old_value2")
|
||||||
|
self.assertIsNone(other_value)
|
||||||
|
|
||||||
|
# Test deprecated keyword argument is not passed
|
||||||
|
value1, value2, other_value = dummy_function(new_name1="new_value1", new_name2="new_value2")
|
||||||
|
self.assertEqual(value1, "new_value1")
|
||||||
|
self.assertEqual(value2, "new_value2")
|
||||||
|
self.assertIsNone(other_value)
|
||||||
|
|
||||||
|
# Test other keyword argument is passed and correctly returned
|
||||||
|
value1, value2, other_value = dummy_function(other_name="other_value")
|
||||||
|
self.assertIsNone(value1)
|
||||||
|
self.assertIsNone(value2)
|
||||||
|
self.assertEqual(other_value, "other_value")
|
||||||
|
|
||||||
|
def test_warnings(self):
|
||||||
|
# Test warning is raised for future version
|
||||||
|
@deprecate_kwarg("deprecated_name", new_name="new_name", version=INFINITE_VERSION)
|
||||||
|
def dummy_function(new_name=None, other_name=None):
|
||||||
|
return new_name, other_name
|
||||||
|
|
||||||
|
with self.assertWarns(FutureWarning):
|
||||||
|
dummy_function(deprecated_name="old_value")
|
||||||
|
|
||||||
|
# Test warning is not raised for past version, but arg is still renamed
|
||||||
|
@deprecate_kwarg("deprecated_name", new_name="new_name", version="0.0.0")
|
||||||
|
def dummy_function(new_name=None, other_name=None):
|
||||||
|
return new_name, other_name
|
||||||
|
|
||||||
|
with warnings.catch_warnings(record=True) as raised_warnings:
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
|
||||||
|
value, other_value = dummy_function(deprecated_name="old_value")
|
||||||
|
|
||||||
|
self.assertEqual(value, "old_value")
|
||||||
|
self.assertIsNone(other_value)
|
||||||
|
self.assertEqual(len(raised_warnings), 0, f"Warning raised: {[w.message for w in raised_warnings]}")
|
||||||
|
|
||||||
|
# Test warning is raised for future version if warn_if_greater_or_equal_version is set
|
||||||
|
@deprecate_kwarg("deprecated_name", version="0.0.0", warn_if_greater_or_equal_version=True)
|
||||||
|
def dummy_function(deprecated_name=None):
|
||||||
|
return deprecated_name
|
||||||
|
|
||||||
|
with self.assertWarns(FutureWarning):
|
||||||
|
value = dummy_function(deprecated_name="deprecated_value")
|
||||||
|
self.assertEqual(value, "deprecated_value")
|
||||||
|
|
||||||
|
# Test arg is not renamed if new_name is not specified, but warning is raised
|
||||||
|
@deprecate_kwarg("deprecated_name", version=INFINITE_VERSION)
|
||||||
|
def dummy_function(deprecated_name=None):
|
||||||
|
return deprecated_name
|
||||||
|
|
||||||
|
with self.assertWarns(FutureWarning):
|
||||||
|
value = dummy_function(deprecated_name="deprecated_value")
|
||||||
|
self.assertEqual(value, "deprecated_value")
|
||||||
|
|
||||||
|
def test_raises(self):
|
||||||
|
# Test if deprecated name and new name are both passed and raise_if_both_names is set -> raise error
|
||||||
|
@deprecate_kwarg("deprecated_name", new_name="new_name", version=INFINITE_VERSION, raise_if_both_names=True)
|
||||||
|
def dummy_function(new_name=None, other_name=None):
|
||||||
|
return new_name, other_name
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
dummy_function(deprecated_name="old_value", new_name="new_value")
|
||||||
|
|
||||||
|
# Test for current version == deprecation version
|
||||||
|
@deprecate_kwarg("deprecated_name", version=__version__, raise_if_greater_or_equal_version=True)
|
||||||
|
def dummy_function(deprecated_name=None):
|
||||||
|
return deprecated_name
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
dummy_function(deprecated_name="old_value")
|
||||||
|
|
||||||
|
# Test for current version > deprecation version
|
||||||
|
@deprecate_kwarg("deprecated_name", version="0.0.0", raise_if_greater_or_equal_version=True)
|
||||||
|
def dummy_function(deprecated_name=None):
|
||||||
|
return deprecated_name
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
dummy_function(deprecated_name="old_value")
|
||||||
|
|
||||||
|
def test_additional_message(self):
|
||||||
|
# Test additional message is added to the warning
|
||||||
|
@deprecate_kwarg("deprecated_name", version=INFINITE_VERSION, additional_message="Additional message")
|
||||||
|
def dummy_function(deprecated_name=None):
|
||||||
|
return deprecated_name
|
||||||
|
|
||||||
|
with warnings.catch_warnings(record=True) as raised_warnings:
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
dummy_function(deprecated_name="old_value")
|
||||||
|
|
||||||
|
self.assertTrue("Additional message" in str(raised_warnings[0].message))
|
||||||
|
|
||||||
|
@parameterized.expand(["0.0.0", __version__, INFINITE_VERSION])
|
||||||
|
def test_warning_for_both_names(self, version):
|
||||||
|
# We should raise warning if both names are passed for any specified version
|
||||||
|
@deprecate_kwarg("deprecated_name", new_name="new_name", version=version)
|
||||||
|
def dummy_function(new_name=None, **kwargs):
|
||||||
|
return new_name
|
||||||
|
|
||||||
|
with self.assertWarns(FutureWarning):
|
||||||
|
result = dummy_function(deprecated_name="old_value", new_name="new_value")
|
||||||
|
self.assertEqual(result, "new_value")
|
||||||
@@ -14,12 +14,14 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
import warnings
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from transformers.testing_utils import require_flax, require_tf, require_torch
|
from transformers.testing_utils import require_flax, require_tf, require_torch
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
expand_dims,
|
expand_dims,
|
||||||
|
filter_out_non_signature_kwargs,
|
||||||
flatten_dict,
|
flatten_dict,
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
@@ -198,3 +200,74 @@ class GenericTester(unittest.TestCase):
|
|||||||
x = np.random.randn(3, 4)
|
x = np.random.randn(3, 4)
|
||||||
t = jnp.array(x)
|
t = jnp.array(x)
|
||||||
self.assertTrue(np.allclose(expand_dims(x, axis=1), np.asarray(expand_dims(t, axis=1))))
|
self.assertTrue(np.allclose(expand_dims(x, axis=1), np.asarray(expand_dims(t, axis=1))))
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationDecoratorTester(unittest.TestCase):
|
||||||
|
def test_cases_no_warning(self):
|
||||||
|
with warnings.catch_warnings(record=True) as raised_warnings:
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
|
||||||
|
# basic test
|
||||||
|
@filter_out_non_signature_kwargs()
|
||||||
|
def func1(a):
|
||||||
|
return a
|
||||||
|
|
||||||
|
result = func1(1)
|
||||||
|
self.assertEqual(result, 1)
|
||||||
|
|
||||||
|
# include extra kwarg
|
||||||
|
@filter_out_non_signature_kwargs(extra=["extra_arg"])
|
||||||
|
def func2(a, **kwargs):
|
||||||
|
return a, kwargs
|
||||||
|
|
||||||
|
a, kwargs = func2(1)
|
||||||
|
self.assertEqual(a, 1)
|
||||||
|
self.assertEqual(kwargs, {})
|
||||||
|
|
||||||
|
a, kwargs = func2(1, extra_arg=2)
|
||||||
|
self.assertEqual(a, 1)
|
||||||
|
self.assertEqual(kwargs, {"extra_arg": 2})
|
||||||
|
|
||||||
|
# multiple extra kwargs
|
||||||
|
@filter_out_non_signature_kwargs(extra=["extra_arg", "extra_arg2"])
|
||||||
|
def func3(a, **kwargs):
|
||||||
|
return a, kwargs
|
||||||
|
|
||||||
|
a, kwargs = func3(2)
|
||||||
|
self.assertEqual(a, 2)
|
||||||
|
self.assertEqual(kwargs, {})
|
||||||
|
|
||||||
|
a, kwargs = func3(3, extra_arg2=3)
|
||||||
|
self.assertEqual(a, 3)
|
||||||
|
self.assertEqual(kwargs, {"extra_arg2": 3})
|
||||||
|
|
||||||
|
a, kwargs = func3(1, extra_arg=2, extra_arg2=3)
|
||||||
|
self.assertEqual(a, 1)
|
||||||
|
self.assertEqual(kwargs, {"extra_arg": 2, "extra_arg2": 3})
|
||||||
|
|
||||||
|
# Check that no warnings were raised
|
||||||
|
self.assertEqual(len(raised_warnings), 0, f"Warning raised: {[w.message for w in raised_warnings]}")
|
||||||
|
|
||||||
|
def test_cases_with_warnings(self):
|
||||||
|
@filter_out_non_signature_kwargs()
|
||||||
|
def func1(a):
|
||||||
|
return a
|
||||||
|
|
||||||
|
with self.assertWarns(UserWarning):
|
||||||
|
func1(1, extra_arg=2)
|
||||||
|
|
||||||
|
@filter_out_non_signature_kwargs(extra=["extra_arg"])
|
||||||
|
def func2(a, **kwargs):
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
with self.assertWarns(UserWarning):
|
||||||
|
kwargs = func2(1, extra_arg=2, extra_arg2=3)
|
||||||
|
self.assertEqual(kwargs, {"extra_arg": 2})
|
||||||
|
|
||||||
|
@filter_out_non_signature_kwargs(extra=["extra_arg", "extra_arg2"])
|
||||||
|
def func3(a, **kwargs):
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
with self.assertWarns(UserWarning):
|
||||||
|
kwargs = func3(1, extra_arg=2, extra_arg2=3, extra_arg3=4)
|
||||||
|
self.assertEqual(kwargs, {"extra_arg": 2, "extra_arg2": 3})
|
||||||
|
|||||||
Reference in New Issue
Block a user