Return input_ids in ImageGPT feature extractor (#16872)
This commit is contained in:
@@ -68,7 +68,7 @@ class ImageGPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMix
|
|||||||
Whether or not to normalize the input to the range between -1 and +1.
|
Whether or not to normalize the input to the range between -1 and +1.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_input_names = ["pixel_values"]
|
model_input_names = ["input_ids"]
|
||||||
|
|
||||||
def __init__(self, clusters, do_resize=True, size=32, resample=Image.BILINEAR, do_normalize=True, **kwargs):
|
def __init__(self, clusters, do_resize=True, size=32, resample=Image.BILINEAR, do_normalize=True, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@@ -128,8 +128,7 @@ class ImageGPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMix
|
|||||||
Returns:
|
Returns:
|
||||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||||
|
|
||||||
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
|
- **input_ids** -- Input IDs to be fed to a model, of shape `(batch_size, height * width)`.
|
||||||
width).
|
|
||||||
"""
|
"""
|
||||||
# Input type checking for clearer error
|
# Input type checking for clearer error
|
||||||
valid_images = False
|
valid_images = False
|
||||||
@@ -171,7 +170,7 @@ class ImageGPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMix
|
|||||||
images = images.reshape(batch_size, -1)
|
images = images.reshape(batch_size, -1)
|
||||||
|
|
||||||
# return as BatchFeature
|
# return as BatchFeature
|
||||||
data = {"pixel_values": images}
|
data = {"input_ids": images}
|
||||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
||||||
|
|
||||||
return encoded_inputs
|
return encoded_inputs
|
||||||
|
|||||||
@@ -161,17 +161,17 @@ class ImageGPTFeatureExtractorIntegrationTest(unittest.TestCase):
|
|||||||
# test non-batched
|
# test non-batched
|
||||||
encoding = feature_extractor(images[0], return_tensors="pt")
|
encoding = feature_extractor(images[0], return_tensors="pt")
|
||||||
|
|
||||||
self.assertIsInstance(encoding.pixel_values, torch.LongTensor)
|
self.assertIsInstance(encoding.input_ids, torch.LongTensor)
|
||||||
self.assertEqual(encoding.pixel_values.shape, (1, 1024))
|
self.assertEqual(encoding.input_ids.shape, (1, 1024))
|
||||||
|
|
||||||
expected_slice = [306, 191, 191]
|
expected_slice = [306, 191, 191]
|
||||||
self.assertEqual(encoding.pixel_values[0, :3].tolist(), expected_slice)
|
self.assertEqual(encoding.input_ids[0, :3].tolist(), expected_slice)
|
||||||
|
|
||||||
# test batched
|
# test batched
|
||||||
encoding = feature_extractor(images, return_tensors="pt")
|
encoding = feature_extractor(images, return_tensors="pt")
|
||||||
|
|
||||||
self.assertIsInstance(encoding.pixel_values, torch.LongTensor)
|
self.assertIsInstance(encoding.input_ids, torch.LongTensor)
|
||||||
self.assertEqual(encoding.pixel_values.shape, (2, 1024))
|
self.assertEqual(encoding.input_ids.shape, (2, 1024))
|
||||||
|
|
||||||
expected_slice = [303, 13, 13]
|
expected_slice = [303, 13, 13]
|
||||||
self.assertEqual(encoding.pixel_values[1, -3:].tolist(), expected_slice)
|
self.assertEqual(encoding.input_ids[1, -3:].tolist(), expected_slice)
|
||||||
|
|||||||
Reference in New Issue
Block a user