[DETA] fix backbone freeze/unfreeze function (#27843)
* [DETA] fix freeze/unfreeze function * Update src/transformers/models/deta/modeling_deta.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/deta/modeling_deta.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * add freeze/unfreeze test case in DETA * fix type * fix typo 2 --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
df5c5c62ae
commit
235be08569
@@ -1414,14 +1414,12 @@ class DetaModel(DetaPreTrainedModel):
|
|||||||
def get_decoder(self):
|
def get_decoder(self):
|
||||||
return self.decoder
|
return self.decoder
|
||||||
|
|
||||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.freeze_backbone
|
|
||||||
def freeze_backbone(self):
|
def freeze_backbone(self):
|
||||||
for name, param in self.backbone.conv_encoder.model.named_parameters():
|
for name, param in self.backbone.model.named_parameters():
|
||||||
param.requires_grad_(False)
|
param.requires_grad_(False)
|
||||||
|
|
||||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.unfreeze_backbone
|
|
||||||
def unfreeze_backbone(self):
|
def unfreeze_backbone(self):
|
||||||
for name, param in self.backbone.conv_encoder.model.named_parameters():
|
for name, param in self.backbone.model.named_parameters():
|
||||||
param.requires_grad_(True)
|
param.requires_grad_(True)
|
||||||
|
|
||||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.get_valid_ratio
|
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.get_valid_ratio
|
||||||
|
|||||||
@@ -162,6 +162,26 @@ class DetaModelTester:
|
|||||||
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.num_queries, self.hidden_size))
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.num_queries, self.hidden_size))
|
||||||
|
|
||||||
|
def create_and_check_deta_freeze_backbone(self, config, pixel_values, pixel_mask, labels):
|
||||||
|
model = DetaModel(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
model.freeze_backbone()
|
||||||
|
|
||||||
|
for _, param in model.backbone.model.named_parameters():
|
||||||
|
self.parent.assertEqual(False, param.requires_grad)
|
||||||
|
|
||||||
|
def create_and_check_deta_unfreeze_backbone(self, config, pixel_values, pixel_mask, labels):
|
||||||
|
model = DetaModel(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
model.unfreeze_backbone()
|
||||||
|
|
||||||
|
for _, param in model.backbone.model.named_parameters():
|
||||||
|
self.parent.assertEqual(True, param.requires_grad)
|
||||||
|
|
||||||
def create_and_check_deta_object_detection_head_model(self, config, pixel_values, pixel_mask, labels):
|
def create_and_check_deta_object_detection_head_model(self, config, pixel_values, pixel_mask, labels):
|
||||||
model = DetaForObjectDetection(config=config)
|
model = DetaForObjectDetection(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
@@ -250,6 +270,14 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_deta_model(*config_and_inputs)
|
self.model_tester.create_and_check_deta_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_deta_freeze_backbone(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_deta_freeze_backbone(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_deta_unfreeze_backbone(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_deta_unfreeze_backbone(*config_and_inputs)
|
||||||
|
|
||||||
def test_deta_object_detection_head_model(self):
|
def test_deta_object_detection_head_model(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_deta_object_detection_head_model(*config_and_inputs)
|
self.model_tester.create_and_check_deta_object_detection_head_model(*config_and_inputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user