[DETA] Improvement and Sync from DETA especially for training (#27990)
* [DETA] fix freeze/unfreeze function * Update src/transformers/models/deta/modeling_deta.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/deta/modeling_deta.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * add freeze/unfreeze test case in DETA * fix type * fix typo 2 * fix : enable aux and enc loss in training pipeline * Add unsynced variables from original DETA for training * modification for passing CI test * make style * make fix * manual make fix * change deta_modeling_test of configuration 'two_stage' default to TRUE and minor change of dist checking * remove print * divide configuration in DetaModel and DetaForObjectDetection * image smaller size than 224 will give topk error * pred_boxes and logits should be equivalent to two_stage_num_proposals * add missing part in DetaConfig * Update src/transformers/models/deta/modeling_deta.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add docstring in configure and prettify TO DO part * change distribute related code to accelerate * Update src/transformers/models/deta/configuration_deta.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/deta/test_modeling_deta.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * protect importing accelerate * change variable name to specific value * wrong import --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
57e9c83213
commit
899d8351f9
@@ -109,6 +109,10 @@ class DetaConfig(PretrainedConfig):
|
|||||||
based on the predictions from the previous layer.
|
based on the predictions from the previous layer.
|
||||||
focal_alpha (`float`, *optional*, defaults to 0.25):
|
focal_alpha (`float`, *optional*, defaults to 0.25):
|
||||||
Alpha parameter in the focal loss.
|
Alpha parameter in the focal loss.
|
||||||
|
assign_first_stage (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to assign each prediction i to the highest overlapping ground truth object if the overlap is larger than a threshold 0.7.
|
||||||
|
assign_second_stage (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to assign second assignment procedure in the second stage closely follows the first stage assignment procedure.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@@ -161,6 +165,7 @@ class DetaConfig(PretrainedConfig):
|
|||||||
two_stage_num_proposals=300,
|
two_stage_num_proposals=300,
|
||||||
with_box_refine=True,
|
with_box_refine=True,
|
||||||
assign_first_stage=True,
|
assign_first_stage=True,
|
||||||
|
assign_second_stage=True,
|
||||||
class_cost=1,
|
class_cost=1,
|
||||||
bbox_cost=5,
|
bbox_cost=5,
|
||||||
giou_cost=2,
|
giou_cost=2,
|
||||||
@@ -208,6 +213,7 @@ class DetaConfig(PretrainedConfig):
|
|||||||
self.two_stage_num_proposals = two_stage_num_proposals
|
self.two_stage_num_proposals = two_stage_num_proposals
|
||||||
self.with_box_refine = with_box_refine
|
self.with_box_refine = with_box_refine
|
||||||
self.assign_first_stage = assign_first_stage
|
self.assign_first_stage = assign_first_stage
|
||||||
|
self.assign_second_stage = assign_second_stage
|
||||||
if two_stage is True and with_box_refine is False:
|
if two_stage is True and with_box_refine is False:
|
||||||
raise ValueError("If two_stage is True, with_box_refine must be True.")
|
raise ValueError("If two_stage is True, with_box_refine must be True.")
|
||||||
# Hungarian matcher
|
# Hungarian matcher
|
||||||
|
|||||||
@@ -1052,7 +1052,7 @@ class DetaImageProcessor(BaseImageProcessor):
|
|||||||
score = all_scores[b]
|
score = all_scores[b]
|
||||||
lbls = all_labels[b]
|
lbls = all_labels[b]
|
||||||
|
|
||||||
pre_topk = score.topk(min(10000, len(score))).indices
|
pre_topk = score.topk(min(10000, num_queries * num_labels)).indices
|
||||||
box = box[pre_topk]
|
box = box[pre_topk]
|
||||||
score = score[pre_topk]
|
score = score[pre_topk]
|
||||||
lbls = lbls[pre_topk]
|
lbls = lbls[pre_topk]
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
|||||||
from ...modeling_outputs import BaseModelOutput
|
from ...modeling_outputs import BaseModelOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import meshgrid
|
from ...pytorch_utils import meshgrid
|
||||||
from ...utils import is_torchvision_available, logging, requires_backends
|
from ...utils import is_accelerate_available, is_torchvision_available, logging, requires_backends
|
||||||
from ..auto import AutoBackbone
|
from ..auto import AutoBackbone
|
||||||
from .configuration_deta import DetaConfig
|
from .configuration_deta import DetaConfig
|
||||||
|
|
||||||
@@ -46,6 +46,10 @@ from .configuration_deta import DetaConfig
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
if is_accelerate_available():
|
||||||
|
from accelerate import PartialState
|
||||||
|
from accelerate.utils import reduce
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from transformers.image_transforms import center_to_corners_format
|
from transformers.image_transforms import center_to_corners_format
|
||||||
|
|
||||||
@@ -105,7 +109,6 @@ class DetaDecoderOutput(ModelOutput):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModelOutput with DeformableDetr->Deta,Deformable DETR->DETA
|
|
||||||
class DetaModelOutput(ModelOutput):
|
class DetaModelOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
Base class for outputs of the Deformable DETR encoder-decoder model.
|
Base class for outputs of the Deformable DETR encoder-decoder model.
|
||||||
@@ -147,6 +150,8 @@ class DetaModelOutput(ModelOutput):
|
|||||||
foreground and background).
|
foreground and background).
|
||||||
enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
|
enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
|
||||||
Logits of predicted bounding boxes coordinates in the first stage.
|
Logits of predicted bounding boxes coordinates in the first stage.
|
||||||
|
output_proposals (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.two_stage=True`):
|
||||||
|
Logits of proposal bounding boxes coordinates in the gen_encoder_output_proposals.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
init_reference_points: torch.FloatTensor = None
|
init_reference_points: torch.FloatTensor = None
|
||||||
@@ -161,10 +166,10 @@ class DetaModelOutput(ModelOutput):
|
|||||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
enc_outputs_class: Optional[torch.FloatTensor] = None
|
enc_outputs_class: Optional[torch.FloatTensor] = None
|
||||||
enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
|
enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
|
||||||
|
output_proposals: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrObjectDetectionOutput with DeformableDetr->Deta
|
|
||||||
class DetaObjectDetectionOutput(ModelOutput):
|
class DetaObjectDetectionOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
Output type of [`DetaForObjectDetection`].
|
Output type of [`DetaForObjectDetection`].
|
||||||
@@ -223,6 +228,8 @@ class DetaObjectDetectionOutput(ModelOutput):
|
|||||||
foreground and background).
|
foreground and background).
|
||||||
enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
|
enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
|
||||||
Logits of predicted bounding boxes coordinates in the first stage.
|
Logits of predicted bounding boxes coordinates in the first stage.
|
||||||
|
output_proposals (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.two_stage=True`):
|
||||||
|
Logits of proposal bounding boxes coordinates in the gen_encoder_output_proposals.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
loss: Optional[torch.FloatTensor] = None
|
loss: Optional[torch.FloatTensor] = None
|
||||||
@@ -242,6 +249,7 @@ class DetaObjectDetectionOutput(ModelOutput):
|
|||||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
enc_outputs_class: Optional = None
|
enc_outputs_class: Optional = None
|
||||||
enc_outputs_coord_logits: Optional = None
|
enc_outputs_coord_logits: Optional = None
|
||||||
|
output_proposals: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
|
||||||
def _get_clones(module, N):
|
def _get_clones(module, N):
|
||||||
@@ -1632,6 +1640,7 @@ class DetaModel(DetaPreTrainedModel):
|
|||||||
batch_size, _, num_channels = encoder_outputs[0].shape
|
batch_size, _, num_channels = encoder_outputs[0].shape
|
||||||
enc_outputs_class = None
|
enc_outputs_class = None
|
||||||
enc_outputs_coord_logits = None
|
enc_outputs_coord_logits = None
|
||||||
|
output_proposals = None
|
||||||
if self.config.two_stage:
|
if self.config.two_stage:
|
||||||
object_query_embedding, output_proposals, level_ids = self.gen_encoder_output_proposals(
|
object_query_embedding, output_proposals, level_ids = self.gen_encoder_output_proposals(
|
||||||
encoder_outputs[0], ~mask_flatten, spatial_shapes
|
encoder_outputs[0], ~mask_flatten, spatial_shapes
|
||||||
@@ -1746,6 +1755,7 @@ class DetaModel(DetaPreTrainedModel):
|
|||||||
encoder_attentions=encoder_outputs.attentions,
|
encoder_attentions=encoder_outputs.attentions,
|
||||||
enc_outputs_class=enc_outputs_class,
|
enc_outputs_class=enc_outputs_class,
|
||||||
enc_outputs_coord_logits=enc_outputs_coord_logits,
|
enc_outputs_coord_logits=enc_outputs_coord_logits,
|
||||||
|
output_proposals=output_proposals,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1804,12 +1814,15 @@ class DetaForObjectDetection(DetaPreTrainedModel):
|
|||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
@torch.jit.unused
|
@torch.jit.unused
|
||||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrForObjectDetection._set_aux_loss
|
|
||||||
def _set_aux_loss(self, outputs_class, outputs_coord):
|
def _set_aux_loss(self, outputs_class, outputs_coord):
|
||||||
# this is a workaround to make torchscript happy, as torchscript
|
# this is a workaround to make torchscript happy, as torchscript
|
||||||
# doesn't support dictionary with non-homogeneous values, such
|
# doesn't support dictionary with non-homogeneous values, such
|
||||||
# as a dict having both a Tensor and a list.
|
# as a dict having both a Tensor and a list.
|
||||||
return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
|
aux_loss = [
|
||||||
|
{"logits": logits, "pred_boxes": pred_boxes}
|
||||||
|
for logits, pred_boxes in zip(outputs_class.transpose(0, 1)[:-1], outputs_coord.transpose(0, 1)[:-1])
|
||||||
|
]
|
||||||
|
return aux_loss
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(DETA_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(DETA_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=DetaObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=DetaObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
@@ -1929,21 +1942,25 @@ class DetaForObjectDetection(DetaPreTrainedModel):
|
|||||||
focal_alpha=self.config.focal_alpha,
|
focal_alpha=self.config.focal_alpha,
|
||||||
losses=losses,
|
losses=losses,
|
||||||
num_queries=self.config.num_queries,
|
num_queries=self.config.num_queries,
|
||||||
|
assign_first_stage=self.config.assign_first_stage,
|
||||||
|
assign_second_stage=self.config.assign_second_stage,
|
||||||
)
|
)
|
||||||
criterion.to(logits.device)
|
criterion.to(logits.device)
|
||||||
# Third: compute the losses, based on outputs and labels
|
# Third: compute the losses, based on outputs and labels
|
||||||
outputs_loss = {}
|
outputs_loss = {}
|
||||||
outputs_loss["logits"] = logits
|
outputs_loss["logits"] = logits
|
||||||
outputs_loss["pred_boxes"] = pred_boxes
|
outputs_loss["pred_boxes"] = pred_boxes
|
||||||
|
outputs_loss["init_reference"] = init_reference
|
||||||
if self.config.auxiliary_loss:
|
if self.config.auxiliary_loss:
|
||||||
intermediate = outputs.intermediate_hidden_states if return_dict else outputs[4]
|
|
||||||
outputs_class = self.class_embed(intermediate)
|
|
||||||
outputs_coord = self.bbox_embed(intermediate).sigmoid()
|
|
||||||
auxiliary_outputs = self._set_aux_loss(outputs_class, outputs_coord)
|
auxiliary_outputs = self._set_aux_loss(outputs_class, outputs_coord)
|
||||||
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
|
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
|
||||||
if self.config.two_stage:
|
if self.config.two_stage:
|
||||||
enc_outputs_coord = outputs.enc_outputs_coord_logits.sigmoid()
|
enc_outputs_coord = outputs.enc_outputs_coord_logits.sigmoid()
|
||||||
outputs["enc_outputs"] = {"pred_logits": outputs.enc_outputs_class, "pred_boxes": enc_outputs_coord}
|
outputs_loss["enc_outputs"] = {
|
||||||
|
"logits": outputs.enc_outputs_class,
|
||||||
|
"pred_boxes": enc_outputs_coord,
|
||||||
|
"anchors": outputs.output_proposals.sigmoid(),
|
||||||
|
}
|
||||||
|
|
||||||
loss_dict = criterion(outputs_loss, labels)
|
loss_dict = criterion(outputs_loss, labels)
|
||||||
# Fourth: compute total loss, as a weighted sum of the various losses
|
# Fourth: compute total loss, as a weighted sum of the various losses
|
||||||
@@ -1953,6 +1970,7 @@ class DetaForObjectDetection(DetaPreTrainedModel):
|
|||||||
aux_weight_dict = {}
|
aux_weight_dict = {}
|
||||||
for i in range(self.config.decoder_layers - 1):
|
for i in range(self.config.decoder_layers - 1):
|
||||||
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
||||||
|
aux_weight_dict.update({k + "_enc": v for k, v in weight_dict.items()})
|
||||||
weight_dict.update(aux_weight_dict)
|
weight_dict.update(aux_weight_dict)
|
||||||
loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
||||||
|
|
||||||
@@ -1983,6 +2001,7 @@ class DetaForObjectDetection(DetaPreTrainedModel):
|
|||||||
init_reference_points=outputs.init_reference_points,
|
init_reference_points=outputs.init_reference_points,
|
||||||
enc_outputs_class=outputs.enc_outputs_class,
|
enc_outputs_class=outputs.enc_outputs_class,
|
||||||
enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
|
enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
|
||||||
|
output_proposals=outputs.output_proposals,
|
||||||
)
|
)
|
||||||
|
|
||||||
return dict_outputs
|
return dict_outputs
|
||||||
@@ -2192,7 +2211,7 @@ class DetaLoss(nn.Module):
|
|||||||
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
|
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
|
||||||
losses applied, see each loss' doc.
|
losses applied, see each loss' doc.
|
||||||
"""
|
"""
|
||||||
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
|
outputs_without_aux = {k: v for k, v in outputs.items() if k not in ("auxiliary_outputs", "enc_outputs")}
|
||||||
|
|
||||||
# Retrieve the matching between the outputs of the last layer and the targets
|
# Retrieve the matching between the outputs of the last layer and the targets
|
||||||
if self.assign_second_stage:
|
if self.assign_second_stage:
|
||||||
@@ -2203,11 +2222,12 @@ class DetaLoss(nn.Module):
|
|||||||
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
||||||
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
||||||
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
||||||
# (Niels): comment out function below, distributed training to be added
|
# Check that we have initialized the distributed state
|
||||||
# if is_dist_avail_and_initialized():
|
world_size = 1
|
||||||
# torch.distributed.all_reduce(num_boxes)
|
if PartialState._shared_state != {}:
|
||||||
# (Niels) in original implementation, num_boxes is divided by get_world_size()
|
num_boxes = reduce(num_boxes)
|
||||||
num_boxes = torch.clamp(num_boxes, min=1).item()
|
world_size = PartialState().num_processes
|
||||||
|
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
|
||||||
|
|
||||||
# Compute all the requested losses
|
# Compute all the requested losses
|
||||||
losses = {}
|
losses = {}
|
||||||
@@ -2228,17 +2248,13 @@ class DetaLoss(nn.Module):
|
|||||||
enc_outputs = outputs["enc_outputs"]
|
enc_outputs = outputs["enc_outputs"]
|
||||||
bin_targets = copy.deepcopy(targets)
|
bin_targets = copy.deepcopy(targets)
|
||||||
for bt in bin_targets:
|
for bt in bin_targets:
|
||||||
bt["labels"] = torch.zeros_like(bt["labels"])
|
bt["class_labels"] = torch.zeros_like(bt["class_labels"])
|
||||||
if self.assign_first_stage:
|
if self.assign_first_stage:
|
||||||
indices = self.stg1_assigner(enc_outputs, bin_targets)
|
indices = self.stg1_assigner(enc_outputs, bin_targets)
|
||||||
else:
|
else:
|
||||||
indices = self.matcher(enc_outputs, bin_targets)
|
indices = self.matcher(enc_outputs, bin_targets)
|
||||||
for loss in self.losses:
|
for loss in self.losses:
|
||||||
kwargs = {}
|
l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes)
|
||||||
if loss == "labels":
|
|
||||||
# Logging is enabled only for the last layer
|
|
||||||
kwargs["log"] = False
|
|
||||||
l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs)
|
|
||||||
l_dict = {k + "_enc": v for k, v in l_dict.items()}
|
l_dict = {k + "_enc": v for k, v in l_dict.items()}
|
||||||
losses.update(l_dict)
|
losses.update(l_dict)
|
||||||
|
|
||||||
@@ -2662,7 +2678,7 @@ class DetaStage2Assigner(nn.Module):
|
|||||||
sampled_idxs,
|
sampled_idxs,
|
||||||
sampled_gt_classes,
|
sampled_gt_classes,
|
||||||
) = self._sample_proposals( # list of sampled proposal_ids, sampled_id -> [0, num_classes)+[bg_label]
|
) = self._sample_proposals( # list of sampled proposal_ids, sampled_id -> [0, num_classes)+[bg_label]
|
||||||
matched_idxs, matched_labels, targets[b]["labels"]
|
matched_idxs, matched_labels, targets[b]["class_labels"]
|
||||||
)
|
)
|
||||||
pos_pr_inds = sampled_idxs[sampled_gt_classes != self.bg_label]
|
pos_pr_inds = sampled_idxs[sampled_gt_classes != self.bg_label]
|
||||||
pos_gt_inds = matched_idxs[pos_pr_inds]
|
pos_gt_inds = matched_idxs[pos_pr_inds]
|
||||||
@@ -2727,7 +2743,7 @@ class DetaStage1Assigner(nn.Module):
|
|||||||
) # proposal_id -> highest_iou_gt_id, proposal_id -> [1 if iou > 0.7, 0 if iou < 0.3, -1 ow]
|
) # proposal_id -> highest_iou_gt_id, proposal_id -> [1 if iou > 0.7, 0 if iou < 0.3, -1 ow]
|
||||||
matched_labels = self._subsample_labels(matched_labels)
|
matched_labels = self._subsample_labels(matched_labels)
|
||||||
|
|
||||||
all_pr_inds = torch.arange(len(anchors))
|
all_pr_inds = torch.arange(len(anchors), device=matched_labels.device)
|
||||||
pos_pr_inds = all_pr_inds[matched_labels == 1]
|
pos_pr_inds = all_pr_inds[matched_labels == 1]
|
||||||
pos_gt_inds = matched_idxs[pos_pr_inds]
|
pos_gt_inds = matched_idxs[pos_pr_inds]
|
||||||
pos_pr_inds, pos_gt_inds = self.postprocess_indices(pos_pr_inds, pos_gt_inds, iou)
|
pos_pr_inds, pos_gt_inds = self.postprocess_indices(pos_pr_inds, pos_gt_inds, iou)
|
||||||
|
|||||||
@@ -57,14 +57,17 @@ class DetaModelTester:
|
|||||||
hidden_dropout_prob=0.1,
|
hidden_dropout_prob=0.1,
|
||||||
attention_probs_dropout_prob=0.1,
|
attention_probs_dropout_prob=0.1,
|
||||||
num_queries=12,
|
num_queries=12,
|
||||||
|
two_stage_num_proposals=12,
|
||||||
num_channels=3,
|
num_channels=3,
|
||||||
image_size=196,
|
image_size=224,
|
||||||
n_targets=8,
|
n_targets=8,
|
||||||
num_labels=91,
|
num_labels=91,
|
||||||
num_feature_levels=4,
|
num_feature_levels=4,
|
||||||
encoder_n_points=2,
|
encoder_n_points=2,
|
||||||
decoder_n_points=6,
|
decoder_n_points=6,
|
||||||
two_stage=False,
|
two_stage=True,
|
||||||
|
assign_first_stage=True,
|
||||||
|
assign_second_stage=True,
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@@ -78,6 +81,7 @@ class DetaModelTester:
|
|||||||
self.hidden_dropout_prob = hidden_dropout_prob
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
self.num_queries = num_queries
|
self.num_queries = num_queries
|
||||||
|
self.two_stage_num_proposals = two_stage_num_proposals
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.n_targets = n_targets
|
self.n_targets = n_targets
|
||||||
@@ -86,6 +90,8 @@ class DetaModelTester:
|
|||||||
self.encoder_n_points = encoder_n_points
|
self.encoder_n_points = encoder_n_points
|
||||||
self.decoder_n_points = decoder_n_points
|
self.decoder_n_points = decoder_n_points
|
||||||
self.two_stage = two_stage
|
self.two_stage = two_stage
|
||||||
|
self.assign_first_stage = assign_first_stage
|
||||||
|
self.assign_second_stage = assign_second_stage
|
||||||
|
|
||||||
# we also set the expected seq length for both encoder and decoder
|
# we also set the expected seq length for both encoder and decoder
|
||||||
self.encoder_seq_length = (
|
self.encoder_seq_length = (
|
||||||
@@ -96,7 +102,7 @@ class DetaModelTester:
|
|||||||
)
|
)
|
||||||
self.decoder_seq_length = self.num_queries
|
self.decoder_seq_length = self.num_queries
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self, model_class_name):
|
||||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||||
|
|
||||||
pixel_mask = torch.ones([self.batch_size, self.image_size, self.image_size], device=torch_device)
|
pixel_mask = torch.ones([self.batch_size, self.image_size, self.image_size], device=torch_device)
|
||||||
@@ -114,10 +120,10 @@ class DetaModelTester:
|
|||||||
target["masks"] = torch.rand(self.n_targets, self.image_size, self.image_size, device=torch_device)
|
target["masks"] = torch.rand(self.n_targets, self.image_size, self.image_size, device=torch_device)
|
||||||
labels.append(target)
|
labels.append(target)
|
||||||
|
|
||||||
config = self.get_config()
|
config = self.get_config(model_class_name)
|
||||||
return config, pixel_values, pixel_mask, labels
|
return config, pixel_values, pixel_mask, labels
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self, model_class_name):
|
||||||
resnet_config = ResNetConfig(
|
resnet_config = ResNetConfig(
|
||||||
num_channels=3,
|
num_channels=3,
|
||||||
embeddings_size=10,
|
embeddings_size=10,
|
||||||
@@ -128,6 +134,9 @@ class DetaModelTester:
|
|||||||
out_features=["stage2", "stage3", "stage4"],
|
out_features=["stage2", "stage3", "stage4"],
|
||||||
out_indices=[2, 3, 4],
|
out_indices=[2, 3, 4],
|
||||||
)
|
)
|
||||||
|
two_stage = model_class_name == "DetaForObjectDetection"
|
||||||
|
assign_first_stage = model_class_name == "DetaForObjectDetection"
|
||||||
|
assign_second_stage = model_class_name == "DetaForObjectDetection"
|
||||||
return DetaConfig(
|
return DetaConfig(
|
||||||
d_model=self.hidden_size,
|
d_model=self.hidden_size,
|
||||||
encoder_layers=self.num_hidden_layers,
|
encoder_layers=self.num_hidden_layers,
|
||||||
@@ -139,16 +148,19 @@ class DetaModelTester:
|
|||||||
dropout=self.hidden_dropout_prob,
|
dropout=self.hidden_dropout_prob,
|
||||||
attention_dropout=self.attention_probs_dropout_prob,
|
attention_dropout=self.attention_probs_dropout_prob,
|
||||||
num_queries=self.num_queries,
|
num_queries=self.num_queries,
|
||||||
|
two_stage_num_proposals=self.two_stage_num_proposals,
|
||||||
num_labels=self.num_labels,
|
num_labels=self.num_labels,
|
||||||
num_feature_levels=self.num_feature_levels,
|
num_feature_levels=self.num_feature_levels,
|
||||||
encoder_n_points=self.encoder_n_points,
|
encoder_n_points=self.encoder_n_points,
|
||||||
decoder_n_points=self.decoder_n_points,
|
decoder_n_points=self.decoder_n_points,
|
||||||
two_stage=self.two_stage,
|
two_stage=two_stage,
|
||||||
|
assign_first_stage=assign_first_stage,
|
||||||
|
assign_second_stage=assign_second_stage,
|
||||||
backbone_config=resnet_config,
|
backbone_config=resnet_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self, model_class_name="DetaModel"):
|
||||||
config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs()
|
config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs(model_class_name)
|
||||||
inputs_dict = {"pixel_values": pixel_values, "pixel_mask": pixel_mask}
|
inputs_dict = {"pixel_values": pixel_values, "pixel_mask": pixel_mask}
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
@@ -190,14 +202,14 @@ class DetaModelTester:
|
|||||||
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
|
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
|
||||||
result = model(pixel_values)
|
result = model(pixel_values)
|
||||||
|
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.two_stage_num_proposals, self.num_labels))
|
||||||
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
|
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.two_stage_num_proposals, 4))
|
||||||
|
|
||||||
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
|
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
|
||||||
|
|
||||||
self.parent.assertEqual(result.loss.shape, ())
|
self.parent.assertEqual(result.loss.shape, ())
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.two_stage_num_proposals, self.num_labels))
|
||||||
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
|
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.two_stage_num_proposals, 4))
|
||||||
|
|
||||||
|
|
||||||
@require_torchvision
|
@require_torchvision
|
||||||
@@ -267,19 +279,19 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||||||
self.config_tester.check_config_can_be_init_without_params()
|
self.config_tester.check_config_can_be_init_without_params()
|
||||||
|
|
||||||
def test_deta_model(self):
|
def test_deta_model(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs(model_class_name="DetaModel")
|
||||||
self.model_tester.create_and_check_deta_model(*config_and_inputs)
|
self.model_tester.create_and_check_deta_model(*config_and_inputs)
|
||||||
|
|
||||||
def test_deta_freeze_backbone(self):
|
def test_deta_freeze_backbone(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs(model_class_name="DetaModel")
|
||||||
self.model_tester.create_and_check_deta_freeze_backbone(*config_and_inputs)
|
self.model_tester.create_and_check_deta_freeze_backbone(*config_and_inputs)
|
||||||
|
|
||||||
def test_deta_unfreeze_backbone(self):
|
def test_deta_unfreeze_backbone(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs(model_class_name="DetaModel")
|
||||||
self.model_tester.create_and_check_deta_unfreeze_backbone(*config_and_inputs)
|
self.model_tester.create_and_check_deta_unfreeze_backbone(*config_and_inputs)
|
||||||
|
|
||||||
def test_deta_object_detection_head_model(self):
|
def test_deta_object_detection_head_model(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs(model_class_name="DetaForObjectDetection")
|
||||||
self.model_tester.create_and_check_deta_object_detection_head_model(*config_and_inputs)
|
self.model_tester.create_and_check_deta_object_detection_head_model(*config_and_inputs)
|
||||||
|
|
||||||
@unittest.skip(reason="DETA does not use inputs_embeds")
|
@unittest.skip(reason="DETA does not use inputs_embeds")
|
||||||
|
|||||||
Reference in New Issue
Block a user