[Vision] .to function for ImageProcessors (#20536)
* add v1 with tests * add checker * simplified version * update docstring * better version * fix docstring + change order * make style * tests + change conditions * final tests * modify docstring * Update src/transformers/feature_extraction_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * replace by `ValueError` * fix logic * apply suggestions * `dtype` is not needed * adapt suggestions * remove `_parse_args_to_device` Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -40,6 +40,7 @@ from .utils import (
|
|||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_torch_device,
|
is_torch_device,
|
||||||
|
is_torch_dtype,
|
||||||
logging,
|
logging,
|
||||||
torch_required,
|
torch_required,
|
||||||
)
|
)
|
||||||
@@ -47,7 +48,7 @@ from .utils import (
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch # noqa
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -138,7 +139,7 @@ class BatchFeature(UserDict):
|
|||||||
elif tensor_type == TensorType.PYTORCH:
|
elif tensor_type == TensorType.PYTORCH:
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
|
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
|
||||||
import torch
|
import torch # noqa
|
||||||
|
|
||||||
def as_tensor(value):
|
def as_tensor(value):
|
||||||
if isinstance(value, (list, tuple)) and len(value) > 0 and isinstance(value[0], np.ndarray):
|
if isinstance(value, (list, tuple)) and len(value) > 0 and isinstance(value[0], np.ndarray):
|
||||||
@@ -175,25 +176,47 @@ class BatchFeature(UserDict):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
@torch_required
|
@torch_required
|
||||||
# Copied from transformers.tokenization_utils_base.BatchEncoding.to with BatchEncoding->BatchFeature
|
def to(self, *args, **kwargs) -> "BatchFeature":
|
||||||
def to(self, device: Union[str, "torch.device"]) -> "BatchFeature":
|
|
||||||
"""
|
"""
|
||||||
Send all values to device by calling `v.to(device)` (PyTorch only).
|
Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
|
||||||
|
different `dtypes` and sending the `BatchFeature` to a different `device`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
device (`str` or `torch.device`): The device to put the tensors on.
|
args (`Tuple`):
|
||||||
|
Will be passed to the `to(...)` function of the tensors.
|
||||||
|
kwargs (`Dict`, *optional*):
|
||||||
|
Will be passed to the `to(...)` function of the tensors.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`BatchFeature`]: The same instance after modification.
|
[`BatchFeature`]: The same instance after modification.
|
||||||
"""
|
"""
|
||||||
|
import torch # noqa
|
||||||
|
|
||||||
# This check catches things like APEX blindly calling "to" on all inputs to a module
|
new_data = {}
|
||||||
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
|
device = kwargs.get("device")
|
||||||
# into a HalfTensor
|
# Check if the args are a device or a dtype
|
||||||
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
|
if device is None and len(args) > 0:
|
||||||
self.data = {k: v.to(device=device) for k, v in self.data.items()}
|
# device should be always the first argument
|
||||||
|
arg = args[0]
|
||||||
|
if is_torch_dtype(arg):
|
||||||
|
# The first argument is a dtype
|
||||||
|
pass
|
||||||
|
elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
|
||||||
|
device = arg
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Attempting to cast a BatchFeature to type {str(device)}. This is not supported.")
|
# it's something else
|
||||||
|
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
|
||||||
|
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
|
||||||
|
for k, v in self.items():
|
||||||
|
# check if v is a floating point
|
||||||
|
if torch.is_floating_point(v):
|
||||||
|
# cast and send to device
|
||||||
|
new_data[k] = v.to(*args, **kwargs)
|
||||||
|
elif device is not None:
|
||||||
|
new_data[k] = v.to(device=device)
|
||||||
|
else:
|
||||||
|
new_data[k] = v
|
||||||
|
self.data = new_data
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ from .generic import (
|
|||||||
is_tensor,
|
is_tensor,
|
||||||
is_tf_tensor,
|
is_tf_tensor,
|
||||||
is_torch_device,
|
is_torch_device,
|
||||||
|
is_torch_dtype,
|
||||||
is_torch_tensor,
|
is_torch_tensor,
|
||||||
reshape,
|
reshape,
|
||||||
squeeze,
|
squeeze,
|
||||||
|
|||||||
@@ -123,6 +123,24 @@ def is_torch_device(x):
|
|||||||
return False if not is_torch_available() else _is_torch_device(x)
|
return False if not is_torch_available() else _is_torch_device(x)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_torch_dtype(x):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if isinstance(x, str):
|
||||||
|
if hasattr(torch, x):
|
||||||
|
x = getattr(torch, x)
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return isinstance(x, torch.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def is_torch_dtype(x):
|
||||||
|
"""
|
||||||
|
Tests if `x` is a torch dtype or not. Safe to call even if torch is not installed.
|
||||||
|
"""
|
||||||
|
return False if not is_torch_available() else _is_torch_dtype(x)
|
||||||
|
|
||||||
|
|
||||||
def _is_tensorflow(x):
|
def _is_tensorflow(x):
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|||||||
@@ -84,6 +84,7 @@ class DeiTFeatureExtractionTester(unittest.TestCase):
|
|||||||
class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase):
|
class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase):
|
||||||
|
|
||||||
feature_extraction_class = DeiTFeatureExtractor if is_vision_available() else None
|
feature_extraction_class = DeiTFeatureExtractor if is_vision_available() else None
|
||||||
|
test_cast_dtype = True
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.feature_extract_tester = DeiTFeatureExtractionTester(self)
|
self.feature_extract_tester = DeiTFeatureExtractionTester(self)
|
||||||
|
|||||||
@@ -25,7 +25,15 @@ from pathlib import Path
|
|||||||
from huggingface_hub import HfFolder, delete_repo, set_access_token
|
from huggingface_hub import HfFolder, delete_repo, set_access_token
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
|
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
|
||||||
from transformers.testing_utils import TOKEN, USER, check_json_file_has_correct_format, get_tests_dir, is_staging_test
|
from transformers.testing_utils import (
|
||||||
|
TOKEN,
|
||||||
|
USER,
|
||||||
|
check_json_file_has_correct_format,
|
||||||
|
get_tests_dir,
|
||||||
|
is_staging_test,
|
||||||
|
require_torch,
|
||||||
|
require_vision,
|
||||||
|
)
|
||||||
from transformers.utils import is_torch_available, is_vision_available
|
from transformers.utils import is_torch_available, is_vision_available
|
||||||
|
|
||||||
|
|
||||||
@@ -134,6 +142,8 @@ def prepare_video_inputs(feature_extract_tester, equal_resolution=False, numpify
|
|||||||
|
|
||||||
|
|
||||||
class FeatureExtractionSavingTestMixin:
|
class FeatureExtractionSavingTestMixin:
|
||||||
|
test_cast_dtype = None
|
||||||
|
|
||||||
def test_feat_extract_to_json_string(self):
|
def test_feat_extract_to_json_string(self):
|
||||||
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
|
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
|
||||||
obj = json.loads(feat_extract.to_json_string())
|
obj = json.loads(feat_extract.to_json_string())
|
||||||
@@ -164,6 +174,41 @@ class FeatureExtractionSavingTestMixin:
|
|||||||
feat_extract = self.feature_extraction_class()
|
feat_extract = self.feature_extraction_class()
|
||||||
self.assertIsNotNone(feat_extract)
|
self.assertIsNotNone(feat_extract)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_vision
|
||||||
|
def test_cast_dtype_device(self):
|
||||||
|
if self.test_cast_dtype is not None:
|
||||||
|
# Initialize feature_extractor
|
||||||
|
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||||
|
|
||||||
|
# create random PyTorch tensors
|
||||||
|
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
|
||||||
|
|
||||||
|
encoding = feature_extractor(image_inputs, return_tensors="pt")
|
||||||
|
# for layoutLM compatiblity
|
||||||
|
self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
|
||||||
|
self.assertEqual(encoding.pixel_values.dtype, torch.float32)
|
||||||
|
|
||||||
|
encoding = feature_extractor(image_inputs, return_tensors="pt").to(torch.float16)
|
||||||
|
self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
|
||||||
|
self.assertEqual(encoding.pixel_values.dtype, torch.float16)
|
||||||
|
|
||||||
|
encoding = feature_extractor(image_inputs, return_tensors="pt").to("cpu", torch.bfloat16)
|
||||||
|
self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
|
||||||
|
self.assertEqual(encoding.pixel_values.dtype, torch.bfloat16)
|
||||||
|
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
_ = feature_extractor(image_inputs, return_tensors="pt").to(torch.bfloat16, "cpu")
|
||||||
|
|
||||||
|
# Try with text + image feature
|
||||||
|
encoding = feature_extractor(image_inputs, return_tensors="pt")
|
||||||
|
encoding.update({"input_ids": torch.LongTensor([[1, 2, 3], [4, 5, 6]])})
|
||||||
|
encoding = encoding.to(torch.float16)
|
||||||
|
|
||||||
|
self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
|
||||||
|
self.assertEqual(encoding.pixel_values.dtype, torch.float16)
|
||||||
|
self.assertEqual(encoding.input_ids.dtype, torch.long)
|
||||||
|
|
||||||
|
|
||||||
class FeatureExtractorUtilTester(unittest.TestCase):
|
class FeatureExtractorUtilTester(unittest.TestCase):
|
||||||
def test_cached_files_are_used_when_internet_is_down(self):
|
def test_cached_files_are_used_when_internet_is_down(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user