Revert changes in logit size for semantic segmentation models (#15722)
* Revert changes in logit size for semantic segmentation models * Address review comments
This commit is contained in:
@@ -822,8 +822,17 @@ class SemanticSegmentationModelOutput(ModelOutput):
|
|||||||
Args:
|
Args:
|
||||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||||
Classification (or regression if config.num_labels==1) loss.
|
Classification (or regression if config.num_labels==1) loss.
|
||||||
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, height, width)`):
|
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
|
||||||
Classification scores for each pixel.
|
Classification scores for each pixel.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
|
||||||
|
to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
|
||||||
|
original image size as post-processing. You should always check your logits shape and resize as needed.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
shape `(batch_size, patch_size, hidden_size)`.
|
shape `(batch_size, patch_size, hidden_size)`.
|
||||||
|
|||||||
@@ -93,10 +93,6 @@ class BeitConfig(PretrainedConfig):
|
|||||||
Whether to concatenate the output of the auxiliary head with the input before the classification layer.
|
Whether to concatenate the output of the auxiliary head with the input before the classification layer.
|
||||||
semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
|
semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
|
||||||
The index that is ignored by the loss function of the semantic segmentation model.
|
The index that is ignored by the loss function of the semantic segmentation model.
|
||||||
legacy_output (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether to return the legacy outputs or not (with logits of shape `height / 4 , width / 4`)
|
|
||||||
|
|
||||||
This argument is only present for backward compatibility reasons and will be removed in v5 of Transformers.
|
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@@ -145,7 +141,6 @@ class BeitConfig(PretrainedConfig):
|
|||||||
auxiliary_num_convs=1,
|
auxiliary_num_convs=1,
|
||||||
auxiliary_concat_input=False,
|
auxiliary_concat_input=False,
|
||||||
semantic_loss_ignore_index=255,
|
semantic_loss_ignore_index=255,
|
||||||
legacy_output=False,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@@ -181,4 +176,3 @@ class BeitConfig(PretrainedConfig):
|
|||||||
self.auxiliary_num_convs = auxiliary_num_convs
|
self.auxiliary_num_convs = auxiliary_num_convs
|
||||||
self.auxiliary_concat_input = auxiliary_concat_input
|
self.auxiliary_concat_input = auxiliary_concat_input
|
||||||
self.semantic_loss_ignore_index = semantic_loss_ignore_index
|
self.semantic_loss_ignore_index = semantic_loss_ignore_index
|
||||||
self.legacy_output = legacy_output
|
|
||||||
|
|||||||
@@ -17,7 +17,6 @@
|
|||||||
|
|
||||||
import collections.abc
|
import collections.abc
|
||||||
import math
|
import math
|
||||||
import warnings
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -1121,8 +1120,11 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
|
|||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
def compute_loss(self, upsampled_logits, auxiliary_logits, labels):
|
def compute_loss(self, logits, auxiliary_logits, labels):
|
||||||
# upsample logits to the images' original size
|
# upsample logits to the images' original size
|
||||||
|
upsampled_logits = nn.functional.interpolate(
|
||||||
|
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
|
||||||
|
)
|
||||||
if auxiliary_logits is not None:
|
if auxiliary_logits is not None:
|
||||||
upsampled_auxiliary_logits = nn.functional.interpolate(
|
upsampled_auxiliary_logits = nn.functional.interpolate(
|
||||||
auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
|
auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
|
||||||
@@ -1145,17 +1147,11 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
|
|||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
legacy_output=None,
|
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
|
||||||
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
|
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
|
||||||
config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
|
config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
|
||||||
legacy_output (`bool`, *optional*):
|
|
||||||
Whether to return the legacy outputs or not (with logits of shape `height / 4 , width / 4`). Will default
|
|
||||||
to `self.config.legacy_output`.
|
|
||||||
|
|
||||||
This argument is only present for backward compatibility reasons and will be removed in v5 of Transformers.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
@@ -1181,14 +1177,6 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
|
|||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
)
|
)
|
||||||
legacy_output = legacy_output if legacy_output is not None else self.config.legacy_output
|
|
||||||
if not legacy_output:
|
|
||||||
warnings.warn(
|
|
||||||
"The output of this model has changed in v4.17.0 and the logits now have the same size as the inputs. "
|
|
||||||
"You can activate the previous behavior by passing `legacy_output=True` to this call or the "
|
|
||||||
"configuration of this model (only until v5, then that argument will be removed).",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = self.beit(
|
outputs = self.beit(
|
||||||
pixel_values,
|
pixel_values,
|
||||||
@@ -1216,10 +1204,6 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
|
|||||||
|
|
||||||
logits = self.decode_head(features)
|
logits = self.decode_head(features)
|
||||||
|
|
||||||
upsampled_logits = nn.functional.interpolate(
|
|
||||||
logits, size=pixel_values.shape[-2:], mode="bilinear", align_corners=False
|
|
||||||
)
|
|
||||||
|
|
||||||
auxiliary_logits = None
|
auxiliary_logits = None
|
||||||
if self.auxiliary_head is not None:
|
if self.auxiliary_head is not None:
|
||||||
auxiliary_logits = self.auxiliary_head(features)
|
auxiliary_logits = self.auxiliary_head(features)
|
||||||
@@ -1229,26 +1213,18 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
|
|||||||
if self.config.num_labels == 1:
|
if self.config.num_labels == 1:
|
||||||
raise ValueError("The number of labels should be greater than one")
|
raise ValueError("The number of labels should be greater than one")
|
||||||
else:
|
else:
|
||||||
loss = self.compute_loss(upsampled_logits, auxiliary_logits, labels)
|
loss = self.compute_loss(logits, auxiliary_logits, labels)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
output = (logits if legacy_output else upsampled_logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
else:
|
else:
|
||||||
output = (logits if legacy_output else upsampled_logits,) + outputs[3:]
|
output = (logits,) + outputs[3:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
if legacy_output:
|
return SemanticSegmentationModelOutput(
|
||||||
return SequenceClassifierOutput(
|
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return SemanticSegmentationModelOutput(
|
|
||||||
loss=loss,
|
|
||||||
logits=upsampled_logits,
|
|
||||||
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -83,10 +83,6 @@ class SegformerConfig(PretrainedConfig):
|
|||||||
required for the semantic segmentation model.
|
required for the semantic segmentation model.
|
||||||
semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
|
semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
|
||||||
The index that is ignored by the loss function of the semantic segmentation model.
|
The index that is ignored by the loss function of the semantic segmentation model.
|
||||||
legacy_output (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether to return the legacy outputs or not (with logits of shape `height / 4 , width / 4`)
|
|
||||||
|
|
||||||
This argument is only present for backward compatibility reasons and will be removed in v5 of Transformers.
|
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@@ -128,7 +124,6 @@ class SegformerConfig(PretrainedConfig):
|
|||||||
is_encoder_decoder=False,
|
is_encoder_decoder=False,
|
||||||
reshape_last_stage=True,
|
reshape_last_stage=True,
|
||||||
semantic_loss_ignore_index=255,
|
semantic_loss_ignore_index=255,
|
||||||
legacy_output=False,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@@ -154,4 +149,3 @@ class SegformerConfig(PretrainedConfig):
|
|||||||
self.decoder_hidden_size = decoder_hidden_size
|
self.decoder_hidden_size = decoder_hidden_size
|
||||||
self.reshape_last_stage = reshape_last_stage
|
self.reshape_last_stage = reshape_last_stage
|
||||||
self.semantic_loss_ignore_index = semantic_loss_ignore_index
|
self.semantic_loss_ignore_index = semantic_loss_ignore_index
|
||||||
self.legacy_output = legacy_output
|
|
||||||
|
|||||||
@@ -17,7 +17,6 @@
|
|||||||
|
|
||||||
import collections
|
import collections
|
||||||
import math
|
import math
|
||||||
import warnings
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
@@ -697,17 +696,11 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
|
|||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
legacy_output=None,
|
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
|
||||||
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
|
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
|
||||||
config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
|
config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
|
||||||
legacy_output (`bool`, *optional*):
|
|
||||||
Whether to return the legacy outputs or not (with logits of shape `height / 4 , width / 4`). Will default
|
|
||||||
to `self.config.legacy_output`.
|
|
||||||
|
|
||||||
This argument is only present for backward compatibility reasons and will be removed in v5 of Transformers.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
@@ -732,14 +725,6 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
|
|||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
)
|
)
|
||||||
legacy_output = legacy_output if legacy_output is not None else self.config.legacy_output
|
|
||||||
if not legacy_output:
|
|
||||||
warnings.warn(
|
|
||||||
"The output of this model has changed in v4.17.0 and the logits now have the same size as the inputs. "
|
|
||||||
"You can activate the previous behavior by passing `legacy_output=True` to this call or the "
|
|
||||||
"configuration of this model (only until v5, then that argument will be removed).",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = self.segformer(
|
outputs = self.segformer(
|
||||||
pixel_values,
|
pixel_values,
|
||||||
@@ -752,37 +737,28 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
|
|||||||
|
|
||||||
logits = self.decode_head(encoder_hidden_states)
|
logits = self.decode_head(encoder_hidden_states)
|
||||||
|
|
||||||
upsampled_logits = nn.functional.interpolate(
|
|
||||||
logits, size=pixel_values.shape[-2:], mode="bilinear", align_corners=False
|
|
||||||
)
|
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
if self.config.num_labels == 1:
|
if self.config.num_labels == 1:
|
||||||
raise ValueError("The number of labels should be greater than one")
|
raise ValueError("The number of labels should be greater than one")
|
||||||
else:
|
else:
|
||||||
# upsample logits to the images' original size
|
# upsample logits to the images' original size
|
||||||
|
upsampled_logits = nn.functional.interpolate(
|
||||||
|
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
|
||||||
|
)
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
|
loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
|
||||||
loss = loss_fct(upsampled_logits, labels)
|
loss = loss_fct(upsampled_logits, labels)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
output = (logits if legacy_output else upsampled_logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
else:
|
else:
|
||||||
output = (logits if legacy_output else upsampled_logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
if legacy_output:
|
return SemanticSegmentationModelOutput(
|
||||||
return SequenceClassifierOutput(
|
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return SemanticSegmentationModelOutput(
|
|
||||||
loss=loss,
|
|
||||||
logits=upsampled_logits,
|
|
||||||
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -162,11 +162,11 @@ class BeitModelTester:
|
|||||||
model.eval()
|
model.eval()
|
||||||
result = model(pixel_values)
|
result = model(pixel_values)
|
||||||
self.parent.assertEqual(
|
self.parent.assertEqual(
|
||||||
result.logits.shape, (self.batch_size, self.num_labels, self.image_size, self.image_size)
|
result.logits.shape, (self.batch_size, self.num_labels, self.image_size * 2, self.image_size * 2)
|
||||||
)
|
)
|
||||||
result = model(pixel_values, labels=pixel_labels)
|
result = model(pixel_values, labels=pixel_labels)
|
||||||
self.parent.assertEqual(
|
self.parent.assertEqual(
|
||||||
result.logits.shape, (self.batch_size, self.num_labels, self.image_size, self.image_size)
|
result.logits.shape, (self.batch_size, self.num_labels, self.image_size * 2, self.image_size * 2)
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
@@ -533,14 +533,14 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
|||||||
logits = outputs.logits
|
logits = outputs.logits
|
||||||
|
|
||||||
# verify the logits
|
# verify the logits
|
||||||
expected_shape = torch.Size((1, 150, 640, 640))
|
expected_shape = torch.Size((1, 150, 160, 160))
|
||||||
self.assertEqual(logits.shape, expected_shape)
|
self.assertEqual(logits.shape, expected_shape)
|
||||||
|
|
||||||
expected_slice = torch.tensor(
|
expected_slice = torch.tensor(
|
||||||
[
|
[
|
||||||
[[-4.9225, -4.9225, -4.6066], [-4.9225, -4.9225, -4.6066], [-4.6675, -4.6675, -4.3617]],
|
[[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]],
|
||||||
[[-5.8168, -5.8168, -5.5163], [-5.8168, -5.8168, -5.5163], [-5.5728, -5.5728, -5.2842]],
|
[[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]],
|
||||||
[[-0.0078, -0.0078, 0.4926], [-0.0078, -0.0078, 0.4926], [0.3664, 0.3664, 0.8309]],
|
[[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]],
|
||||||
]
|
]
|
||||||
).to(torch_device)
|
).to(torch_device)
|
||||||
|
|
||||||
|
|||||||
@@ -135,11 +135,11 @@ class SegformerModelTester:
|
|||||||
model.eval()
|
model.eval()
|
||||||
result = model(pixel_values)
|
result = model(pixel_values)
|
||||||
self.parent.assertEqual(
|
self.parent.assertEqual(
|
||||||
result.logits.shape, (self.batch_size, self.num_labels, self.image_size, self.image_size)
|
result.logits.shape, (self.batch_size, self.num_labels, self.image_size // 4, self.image_size // 4)
|
||||||
)
|
)
|
||||||
result = model(pixel_values, labels=labels)
|
result = model(pixel_values, labels=labels)
|
||||||
self.parent.assertEqual(
|
self.parent.assertEqual(
|
||||||
result.logits.shape, (self.batch_size, self.num_labels, self.image_size, self.image_size)
|
result.logits.shape, (self.batch_size, self.num_labels, self.image_size // 4, self.image_size // 4)
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
@@ -363,14 +363,14 @@ class SegformerModelIntegrationTest(unittest.TestCase):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(pixel_values)
|
outputs = model(pixel_values)
|
||||||
|
|
||||||
expected_shape = torch.Size((1, model.config.num_labels, 512, 512))
|
expected_shape = torch.Size((1, model.config.num_labels, 128, 128))
|
||||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||||
|
|
||||||
expected_slice = torch.tensor(
|
expected_slice = torch.tensor(
|
||||||
[
|
[
|
||||||
[[-4.6309, -4.6309, -4.7425], [-4.6309, -4.6309, -4.7425], [-4.7011, -4.7011, -4.8136]],
|
[[-4.6310, -5.5232, -6.2356], [-5.1921, -6.1444, -6.5996], [-5.4424, -6.2790, -6.7574]],
|
||||||
[[-12.1391, -12.1391, -12.2858], [-12.1391, -12.1391, -12.2858], [-12.2309, -12.2309, -12.3758]],
|
[[-12.1391, -13.3122, -13.9554], [-12.8732, -13.9352, -14.3563], [-12.9438, -13.8226, -14.2513]],
|
||||||
[[-12.5134, -12.5134, -12.6328], [-12.5134, -12.5134, -12.6328], [-12.5576, -12.5576, -12.6865]],
|
[[-12.5134, -13.4686, -14.4915], [-12.8669, -14.4343, -14.7758], [-13.2523, -14.5819, -15.0694]],
|
||||||
]
|
]
|
||||||
).to(torch_device)
|
).to(torch_device)
|
||||||
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-4))
|
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-4))
|
||||||
@@ -392,14 +392,14 @@ class SegformerModelIntegrationTest(unittest.TestCase):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(pixel_values)
|
outputs = model(pixel_values)
|
||||||
|
|
||||||
expected_shape = torch.Size((1, model.config.num_labels, 512, 512))
|
expected_shape = torch.Size((1, model.config.num_labels, 128, 128))
|
||||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||||
|
|
||||||
expected_slice = torch.tensor(
|
expected_slice = torch.tensor(
|
||||||
[
|
[
|
||||||
[[-13.5729, -13.5729, -13.6149], [-13.5729, -13.5729, -13.6149], [-13.6697, -13.6697, -13.7224]],
|
[[-13.5748, -13.9111, -12.6500], [-14.3500, -15.3683, -14.2328], [-14.7532, -16.0424, -15.6087]],
|
||||||
[[-17.1638, -17.1638, -17.0022], [-17.1638, -17.1638, -17.0022], [-17.1754, -17.1754, -17.0358]],
|
[[-17.1651, -15.8725, -12.9653], [-17.2580, -17.3718, -14.8223], [-16.6058, -16.8783, -16.7452]],
|
||||||
[[-3.6452, -3.6452, -3.5670], [-3.6452, -3.6452, -3.5670], [-3.5744, -3.5744, -3.5079]],
|
[[-3.6456, -3.0209, -1.4203], [-3.0797, -3.1959, -2.0000], [-1.8757, -1.9217, -1.6997]],
|
||||||
]
|
]
|
||||||
).to(torch_device)
|
).to(torch_device)
|
||||||
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-1))
|
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-1))
|
||||||
|
|||||||
Reference in New Issue
Block a user