[Vision] .to function for ImageProcessors (#20536)
* add v1 with tests * add checker * simplified version * update docstring * better version * fix docstring + change order * make style * tests + change conditions * final tests * modify docstring * Update src/transformers/feature_extraction_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * replace by `ValueError` * fix logic * apply suggestions * `dtype` is not needed * adapt suggestions * remove `_parse_args_to_device` Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -25,7 +25,15 @@ from pathlib import Path
|
||||
from huggingface_hub import HfFolder, delete_repo, set_access_token
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
|
||||
from transformers.testing_utils import TOKEN, USER, check_json_file_has_correct_format, get_tests_dir, is_staging_test
|
||||
from transformers.testing_utils import (
|
||||
TOKEN,
|
||||
USER,
|
||||
check_json_file_has_correct_format,
|
||||
get_tests_dir,
|
||||
is_staging_test,
|
||||
require_torch,
|
||||
require_vision,
|
||||
)
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
|
||||
@@ -134,6 +142,8 @@ def prepare_video_inputs(feature_extract_tester, equal_resolution=False, numpify
|
||||
|
||||
|
||||
class FeatureExtractionSavingTestMixin:
|
||||
test_cast_dtype = None
|
||||
|
||||
def test_feat_extract_to_json_string(self):
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
obj = json.loads(feat_extract.to_json_string())
|
||||
@@ -164,6 +174,41 @@ class FeatureExtractionSavingTestMixin:
|
||||
feat_extract = self.feature_extraction_class()
|
||||
self.assertIsNotNone(feat_extract)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_cast_dtype_device(self):
|
||||
if self.test_cast_dtype is not None:
|
||||
# Initialize feature_extractor
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
|
||||
# create random PyTorch tensors
|
||||
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
|
||||
|
||||
encoding = feature_extractor(image_inputs, return_tensors="pt")
|
||||
# for layoutLM compatiblity
|
||||
self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
|
||||
self.assertEqual(encoding.pixel_values.dtype, torch.float32)
|
||||
|
||||
encoding = feature_extractor(image_inputs, return_tensors="pt").to(torch.float16)
|
||||
self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
|
||||
self.assertEqual(encoding.pixel_values.dtype, torch.float16)
|
||||
|
||||
encoding = feature_extractor(image_inputs, return_tensors="pt").to("cpu", torch.bfloat16)
|
||||
self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
|
||||
self.assertEqual(encoding.pixel_values.dtype, torch.bfloat16)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
_ = feature_extractor(image_inputs, return_tensors="pt").to(torch.bfloat16, "cpu")
|
||||
|
||||
# Try with text + image feature
|
||||
encoding = feature_extractor(image_inputs, return_tensors="pt")
|
||||
encoding.update({"input_ids": torch.LongTensor([[1, 2, 3], [4, 5, 6]])})
|
||||
encoding = encoding.to(torch.float16)
|
||||
|
||||
self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
|
||||
self.assertEqual(encoding.pixel_values.dtype, torch.float16)
|
||||
self.assertEqual(encoding.input_ids.dtype, torch.long)
|
||||
|
||||
|
||||
class FeatureExtractorUtilTester(unittest.TestCase):
|
||||
def test_cached_files_are_used_when_internet_is_down(self):
|
||||
|
||||
Reference in New Issue
Block a user