Clean up semantic segmentation tests (#16801)
Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
@@ -244,13 +244,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
# we don't test BeitForMaskedImageModeling
|
# we don't test BeitForMaskedImageModeling
|
||||||
if model_class in [*get_values(MODEL_MAPPING), BeitForMaskedImageModeling]:
|
if model_class in [*get_values(MODEL_MAPPING), BeitForMaskedImageModeling]:
|
||||||
continue
|
continue
|
||||||
# TODO: remove the following 3 lines once we have a MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
|
|
||||||
# this can then be incorporated into _prepare_for_class in test_modeling_common.py
|
|
||||||
elif model_class.__name__ == "BeitForSemanticSegmentation":
|
|
||||||
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
|
|
||||||
inputs_dict["labels"] = torch.zeros(
|
|
||||||
[self.model_tester.batch_size, height, width], device=torch_device
|
|
||||||
).long()
|
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.train()
|
model.train()
|
||||||
|
|||||||
@@ -316,13 +316,7 @@ class SegformerModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if model_class in get_values(MODEL_MAPPING):
|
if model_class in get_values(MODEL_MAPPING):
|
||||||
continue
|
continue
|
||||||
# TODO: remove the following 3 lines once we have a MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
|
|
||||||
# this can then be incorporated into _prepare_for_class in test_modeling_common.py
|
|
||||||
if model_class.__name__ == "SegformerForSemanticSegmentation":
|
|
||||||
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
|
|
||||||
inputs_dict["labels"] = torch.zeros(
|
|
||||||
[self.model_tester.batch_size, height, width], device=torch_device
|
|
||||||
).long()
|
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.train()
|
model.train()
|
||||||
|
|||||||
Reference in New Issue
Block a user