[Vision] .to function for ImageProcessors (#20536)

* add v1 with tests

* add checker

* simplified version

* update docstring

* better version

* fix docstring + change order

* make style

* tests + change conditions

* final tests

* modify docstring

* Update src/transformers/feature_extraction_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* replace by `ValueError`

* fix logic

* apply suggestions

* `dtype` is not needed

* adapt suggestions

* remove `_parse_args_to_device`

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Younes Belkada
2022-12-05 19:10:54 +01:00
committed by GitHub
parent 67d32f4649
commit ef0f85cd57
5 changed files with 102 additions and 14 deletions

View File

@@ -25,7 +25,15 @@ from pathlib import Path
from huggingface_hub import HfFolder, delete_repo, set_access_token
from requests.exceptions import HTTPError
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
from transformers.testing_utils import TOKEN, USER, check_json_file_has_correct_format, get_tests_dir, is_staging_test
from transformers.testing_utils import (
TOKEN,
USER,
check_json_file_has_correct_format,
get_tests_dir,
is_staging_test,
require_torch,
require_vision,
)
from transformers.utils import is_torch_available, is_vision_available
@@ -134,6 +142,8 @@ def prepare_video_inputs(feature_extract_tester, equal_resolution=False, numpify
class FeatureExtractionSavingTestMixin:
test_cast_dtype = None
def test_feat_extract_to_json_string(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
obj = json.loads(feat_extract.to_json_string())
@@ -164,6 +174,41 @@ class FeatureExtractionSavingTestMixin:
feat_extract = self.feature_extraction_class()
self.assertIsNotNone(feat_extract)
@require_torch
@require_vision
def test_cast_dtype_device(self):
if self.test_cast_dtype is not None:
# Initialize feature_extractor
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
# create random PyTorch tensors
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
encoding = feature_extractor(image_inputs, return_tensors="pt")
# for layoutLM compatiblity
self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
self.assertEqual(encoding.pixel_values.dtype, torch.float32)
encoding = feature_extractor(image_inputs, return_tensors="pt").to(torch.float16)
self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
self.assertEqual(encoding.pixel_values.dtype, torch.float16)
encoding = feature_extractor(image_inputs, return_tensors="pt").to("cpu", torch.bfloat16)
self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
self.assertEqual(encoding.pixel_values.dtype, torch.bfloat16)
with self.assertRaises(TypeError):
_ = feature_extractor(image_inputs, return_tensors="pt").to(torch.bfloat16, "cpu")
# Try with text + image feature
encoding = feature_extractor(image_inputs, return_tensors="pt")
encoding.update({"input_ids": torch.LongTensor([[1, 2, 3], [4, 5, 6]])})
encoding = encoding.to(torch.float16)
self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
self.assertEqual(encoding.pixel_values.dtype, torch.float16)
self.assertEqual(encoding.input_ids.dtype, torch.long)
class FeatureExtractorUtilTester(unittest.TestCase):
def test_cached_files_are_used_when_internet_is_down(self):