Add doctests to Perceiver examples (#19129)
* Fix bug in example and add to tests * Fix failing tests * Check the size of logits * Code style * Try again... * Add expected loss for PerceiverForMaskedLM doctest Co-authored-by: Steven Anton <antonstv@amazon.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -801,6 +801,8 @@ class PerceiverModel(PerceiverPreTrainedModel):
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(inputs=inputs)
|
||||
>>> logits = outputs.logits
|
||||
>>> list(logits.shape)
|
||||
[1, 2]
|
||||
|
||||
>>> # to train, one can train the model using standard cross-entropy:
|
||||
>>> criterion = torch.nn.CrossEntropyLoss()
|
||||
@@ -810,6 +812,7 @@ class PerceiverModel(PerceiverPreTrainedModel):
|
||||
|
||||
>>> # EXAMPLE 2: using the Perceiver to classify images
|
||||
>>> # - we define an ImagePreprocessor, which can be used to embed images
|
||||
>>> config = PerceiverConfig(image_size=224)
|
||||
>>> preprocessor = PerceiverImagePreprocessor(
|
||||
... config,
|
||||
... prep_type="conv1x1",
|
||||
@@ -844,6 +847,8 @@ class PerceiverModel(PerceiverPreTrainedModel):
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(inputs=inputs)
|
||||
>>> logits = outputs.logits
|
||||
>>> list(logits.shape)
|
||||
[1, 2]
|
||||
|
||||
>>> # to train, one can train the model using standard cross-entropy:
|
||||
>>> criterion = torch.nn.CrossEntropyLoss()
|
||||
@@ -1017,7 +1022,12 @@ class PerceiverForMaskedLM(PerceiverPreTrainedModel):
|
||||
|
||||
>>> outputs = model(**inputs, labels=labels)
|
||||
>>> loss = outputs.loss
|
||||
>>> round(loss.item(), 2)
|
||||
19.87
|
||||
|
||||
>>> logits = outputs.logits
|
||||
>>> list(logits.shape)
|
||||
[1, 2048, 262]
|
||||
|
||||
>>> # inference
|
||||
>>> text = "This is an incomplete sentence where some words are missing."
|
||||
@@ -1030,6 +1040,8 @@ class PerceiverForMaskedLM(PerceiverPreTrainedModel):
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**encoding)
|
||||
>>> logits = outputs.logits
|
||||
>>> list(logits.shape)
|
||||
[1, 2048, 262]
|
||||
|
||||
>>> masked_tokens_predictions = logits[0, 52:61].argmax(dim=-1).tolist()
|
||||
>>> tokenizer.decode(masked_tokens_predictions)
|
||||
@@ -1128,6 +1140,8 @@ class PerceiverForSequenceClassification(PerceiverPreTrainedModel):
|
||||
>>> inputs = tokenizer(text, return_tensors="pt").input_ids
|
||||
>>> outputs = model(inputs=inputs)
|
||||
>>> logits = outputs.logits
|
||||
>>> list(logits.shape)
|
||||
[1, 2]
|
||||
```"""
|
||||
if inputs is not None and input_ids is not None:
|
||||
raise ValueError("You cannot use both `inputs` and `input_ids`")
|
||||
@@ -1265,9 +1279,13 @@ class PerceiverForImageClassificationLearned(PerceiverPreTrainedModel):
|
||||
>>> inputs = feature_extractor(images=image, return_tensors="pt").pixel_values
|
||||
>>> outputs = model(inputs=inputs)
|
||||
>>> logits = outputs.logits
|
||||
>>> list(logits.shape)
|
||||
[1, 1000]
|
||||
|
||||
>>> # model predicts one of the 1000 ImageNet classes
|
||||
>>> predicted_class_idx = logits.argmax(-1).item()
|
||||
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
|
||||
Predicted class: tabby, tabby cat
|
||||
```"""
|
||||
if inputs is not None and pixel_values is not None:
|
||||
raise ValueError("You cannot use both `inputs` and `pixel_values`")
|
||||
@@ -1402,9 +1420,13 @@ class PerceiverForImageClassificationFourier(PerceiverPreTrainedModel):
|
||||
>>> inputs = feature_extractor(images=image, return_tensors="pt").pixel_values
|
||||
>>> outputs = model(inputs=inputs)
|
||||
>>> logits = outputs.logits
|
||||
>>> list(logits.shape)
|
||||
[1, 1000]
|
||||
|
||||
>>> # model predicts one of the 1000 ImageNet classes
|
||||
>>> predicted_class_idx = logits.argmax(-1).item()
|
||||
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
|
||||
Predicted class: tabby, tabby cat
|
||||
```"""
|
||||
if inputs is not None and pixel_values is not None:
|
||||
raise ValueError("You cannot use both `inputs` and `pixel_values`")
|
||||
@@ -1539,9 +1561,13 @@ class PerceiverForImageClassificationConvProcessing(PerceiverPreTrainedModel):
|
||||
>>> inputs = feature_extractor(images=image, return_tensors="pt").pixel_values
|
||||
>>> outputs = model(inputs=inputs)
|
||||
>>> logits = outputs.logits
|
||||
>>> list(logits.shape)
|
||||
[1, 1000]
|
||||
|
||||
>>> # model predicts one of the 1000 ImageNet classes
|
||||
>>> predicted_class_idx = logits.argmax(-1).item()
|
||||
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
|
||||
Predicted class: tabby, tabby cat
|
||||
```"""
|
||||
if inputs is not None and pixel_values is not None:
|
||||
raise ValueError("You cannot use both `inputs` and `pixel_values`")
|
||||
@@ -1689,6 +1715,8 @@ class PerceiverForOpticalFlow(PerceiverPreTrainedModel):
|
||||
>>> patches = torch.randn(1, 2, 27, 368, 496)
|
||||
>>> outputs = model(inputs=patches)
|
||||
>>> logits = outputs.logits
|
||||
>>> list(logits.shape)
|
||||
[1, 368, 496, 2]
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
@@ -1915,6 +1943,14 @@ class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel):
|
||||
|
||||
>>> outputs = model(inputs=inputs, subsampled_output_points=subsampling)
|
||||
>>> logits = outputs.logits
|
||||
>>> list(logits["audio"].shape)
|
||||
[1, 240]
|
||||
|
||||
>>> list(logits["image"].shape)
|
||||
[1, 6272, 3]
|
||||
|
||||
>>> list(logits["label"].shape)
|
||||
[1, 700]
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
@@ -2925,7 +2961,6 @@ class PerceiverAudioPostprocessor(nn.Module):
|
||||
self.classifier = nn.Linear(in_channels, config.samples_per_patch)
|
||||
|
||||
def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, modality_sizes=None) -> torch.Tensor:
|
||||
|
||||
logits = self.classifier(inputs)
|
||||
return torch.reshape(logits, [inputs.shape[0], -1])
|
||||
|
||||
|
||||
@@ -58,6 +58,7 @@ src/transformers/models/opt/modeling_opt.py
|
||||
src/transformers/models/opt/modeling_tf_opt.py
|
||||
src/transformers/models/owlvit/modeling_owlvit.py
|
||||
src/transformers/models/pegasus/modeling_pegasus.py
|
||||
src/transformers/models/perceiver/modeling_perceiver.py
|
||||
src/transformers/models/plbart/modeling_plbart.py
|
||||
src/transformers/models/poolformer/modeling_poolformer.py
|
||||
src/transformers/models/reformer/modeling_reformer.py
|
||||
|
||||
Reference in New Issue
Block a user