Update old existing feature extractor references (#24552)
* Update old existing feature extractor references * Typo * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review * Address comments from review - update 'feature extractor' Co-authored by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -49,7 +49,7 @@ if is_vision_available():
|
||||
import PIL
|
||||
from PIL import Image
|
||||
|
||||
from transformers import BeitFeatureExtractor
|
||||
from transformers import BeitImageProcessor
|
||||
|
||||
|
||||
class BeitModelTester:
|
||||
@@ -342,18 +342,16 @@ def prepare_img():
|
||||
@require_vision
|
||||
class BeitModelIntegrationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_feature_extractor(self):
|
||||
return (
|
||||
BeitFeatureExtractor.from_pretrained("microsoft/beit-base-patch16-224") if is_vision_available() else None
|
||||
)
|
||||
def default_image_processor(self):
|
||||
return BeitImageProcessor.from_pretrained("microsoft/beit-base-patch16-224") if is_vision_available() else None
|
||||
|
||||
@slow
|
||||
def test_inference_masked_image_modeling_head(self):
|
||||
model = BeitForMaskedImageModeling.from_pretrained("microsoft/beit-base-patch16-224-pt22k").to(torch_device)
|
||||
|
||||
feature_extractor = self.default_feature_extractor
|
||||
image_processor = self.default_image_processor
|
||||
image = prepare_img()
|
||||
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(torch_device)
|
||||
pixel_values = image_processor(images=image, return_tensors="pt").pixel_values.to(torch_device)
|
||||
|
||||
# prepare bool_masked_pos
|
||||
bool_masked_pos = torch.ones((1, 196), dtype=torch.bool).to(torch_device)
|
||||
@@ -377,9 +375,9 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
||||
def test_inference_image_classification_head_imagenet_1k(self):
|
||||
model = BeitForImageClassification.from_pretrained("microsoft/beit-base-patch16-224").to(torch_device)
|
||||
|
||||
feature_extractor = self.default_feature_extractor
|
||||
image_processor = self.default_image_processor
|
||||
image = prepare_img()
|
||||
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
@@ -403,9 +401,9 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
||||
torch_device
|
||||
)
|
||||
|
||||
feature_extractor = self.default_feature_extractor
|
||||
image_processor = self.default_image_processor
|
||||
image = prepare_img()
|
||||
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
@@ -428,11 +426,11 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
||||
model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640")
|
||||
model = model.to(torch_device)
|
||||
|
||||
feature_extractor = BeitFeatureExtractor(do_resize=True, size=640, do_center_crop=False)
|
||||
image_processor = BeitImageProcessor(do_resize=True, size=640, do_center_crop=False)
|
||||
|
||||
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||
image = Image.open(ds[0]["file"])
|
||||
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
@@ -471,11 +469,11 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
||||
model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640")
|
||||
model = model.to(torch_device)
|
||||
|
||||
feature_extractor = BeitFeatureExtractor(do_resize=True, size=640, do_center_crop=False)
|
||||
image_processor = BeitImageProcessor(do_resize=True, size=640, do_center_crop=False)
|
||||
|
||||
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||
image = Image.open(ds[0]["file"])
|
||||
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
@@ -483,10 +481,10 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
outputs.logits = outputs.logits.detach().cpu()
|
||||
|
||||
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[(500, 300)])
|
||||
segmentation = image_processor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[(500, 300)])
|
||||
expected_shape = torch.Size((500, 300))
|
||||
self.assertEqual(segmentation[0].shape, expected_shape)
|
||||
|
||||
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs)
|
||||
segmentation = image_processor.post_process_semantic_segmentation(outputs=outputs)
|
||||
expected_shape = torch.Size((160, 160))
|
||||
self.assertEqual(segmentation[0].shape, expected_shape)
|
||||
|
||||
Reference in New Issue
Block a user