@@ -92,6 +92,7 @@ from transformers.testing_utils import (
|
||||
require_torch_tf32,
|
||||
require_torch_up_to_2_accelerators,
|
||||
require_torchdynamo,
|
||||
require_vision,
|
||||
require_wandb,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -3812,6 +3813,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
reloaded_tokenizer(test_sentence, padding="max_length").input_ids,
|
||||
)
|
||||
|
||||
@require_vision
|
||||
def test_trainer_saves_image_processor(self):
|
||||
MODEL_ID = "openai/clip-vit-base-patch32"
|
||||
image_processor = AutoImageProcessor.from_pretrained(MODEL_ID)
|
||||
@@ -3845,6 +3847,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
|
||||
self.assertDictEqual(feature_extractor.to_dict(), reloaded_feature_extractor.to_dict())
|
||||
|
||||
@require_vision
|
||||
def test_trainer_saves_processor(self):
|
||||
MODEL_ID = "openai/clip-vit-base-patch32"
|
||||
image_processor = AutoImageProcessor.from_pretrained(MODEL_ID)
|
||||
|
||||
Reference in New Issue
Block a user