Restructure DETR post-processing, return prediction scores (#19262)
* Restructure DetrFeatureExtractor post-processing methods * Update post_process_instance_segmentation and post_process_panoptic_segmentation methods to return prediction scores * Update DETR models docs
This commit is contained in:
@@ -171,9 +171,9 @@ mean Average Precision (mAP) and Panoptic Quality (PQ). The latter objects are i
|
|||||||
[[autodoc]] DetrFeatureExtractor
|
[[autodoc]] DetrFeatureExtractor
|
||||||
- __call__
|
- __call__
|
||||||
- pad_and_create_pixel_mask
|
- pad_and_create_pixel_mask
|
||||||
- post_process
|
- post_process_semantic_segmentation
|
||||||
- post_process_segmentation
|
- post_process_instance_segmentation
|
||||||
- post_process_panoptic
|
- post_process_panoptic_segmentation
|
||||||
|
|
||||||
## DetrModel
|
## DetrModel
|
||||||
|
|
||||||
|
|||||||
@@ -141,11 +141,33 @@ def binary_mask_to_rle(mask):
|
|||||||
return [x for x in runs]
|
return [x for x in runs]
|
||||||
|
|
||||||
|
|
||||||
|
def convert_segmentation_to_rle(segmentation):
|
||||||
|
"""
|
||||||
|
Converts given segmentation map of shape (height, width) to the run-length encoding (RLE) format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segmentation (`torch.Tensor` or `numpy.array`):
|
||||||
|
A segmentation map of shape `(height, width)` where each value denotes a segment or class id.
|
||||||
|
Returns:
|
||||||
|
`List[List]`: A list of lists, where each list is the run-length encoding of a segment / class id.
|
||||||
|
"""
|
||||||
|
segment_ids = torch.unique(segmentation)
|
||||||
|
|
||||||
|
run_length_encodings = []
|
||||||
|
for idx in segment_ids:
|
||||||
|
mask = torch.where(segmentation == idx, 1, 0)
|
||||||
|
rle = binary_mask_to_rle(mask)
|
||||||
|
run_length_encodings.append(rle)
|
||||||
|
|
||||||
|
return run_length_encodings
|
||||||
|
|
||||||
|
|
||||||
def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels):
|
def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels):
|
||||||
"""
|
"""
|
||||||
|
Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and
|
||||||
|
`labels`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores`
|
|
||||||
and `labels`.
|
|
||||||
masks (`torch.Tensor`):
|
masks (`torch.Tensor`):
|
||||||
A tensor of shape `(num_queries, height, width)`.
|
A tensor of shape `(num_queries, height, width)`.
|
||||||
scores (`torch.Tensor`):
|
scores (`torch.Tensor`):
|
||||||
@@ -168,6 +190,81 @@ def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_
|
|||||||
return masks[to_keep], scores[to_keep], labels[to_keep]
|
return masks[to_keep], scores[to_keep], labels[to_keep]
|
||||||
|
|
||||||
|
|
||||||
|
def check_segment_validity(mask_labels, mask_probs, k, overlap_mask_area_threshold=0.8):
|
||||||
|
# Get the mask associated with the k class
|
||||||
|
mask_k = mask_labels == k
|
||||||
|
mask_k_area = mask_k.sum()
|
||||||
|
|
||||||
|
# Compute the area of all the stuff in query k
|
||||||
|
original_area = (mask_probs[k] >= 0.5).sum()
|
||||||
|
mask_exists = mask_k_area > 0 and original_area > 0
|
||||||
|
|
||||||
|
# Eliminate disconnected tiny segments
|
||||||
|
if mask_exists:
|
||||||
|
area_ratio = mask_k_area / original_area
|
||||||
|
if not area_ratio.item() > overlap_mask_area_threshold:
|
||||||
|
mask_exists = False
|
||||||
|
|
||||||
|
return mask_exists, mask_k
|
||||||
|
|
||||||
|
|
||||||
|
def compute_segments(
|
||||||
|
mask_probs,
|
||||||
|
pred_scores,
|
||||||
|
pred_labels,
|
||||||
|
overlap_mask_area_threshold: float = 0.8,
|
||||||
|
label_ids_to_fuse: Optional[Set[int]] = None,
|
||||||
|
target_size: Tuple[int, int] = None,
|
||||||
|
):
|
||||||
|
height = mask_probs.shape[1] if target_size is None else target_size[0]
|
||||||
|
width = mask_probs.shape[2] if target_size is None else target_size[1]
|
||||||
|
|
||||||
|
segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs.device)
|
||||||
|
segments: List[Dict] = []
|
||||||
|
|
||||||
|
if target_size is not None:
|
||||||
|
mask_probs = nn.functional.interpolate(
|
||||||
|
mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
current_segment_id = 0
|
||||||
|
|
||||||
|
# Weigh each mask by its prediction score
|
||||||
|
mask_probs *= pred_scores.view(-1, 1, 1)
|
||||||
|
mask_labels = mask_probs.argmax(0) # [height, width]
|
||||||
|
|
||||||
|
# Keep track of instances of each class
|
||||||
|
stuff_memory_list: Dict[str, int] = {}
|
||||||
|
for k in range(pred_labels.shape[0]):
|
||||||
|
pred_class = pred_labels[k].item()
|
||||||
|
should_fuse = pred_class in label_ids_to_fuse
|
||||||
|
|
||||||
|
# Check if mask exists and large enough to be a segment
|
||||||
|
mask_exists, mask_k = check_segment_validity(mask_labels, mask_probs, k, overlap_mask_area_threshold)
|
||||||
|
|
||||||
|
if mask_exists:
|
||||||
|
if pred_class in stuff_memory_list:
|
||||||
|
current_segment_id = stuff_memory_list[pred_class]
|
||||||
|
else:
|
||||||
|
current_segment_id += 1
|
||||||
|
|
||||||
|
# Add current object segment to final segmentation map
|
||||||
|
segmentation[mask_k] = current_segment_id
|
||||||
|
segment_score = round(pred_scores[k].item(), 6)
|
||||||
|
segments.append(
|
||||||
|
{
|
||||||
|
"id": current_segment_id,
|
||||||
|
"label_id": pred_class,
|
||||||
|
"was_fused": should_fuse,
|
||||||
|
"score": segment_score,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if should_fuse:
|
||||||
|
stuff_memory_list[pred_class] = current_segment_id
|
||||||
|
|
||||||
|
return segmentation, segments
|
||||||
|
|
||||||
|
|
||||||
class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
||||||
r"""
|
r"""
|
||||||
Constructs a DETR feature extractor.
|
Constructs a DETR feature extractor.
|
||||||
@@ -1098,7 +1195,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|||||||
|
|
||||||
semantic_segmentation = []
|
semantic_segmentation = []
|
||||||
for idx in range(batch_size):
|
for idx in range(batch_size):
|
||||||
resized_logits = torch.nn.functional.interpolate(
|
resized_logits = nn.functional.interpolate(
|
||||||
segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
|
segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
|
||||||
)
|
)
|
||||||
semantic_map = resized_logits[0].argmax(dim=0)
|
semantic_map = resized_logits[0].argmax(dim=0)
|
||||||
@@ -1114,31 +1211,34 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|||||||
outputs,
|
outputs,
|
||||||
threshold: float = 0.5,
|
threshold: float = 0.5,
|
||||||
overlap_mask_area_threshold: float = 0.8,
|
overlap_mask_area_threshold: float = 0.8,
|
||||||
target_sizes: List[Tuple] = None,
|
target_sizes: Optional[List[Tuple[int, int]]] = None,
|
||||||
return_coco_annotation: Optional[bool] = False,
|
return_coco_annotation: Optional[bool] = False,
|
||||||
):
|
) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
Converts the output of [`DetrForSegmentation`] into instance segmentation predictions. Only supports PyTorch.
|
Converts the output of [`DetrForSegmentation`] into instance segmentation predictions. Only supports PyTorch.
|
||||||
outputs ([`DetrForSegmentation`]):
|
outputs ([`DetrForSegmentation`]):
|
||||||
Raw outputs of the model.
|
Raw outputs of the model.
|
||||||
threshold (`float`, *optional*):
|
threshold (`float`, *optional*, defaults to 0.5):
|
||||||
The probability score threshold to keep predicted instance masks, defaults to 0.5.
|
The probability score threshold to keep predicted instance masks.
|
||||||
overlap_mask_area_threshold (`float`, *optional*):
|
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
|
||||||
The overlap mask area threshold to merge or discard small disconnected parts within each binary
|
The overlap mask area threshold to merge or discard small disconnected parts within each binary
|
||||||
instance mask, defaults to 0.8.
|
instance mask.
|
||||||
target_sizes (`List[Tuple]`, *optional*, defaults to `None`):
|
target_sizes (`List[Tuple]`, *optional*):
|
||||||
List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
|
List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
|
||||||
final size (height, width) of each prediction. If left to None, predictions will not be resized.
|
final size (height, width) of each prediction. If left to None, predictions will not be resized.
|
||||||
return_coco_annotation (`bool`, *optional*, defaults to `False`):
|
return_coco_annotation (`bool`, *optional*):
|
||||||
If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE) format.
|
Defaults to `False`. If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE)
|
||||||
|
format.
|
||||||
Returns:
|
Returns:
|
||||||
`List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
|
`List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
|
||||||
- **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or
|
- **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or
|
||||||
`List[List]` run-length encoding (RLE) of the segmentation map if return_coco_format is set to `True`.
|
`List[List]` run-length encoding (RLE) of the segmentation map if return_coco_annotation is set to
|
||||||
- **segment_ids** -- A dictionary that maps segment ids to semantic class ids.
|
`True`. Set to `None` if no mask if found above `threshold`.
|
||||||
|
- **segments_info** -- A dictionary that contains additional information on each segment.
|
||||||
- **id** -- An integer representing the `segment_id`.
|
- **id** -- An integer representing the `segment_id`.
|
||||||
- **label_id** -- An integer representing the segment's label / semantic class id.
|
- **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
|
||||||
|
- **score** -- Prediction score of segment with `segment_id`.
|
||||||
"""
|
"""
|
||||||
class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
|
class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
|
||||||
masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
|
masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
|
||||||
@@ -1159,76 +1259,27 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|||||||
mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
|
mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
|
||||||
)
|
)
|
||||||
|
|
||||||
height, width = target_sizes[i][0], target_sizes[i][1]
|
# No mask found
|
||||||
segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs_item.device)
|
if mask_probs_item.shape[0] <= 0:
|
||||||
|
segmentation = None
|
||||||
segments: List[Dict] = []
|
segments: List[Dict] = []
|
||||||
|
|
||||||
object_detected = mask_probs_item.shape[0] > 0
|
|
||||||
|
|
||||||
if object_detected:
|
|
||||||
# Resize mask to corresponding target_size
|
|
||||||
if target_sizes is not None:
|
|
||||||
mask_probs_item = torch.nn.functional.interpolate(
|
|
||||||
mask_probs_item.unsqueeze(0),
|
|
||||||
size=target_sizes[i],
|
|
||||||
mode="bilinear",
|
|
||||||
align_corners=False,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
current_segment_id = 0
|
|
||||||
|
|
||||||
# Weigh each mask by its prediction score
|
|
||||||
mask_probs_item *= pred_scores_item.view(-1, 1, 1)
|
|
||||||
mask_labels_item = mask_probs_item.argmax(0) # [height, width]
|
|
||||||
|
|
||||||
# Keep track of instances of each class
|
|
||||||
stuff_memory_list: Dict[str, int] = {}
|
|
||||||
for k in range(pred_labels_item.shape[0]):
|
|
||||||
# Get the mask associated with the k class
|
|
||||||
pred_class = pred_labels_item[k].item()
|
|
||||||
mask_k = mask_labels_item == k
|
|
||||||
mask_k_area = mask_k.sum()
|
|
||||||
|
|
||||||
# Compute the area of all the stuff in query k
|
|
||||||
original_area = (mask_probs_item[k] >= 0.5).sum()
|
|
||||||
mask_exists = mask_k_area > 0 and original_area > 0
|
|
||||||
|
|
||||||
if mask_exists:
|
|
||||||
# Eliminate segments with mask area below threshold
|
|
||||||
area_ratio = mask_k_area / original_area
|
|
||||||
if not area_ratio.item() > overlap_mask_area_threshold:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Add corresponding class id
|
# Get segmentation map and segment information of batch item
|
||||||
if pred_class in stuff_memory_list:
|
target_size = target_sizes[i] if target_sizes is not None else None
|
||||||
current_segment_id = stuff_memory_list[pred_class]
|
segmentation, segments = compute_segments(
|
||||||
else:
|
mask_probs_item,
|
||||||
current_segment_id += 1
|
pred_scores_item,
|
||||||
|
pred_labels_item,
|
||||||
# Add current object segment to final segmentation map
|
overlap_mask_area_threshold,
|
||||||
segmentation[mask_k] = current_segment_id
|
target_size,
|
||||||
segments.append(
|
|
||||||
{
|
|
||||||
"id": current_segment_id,
|
|
||||||
"label_id": pred_class,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
segmentation -= 1
|
|
||||||
|
|
||||||
# Return segmentation map in run-length encoding (RLE) format
|
# Return segmentation map in run-length encoding (RLE) format
|
||||||
if return_coco_annotation:
|
if return_coco_annotation:
|
||||||
segment_ids = torch.unique(segmentation)
|
segmentation = convert_segmentation_to_rle(segmentation)
|
||||||
|
|
||||||
run_length_encodings = []
|
results.append({"segmentation": segmentation, "segments_info": segments})
|
||||||
for idx in segment_ids:
|
|
||||||
mask = torch.where(segmentation == idx, 1, 0)
|
|
||||||
rle = binary_mask_to_rle(mask)
|
|
||||||
run_length_encodings.append(rle)
|
|
||||||
|
|
||||||
segmentation = run_length_encodings
|
|
||||||
|
|
||||||
results.append({"segmentation": segmentation, "segment_ids": segments})
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def post_process_panoptic_segmentation(
|
def post_process_panoptic_segmentation(
|
||||||
@@ -1237,7 +1288,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|||||||
threshold: float = 0.5,
|
threshold: float = 0.5,
|
||||||
overlap_mask_area_threshold: float = 0.8,
|
overlap_mask_area_threshold: float = 0.8,
|
||||||
label_ids_to_fuse: Optional[Set[int]] = None,
|
label_ids_to_fuse: Optional[Set[int]] = None,
|
||||||
target_sizes: List[Tuple] = None,
|
target_sizes: Optional[List[Tuple[int, int]]] = None,
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -1250,7 +1301,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|||||||
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
|
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
|
||||||
The overlap mask area threshold to merge or discard small disconnected parts within each binary
|
The overlap mask area threshold to merge or discard small disconnected parts within each binary
|
||||||
instance mask.
|
instance mask.
|
||||||
label_ids_to_fuse (`Set[int]`, *optional*, defaults to `None`):
|
label_ids_to_fuse (`Set[int]`, *optional*):
|
||||||
The labels in this state will have all their instances be fused together. For instance we could say
|
The labels in this state will have all their instances be fused together. For instance we could say
|
||||||
there can only be one sky in an image, but several persons, so the label ID for sky would be in that
|
there can only be one sky in an image, but several persons, so the label ID for sky would be in that
|
||||||
set, but not the one for person.
|
set, but not the one for person.
|
||||||
@@ -1260,13 +1311,15 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|||||||
resized.
|
resized.
|
||||||
Returns:
|
Returns:
|
||||||
`List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
|
`List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
|
||||||
- **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`. If
|
- **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id` or
|
||||||
`target_sizes` is specified, segmentation is resized to the corresponding `target_sizes` entry.
|
`None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized to
|
||||||
- **segment_ids** -- A dictionary that maps segment ids to semantic class ids.
|
the corresponding `target_sizes` entry.
|
||||||
- **id** -- An integer representing the `segment_id`.
|
- **segments_info** -- A dictionary that contains additional information on each segment.
|
||||||
- **label_id** -- An integer representing the segment's label / semantic class id.
|
- **id** -- an integer representing the `segment_id`.
|
||||||
|
- **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
|
||||||
- **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise.
|
- **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise.
|
||||||
Multiple instances of the same class / label were fused and assigned a single `segment_id`.
|
Multiple instances of the same class / label were fused and assigned a single `segment_id`.
|
||||||
|
- **score** -- Prediction score of segment with `segment_id`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if label_ids_to_fuse is None:
|
if label_ids_to_fuse is None:
|
||||||
@@ -1292,67 +1345,22 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|||||||
mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
|
mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
|
||||||
)
|
)
|
||||||
|
|
||||||
height, width = target_sizes[i][0], target_sizes[i][1]
|
# No mask found
|
||||||
segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs_item.device)
|
if mask_probs_item.shape[0] <= 0:
|
||||||
|
segmentation = None
|
||||||
segments: List[Dict] = []
|
segments: List[Dict] = []
|
||||||
|
|
||||||
object_detected = mask_probs_item.shape[0] > 0
|
|
||||||
|
|
||||||
if object_detected:
|
|
||||||
# Resize mask to corresponding target_size
|
|
||||||
if target_sizes is not None:
|
|
||||||
mask_probs_item = torch.nn.functional.interpolate(
|
|
||||||
mask_probs_item.unsqueeze(0),
|
|
||||||
size=target_sizes[i],
|
|
||||||
mode="bilinear",
|
|
||||||
align_corners=False,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
current_segment_id = 0
|
|
||||||
|
|
||||||
# Weigh each mask by its prediction score
|
|
||||||
mask_probs_item *= pred_scores_item.view(-1, 1, 1)
|
|
||||||
mask_labels_item = mask_probs_item.argmax(0) # [height, width]
|
|
||||||
|
|
||||||
# Keep track of instances of each class
|
|
||||||
stuff_memory_list: Dict[str, int] = {}
|
|
||||||
for k in range(pred_labels_item.shape[0]):
|
|
||||||
pred_class = pred_labels_item[k].item()
|
|
||||||
should_fuse = pred_class in label_ids_to_fuse
|
|
||||||
|
|
||||||
# Get the mask associated with the k class
|
|
||||||
mask_k = mask_labels_item == k
|
|
||||||
mask_k_area = mask_k.sum()
|
|
||||||
|
|
||||||
# Compute the area of all the stuff in query k
|
|
||||||
original_area = (mask_probs_item[k] >= 0.5).sum()
|
|
||||||
mask_exists = mask_k_area > 0 and original_area > 0
|
|
||||||
|
|
||||||
if mask_exists:
|
|
||||||
# Eliminate disconnected tiny segments
|
|
||||||
area_ratio = mask_k_area / original_area
|
|
||||||
if not area_ratio.item() > overlap_mask_area_threshold:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Add corresponding class id
|
# Get segmentation map and segment information of batch item
|
||||||
if pred_class in stuff_memory_list:
|
target_size = target_sizes[i] if target_sizes is not None else None
|
||||||
current_segment_id = stuff_memory_list[pred_class]
|
segmentation, segments = compute_segments(
|
||||||
else:
|
mask_probs_item,
|
||||||
current_segment_id += 1
|
pred_scores_item,
|
||||||
|
pred_labels_item,
|
||||||
# Add current object segment to final segmentation map
|
overlap_mask_area_threshold,
|
||||||
segmentation[mask_k] = current_segment_id
|
label_ids_to_fuse,
|
||||||
segments.append(
|
target_size,
|
||||||
{
|
|
||||||
"id": current_segment_id,
|
|
||||||
"label_id": pred_class,
|
|
||||||
"was_fused": should_fuse,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
if should_fuse:
|
|
||||||
stuff_memory_list[pred_class] = current_segment_id
|
|
||||||
else:
|
|
||||||
segmentation -= 1
|
|
||||||
|
|
||||||
results.append({"segmentation": segmentation, "segment_ids": segments})
|
results.append({"segmentation": segmentation, "segments_info": segments})
|
||||||
return results
|
return results
|
||||||
|
|||||||
@@ -1605,12 +1605,12 @@ class DetrForSegmentation(DetrPreTrainedModel):
|
|||||||
|
|
||||||
>>> # Use the `post_process_panoptic_segmentation` method of `DetrFeatureExtractor` to retrieve post-processed panoptic segmentation maps
|
>>> # Use the `post_process_panoptic_segmentation` method of `DetrFeatureExtractor` to retrieve post-processed panoptic segmentation maps
|
||||||
>>> # Segmentation results are returned as a list of dictionaries
|
>>> # Segmentation results are returned as a list of dictionaries
|
||||||
>>> result = feature_extractor.post_process_panoptic_segmentation(outputs, processed_sizes)
|
>>> result = feature_extractor.post_process_panoptic_segmentation(outputs, target_size=[(300, 500)])
|
||||||
|
|
||||||
>>> # A tensor of shape (height, width) where each value denotes a segment id
|
>>> # A tensor of shape (height, width) where each value denotes a segment id
|
||||||
>>> panoptic_seg = result[0]["segmentation"]
|
>>> panoptic_seg = result[0]["segmentation"]
|
||||||
>>> # Get mapping of segment ids to semantic class ids
|
>>> # Get mapping of segment ids to semantic class ids
|
||||||
>>> panoptic_segments_info = result[0]["segment_ids"]
|
>>> panoptic_segments_info = result[0]["segments_info"]
|
||||||
```"""
|
```"""
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|||||||
Reference in New Issue
Block a user