[DETA] Improvement and Sync from DETA especially for training (#27990)
* [DETA] fix freeze/unfreeze function * Update src/transformers/models/deta/modeling_deta.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/deta/modeling_deta.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * add freeze/unfreeze test case in DETA * fix type * fix typo 2 * fix : enable aux and enc loss in training pipeline * Add unsynced variables from original DETA for training * modification for passing CI test * make style * make fix * manual make fix * change deta_modeling_test of configuration 'two_stage' default to TRUE and minor change of dist checking * remove print * divide configuration in DetaModel and DetaForObjectDetection * image smaller size than 224 will give topk error * pred_boxes and logits should be equivalent to two_stage_num_proposals * add missing part in DetaConfig * Update src/transformers/models/deta/modeling_deta.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add docstring in configure and prettify TO DO part * change distribute related code to accelerate * Update src/transformers/models/deta/configuration_deta.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/deta/test_modeling_deta.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * protect importing accelerate * change variable name to specific value * wrong import --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
57e9c83213
commit
899d8351f9
@@ -57,14 +57,17 @@ class DetaModelTester:
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
num_queries=12,
|
||||
two_stage_num_proposals=12,
|
||||
num_channels=3,
|
||||
image_size=196,
|
||||
image_size=224,
|
||||
n_targets=8,
|
||||
num_labels=91,
|
||||
num_feature_levels=4,
|
||||
encoder_n_points=2,
|
||||
decoder_n_points=6,
|
||||
two_stage=False,
|
||||
two_stage=True,
|
||||
assign_first_stage=True,
|
||||
assign_second_stage=True,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -78,6 +81,7 @@ class DetaModelTester:
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.num_queries = num_queries
|
||||
self.two_stage_num_proposals = two_stage_num_proposals
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.n_targets = n_targets
|
||||
@@ -86,6 +90,8 @@ class DetaModelTester:
|
||||
self.encoder_n_points = encoder_n_points
|
||||
self.decoder_n_points = decoder_n_points
|
||||
self.two_stage = two_stage
|
||||
self.assign_first_stage = assign_first_stage
|
||||
self.assign_second_stage = assign_second_stage
|
||||
|
||||
# we also set the expected seq length for both encoder and decoder
|
||||
self.encoder_seq_length = (
|
||||
@@ -96,7 +102,7 @@ class DetaModelTester:
|
||||
)
|
||||
self.decoder_seq_length = self.num_queries
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
def prepare_config_and_inputs(self, model_class_name):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
|
||||
pixel_mask = torch.ones([self.batch_size, self.image_size, self.image_size], device=torch_device)
|
||||
@@ -114,10 +120,10 @@ class DetaModelTester:
|
||||
target["masks"] = torch.rand(self.n_targets, self.image_size, self.image_size, device=torch_device)
|
||||
labels.append(target)
|
||||
|
||||
config = self.get_config()
|
||||
config = self.get_config(model_class_name)
|
||||
return config, pixel_values, pixel_mask, labels
|
||||
|
||||
def get_config(self):
|
||||
def get_config(self, model_class_name):
|
||||
resnet_config = ResNetConfig(
|
||||
num_channels=3,
|
||||
embeddings_size=10,
|
||||
@@ -128,6 +134,9 @@ class DetaModelTester:
|
||||
out_features=["stage2", "stage3", "stage4"],
|
||||
out_indices=[2, 3, 4],
|
||||
)
|
||||
two_stage = model_class_name == "DetaForObjectDetection"
|
||||
assign_first_stage = model_class_name == "DetaForObjectDetection"
|
||||
assign_second_stage = model_class_name == "DetaForObjectDetection"
|
||||
return DetaConfig(
|
||||
d_model=self.hidden_size,
|
||||
encoder_layers=self.num_hidden_layers,
|
||||
@@ -139,16 +148,19 @@ class DetaModelTester:
|
||||
dropout=self.hidden_dropout_prob,
|
||||
attention_dropout=self.attention_probs_dropout_prob,
|
||||
num_queries=self.num_queries,
|
||||
two_stage_num_proposals=self.two_stage_num_proposals,
|
||||
num_labels=self.num_labels,
|
||||
num_feature_levels=self.num_feature_levels,
|
||||
encoder_n_points=self.encoder_n_points,
|
||||
decoder_n_points=self.decoder_n_points,
|
||||
two_stage=self.two_stage,
|
||||
two_stage=two_stage,
|
||||
assign_first_stage=assign_first_stage,
|
||||
assign_second_stage=assign_second_stage,
|
||||
backbone_config=resnet_config,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs()
|
||||
def prepare_config_and_inputs_for_common(self, model_class_name="DetaModel"):
|
||||
config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs(model_class_name)
|
||||
inputs_dict = {"pixel_values": pixel_values, "pixel_mask": pixel_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
@@ -190,14 +202,14 @@ class DetaModelTester:
|
||||
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
|
||||
result = model(pixel_values)
|
||||
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
|
||||
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.two_stage_num_proposals, self.num_labels))
|
||||
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.two_stage_num_proposals, 4))
|
||||
|
||||
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
|
||||
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
|
||||
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.two_stage_num_proposals, self.num_labels))
|
||||
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.two_stage_num_proposals, 4))
|
||||
|
||||
|
||||
@require_torchvision
|
||||
@@ -267,19 +279,19 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
self.config_tester.check_config_can_be_init_without_params()
|
||||
|
||||
def test_deta_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs(model_class_name="DetaModel")
|
||||
self.model_tester.create_and_check_deta_model(*config_and_inputs)
|
||||
|
||||
def test_deta_freeze_backbone(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs(model_class_name="DetaModel")
|
||||
self.model_tester.create_and_check_deta_freeze_backbone(*config_and_inputs)
|
||||
|
||||
def test_deta_unfreeze_backbone(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs(model_class_name="DetaModel")
|
||||
self.model_tester.create_and_check_deta_unfreeze_backbone(*config_and_inputs)
|
||||
|
||||
def test_deta_object_detection_head_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs(model_class_name="DetaForObjectDetection")
|
||||
self.model_tester.create_and_check_deta_object_detection_head_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="DETA does not use inputs_embeds")
|
||||
|
||||
Reference in New Issue
Block a user