Return input_ids in ImageGPT feature extractor (#16872)
This commit is contained in:
@@ -161,17 +161,17 @@ class ImageGPTFeatureExtractorIntegrationTest(unittest.TestCase):
|
||||
# test non-batched
|
||||
encoding = feature_extractor(images[0], return_tensors="pt")
|
||||
|
||||
self.assertIsInstance(encoding.pixel_values, torch.LongTensor)
|
||||
self.assertEqual(encoding.pixel_values.shape, (1, 1024))
|
||||
self.assertIsInstance(encoding.input_ids, torch.LongTensor)
|
||||
self.assertEqual(encoding.input_ids.shape, (1, 1024))
|
||||
|
||||
expected_slice = [306, 191, 191]
|
||||
self.assertEqual(encoding.pixel_values[0, :3].tolist(), expected_slice)
|
||||
self.assertEqual(encoding.input_ids[0, :3].tolist(), expected_slice)
|
||||
|
||||
# test batched
|
||||
encoding = feature_extractor(images, return_tensors="pt")
|
||||
|
||||
self.assertIsInstance(encoding.pixel_values, torch.LongTensor)
|
||||
self.assertEqual(encoding.pixel_values.shape, (2, 1024))
|
||||
self.assertIsInstance(encoding.input_ids, torch.LongTensor)
|
||||
self.assertEqual(encoding.input_ids.shape, (2, 1024))
|
||||
|
||||
expected_slice = [303, 13, 13]
|
||||
self.assertEqual(encoding.pixel_values[1, -3:].tolist(), expected_slice)
|
||||
self.assertEqual(encoding.input_ids[1, -3:].tolist(), expected_slice)
|
||||
|
||||
Reference in New Issue
Block a user