add doctests for DETR (#17786)
* add: check labels for detr object detection doctests * add: check shapes * add: add detr to documentation_tests.py * fix: make fixup output * fix: add a comment
This commit is contained in:
@@ -1240,6 +1240,8 @@ class DetrModel(DetrPreTrainedModel):
|
|||||||
>>> inputs = feature_extractor(images=image, return_tensors="pt")
|
>>> inputs = feature_extractor(images=image, return_tensors="pt")
|
||||||
>>> outputs = model(**inputs)
|
>>> outputs = model(**inputs)
|
||||||
>>> last_hidden_states = outputs.last_hidden_state
|
>>> last_hidden_states = outputs.last_hidden_state
|
||||||
|
>>> list(last_hidden_states.shape)
|
||||||
|
[1, 100, 256]
|
||||||
```"""
|
```"""
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@@ -1399,8 +1401,16 @@ class DetrForObjectDetection(DetrPreTrainedModel):
|
|||||||
>>> inputs = feature_extractor(images=image, return_tensors="pt")
|
>>> inputs = feature_extractor(images=image, return_tensors="pt")
|
||||||
>>> outputs = model(**inputs)
|
>>> outputs = model(**inputs)
|
||||||
>>> # model predicts bounding boxes and corresponding COCO classes
|
>>> # model predicts bounding boxes and corresponding COCO classes
|
||||||
>>> logits = outputs.logits
|
>>> bboxes, logits = outputs.pred_boxes, outputs.logits
|
||||||
>>> bboxes = outputs.pred_boxes
|
|
||||||
|
>>> # get probability per object class and remove the no-object class
|
||||||
|
>>> probas_per_class = outputs.logits.softmax(-1)[:, :, :-1]
|
||||||
|
>>> objects_to_keep = probas_per_class.max(-1).values > 0.9
|
||||||
|
|
||||||
|
>>> ids, _ = probas_per_class.max(-1).indices[objects_to_keep].sort()
|
||||||
|
>>> labels = [model.config.id2label[id.item()] for id in ids]
|
||||||
|
>>> labels
|
||||||
|
['cat', 'cat', 'couch', 'remote', 'remote']
|
||||||
```"""
|
```"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
@@ -1556,8 +1566,16 @@ class DetrForSegmentation(DetrPreTrainedModel):
|
|||||||
>>> outputs = model(**inputs)
|
>>> outputs = model(**inputs)
|
||||||
>>> # model predicts COCO classes, bounding boxes, and masks
|
>>> # model predicts COCO classes, bounding boxes, and masks
|
||||||
>>> logits = outputs.logits
|
>>> logits = outputs.logits
|
||||||
|
>>> list(logits.shape)
|
||||||
|
[1, 100, 251]
|
||||||
|
|
||||||
>>> bboxes = outputs.pred_boxes
|
>>> bboxes = outputs.pred_boxes
|
||||||
|
>>> list(bboxes.shape)
|
||||||
|
[1, 100, 4]
|
||||||
|
|
||||||
>>> masks = outputs.pred_masks
|
>>> masks = outputs.pred_masks
|
||||||
|
>>> list(masks.shape)
|
||||||
|
[1, 100, 200, 267]
|
||||||
```"""
|
```"""
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ src/transformers/models/cvt/modeling_cvt.py
|
|||||||
src/transformers/models/data2vec/modeling_data2vec_audio.py
|
src/transformers/models/data2vec/modeling_data2vec_audio.py
|
||||||
src/transformers/models/data2vec/modeling_data2vec_vision.py
|
src/transformers/models/data2vec/modeling_data2vec_vision.py
|
||||||
src/transformers/models/deit/modeling_deit.py
|
src/transformers/models/deit/modeling_deit.py
|
||||||
|
src/transformers/models/detr/modeling_detr.py
|
||||||
src/transformers/models/dpt/modeling_dpt.py
|
src/transformers/models/dpt/modeling_dpt.py
|
||||||
src/transformers/models/electra/modeling_electra.py
|
src/transformers/models/electra/modeling_electra.py
|
||||||
src/transformers/models/electra/modeling_tf_electra.py
|
src/transformers/models/electra/modeling_tf_electra.py
|
||||||
|
|||||||
Reference in New Issue
Block a user