@@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
@@ -1622,3 +1623,40 @@ class SampleTSPredictionOutput(ModelOutput):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
sequences: torch.FloatTensor = None
|
sequences: torch.FloatTensor = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MaskedImageModelingOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for outputs of masked image completion / in-painting models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
|
||||||
|
Reconstruction loss.
|
||||||
|
reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||||
|
Reconstructed / completed images.
|
||||||
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or
|
||||||
|
when `config.output_hidden_states=True`):
|
||||||
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||||
|
one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
|
||||||
|
(also called feature maps) of the model at the output of each stage.
|
||||||
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when
|
||||||
|
`config.output_attentions=True`):
|
||||||
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
|
||||||
|
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
|
||||||
|
the self-attention heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
loss: Optional[torch.FloatTensor] = None
|
||||||
|
reconstruction: torch.FloatTensor = None
|
||||||
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logits(self):
|
||||||
|
warnings.warn(
|
||||||
|
"logits attribute is deprecated and will be removed in version 5 of Transformers."
|
||||||
|
" Please use the reconstruction attribute to retrieve the final output instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
return self.reconstruction
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
@@ -55,8 +56,8 @@ class TFBaseModelOutputWithNoAttention(ModelOutput):
|
|||||||
last_hidden_state (`tf.Tensor` shape `(batch_size, num_channels, height, width)`):
|
last_hidden_state (`tf.Tensor` shape `(batch_size, num_channels, height, width)`):
|
||||||
Sequence of hidden-states at the output of the last layer of the model.
|
Sequence of hidden-states at the output of the last layer of the model.
|
||||||
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
|
||||||
one for the output of each layer) of shape `(batch_size, num_channels, height, width)`.
|
the output of each layer) of shape `(batch_size, num_channels, height, width)`.
|
||||||
|
|
||||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||||
"""
|
"""
|
||||||
@@ -949,3 +950,40 @@ class TFImageClassifierOutputWithNoAttention(ModelOutput):
|
|||||||
loss: Optional[tf.Tensor] = None
|
loss: Optional[tf.Tensor] = None
|
||||||
logits: tf.Tensor = None
|
logits: tf.Tensor = None
|
||||||
hidden_states: Optional[Tuple[tf.Tensor, ...]] = None
|
hidden_states: Optional[Tuple[tf.Tensor, ...]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TFMaskedImageModelingOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for outputs of masked image completion / in-painting models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
|
||||||
|
Reconstruction loss.
|
||||||
|
reconstruction (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
|
||||||
|
Reconstructed / completed images.
|
||||||
|
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when
|
||||||
|
`config.output_hidden_states=True`):
|
||||||
|
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
|
||||||
|
the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called
|
||||||
|
feature maps) of the model at the output of each stage.
|
||||||
|
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when
|
||||||
|
`config.output_attentions=True`):
|
||||||
|
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`.
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
loss: Optional[tf.Tensor] = None
|
||||||
|
reconstruction: tf.Tensor = None
|
||||||
|
hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logits(self):
|
||||||
|
warnings.warn(
|
||||||
|
"logits attribute is deprecated and will be removed in version 5 of Transformers."
|
||||||
|
" Please use the reconstruction attribute to retrieve the final output instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
return self.reconstruction
|
||||||
|
|||||||
@@ -26,7 +26,12 @@ from torch import nn
|
|||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedLMOutput
|
from ...modeling_outputs import (
|
||||||
|
BaseModelOutput,
|
||||||
|
BaseModelOutputWithPooling,
|
||||||
|
ImageClassifierOutput,
|
||||||
|
MaskedImageModelingOutput,
|
||||||
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
@@ -592,7 +597,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
|
|||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=MaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values: Optional[torch.Tensor] = None,
|
pixel_values: Optional[torch.Tensor] = None,
|
||||||
@@ -601,7 +606,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[tuple, MaskedLMOutput]:
|
) -> Union[tuple, MaskedImageModelingOutput]:
|
||||||
r"""
|
r"""
|
||||||
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
|
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
|
||||||
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
||||||
@@ -627,7 +632,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
|
|||||||
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
|
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
|
||||||
|
|
||||||
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
|
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||||
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits
|
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
|
||||||
>>> list(reconstructed_pixel_values.shape)
|
>>> list(reconstructed_pixel_values.shape)
|
||||||
[1, 3, 224, 224]
|
[1, 3, 224, 224]
|
||||||
```"""
|
```"""
|
||||||
@@ -670,9 +675,9 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
|
|||||||
output = (reconstructed_pixel_values,) + outputs[1:]
|
output = (reconstructed_pixel_values,) + outputs[1:]
|
||||||
return ((masked_im_loss,) + output) if masked_im_loss is not None else output
|
return ((masked_im_loss,) + output) if masked_im_loss is not None else output
|
||||||
|
|
||||||
return MaskedLMOutput(
|
return MaskedImageModelingOutput(
|
||||||
loss=masked_im_loss,
|
loss=masked_im_loss,
|
||||||
logits=reconstructed_pixel_values,
|
reconstruction=reconstructed_pixel_values,
|
||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ from ...modeling_tf_outputs import (
|
|||||||
TFBaseModelOutput,
|
TFBaseModelOutput,
|
||||||
TFBaseModelOutputWithPooling,
|
TFBaseModelOutputWithPooling,
|
||||||
TFImageClassifierOutput,
|
TFImageClassifierOutput,
|
||||||
TFMaskedLMOutput,
|
TFMaskedImageModelingOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_tf_utils import (
|
from ...modeling_tf_utils import (
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
@@ -769,7 +769,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
|
|||||||
|
|
||||||
@unpack_inputs
|
@unpack_inputs
|
||||||
@add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=TFMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=TFMaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
pixel_values: Optional[tf.Tensor] = None,
|
pixel_values: Optional[tf.Tensor] = None,
|
||||||
@@ -779,7 +779,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> Union[tuple, TFMaskedLMOutput]:
|
) -> Union[tuple, TFMaskedImageModelingOutput]:
|
||||||
r"""
|
r"""
|
||||||
bool_masked_pos (`tf.Tensor` of type bool and shape `(batch_size, num_patches)`):
|
bool_masked_pos (`tf.Tensor` of type bool and shape `(batch_size, num_patches)`):
|
||||||
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
||||||
@@ -805,7 +805,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
|
|||||||
>>> bool_masked_pos = tf.cast(tf.random.uniform((1, num_patches), minval=0, maxval=2, dtype=tf.int32), tf.bool)
|
>>> bool_masked_pos = tf.cast(tf.random.uniform((1, num_patches), minval=0, maxval=2, dtype=tf.int32), tf.bool)
|
||||||
|
|
||||||
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
|
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||||
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits
|
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
|
||||||
>>> list(reconstructed_pixel_values.shape)
|
>>> list(reconstructed_pixel_values.shape)
|
||||||
[1, 3, 224, 224]
|
[1, 3, 224, 224]
|
||||||
```"""
|
```"""
|
||||||
@@ -860,18 +860,20 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
|
|||||||
output = (reconstructed_pixel_values,) + outputs[1:]
|
output = (reconstructed_pixel_values,) + outputs[1:]
|
||||||
return ((masked_im_loss,) + output) if masked_im_loss is not None else output
|
return ((masked_im_loss,) + output) if masked_im_loss is not None else output
|
||||||
|
|
||||||
return TFMaskedLMOutput(
|
return TFMaskedImageModelingOutput(
|
||||||
loss=masked_im_loss,
|
loss=masked_im_loss,
|
||||||
logits=reconstructed_pixel_values,
|
reconstruction=reconstructed_pixel_values,
|
||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput:
|
def serving_output(self, output: TFMaskedImageModelingOutput) -> TFMaskedImageModelingOutput:
|
||||||
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||||
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||||
|
|
||||||
return TFMaskedLMOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions)
|
return TFMaskedImageModelingOutput(
|
||||||
|
reconstruction=output.reconstruction, hidden_states=hidden_states, attentions=attentions
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
|
|||||||
@@ -25,7 +25,12 @@ from torch import nn
|
|||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedLMOutput
|
from ...modeling_outputs import (
|
||||||
|
BaseModelOutput,
|
||||||
|
BaseModelOutputWithPooling,
|
||||||
|
ImageClassifierOutput,
|
||||||
|
MaskedImageModelingOutput,
|
||||||
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
@@ -647,7 +652,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
|
|||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=MaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values: Optional[torch.Tensor] = None,
|
pixel_values: Optional[torch.Tensor] = None,
|
||||||
@@ -657,7 +662,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
interpolate_pos_encoding: Optional[bool] = None,
|
interpolate_pos_encoding: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[tuple, MaskedLMOutput]:
|
) -> Union[tuple, MaskedImageModelingOutput]:
|
||||||
r"""
|
r"""
|
||||||
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
|
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
|
||||||
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
||||||
@@ -683,7 +688,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
|
|||||||
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
|
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
|
||||||
|
|
||||||
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
|
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||||
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits
|
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
|
||||||
>>> list(reconstructed_pixel_values.shape)
|
>>> list(reconstructed_pixel_values.shape)
|
||||||
[1, 3, 224, 224]
|
[1, 3, 224, 224]
|
||||||
```"""
|
```"""
|
||||||
@@ -727,9 +732,9 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
|
|||||||
output = (reconstructed_pixel_values,) + outputs[1:]
|
output = (reconstructed_pixel_values,) + outputs[1:]
|
||||||
return ((masked_im_loss,) + output) if masked_im_loss is not None else output
|
return ((masked_im_loss,) + output) if masked_im_loss is not None else output
|
||||||
|
|
||||||
return MaskedLMOutput(
|
return MaskedImageModelingOutput(
|
||||||
loss=masked_im_loss,
|
loss=masked_im_loss,
|
||||||
logits=reconstructed_pixel_values,
|
reconstruction=reconstructed_pixel_values,
|
||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -145,7 +145,7 @@ class DeiTModelTester:
|
|||||||
model.eval()
|
model.eval()
|
||||||
result = model(pixel_values)
|
result = model(pixel_values)
|
||||||
self.parent.assertEqual(
|
self.parent.assertEqual(
|
||||||
result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
|
result.reconstruction.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
|
||||||
)
|
)
|
||||||
|
|
||||||
# test greyscale images
|
# test greyscale images
|
||||||
@@ -156,7 +156,7 @@ class DeiTModelTester:
|
|||||||
|
|
||||||
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||||
result = model(pixel_values)
|
result = model(pixel_values)
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size))
|
self.parent.assertEqual(result.reconstruction.shape, (self.batch_size, 1, self.image_size, self.image_size))
|
||||||
|
|
||||||
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
||||||
config.num_labels = self.type_sequence_label_size
|
config.num_labels = self.type_sequence_label_size
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ class TFDeiTModelTester:
|
|||||||
model = TFDeiTForMaskedImageModeling(config=config)
|
model = TFDeiTForMaskedImageModeling(config=config)
|
||||||
result = model(pixel_values)
|
result = model(pixel_values)
|
||||||
self.parent.assertEqual(
|
self.parent.assertEqual(
|
||||||
result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
|
result.reconstruction.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
|
||||||
)
|
)
|
||||||
|
|
||||||
# test greyscale images
|
# test greyscale images
|
||||||
@@ -139,7 +139,7 @@ class TFDeiTModelTester:
|
|||||||
|
|
||||||
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||||
result = model(pixel_values)
|
result = model(pixel_values)
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size))
|
self.parent.assertEqual(result.reconstruction.shape, (self.batch_size, 1, self.image_size, self.image_size))
|
||||||
|
|
||||||
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
||||||
config.num_labels = self.type_sequence_label_size
|
config.num_labels = self.type_sequence_label_size
|
||||||
|
|||||||
@@ -134,7 +134,7 @@ class ViTModelTester:
|
|||||||
model.eval()
|
model.eval()
|
||||||
result = model(pixel_values)
|
result = model(pixel_values)
|
||||||
self.parent.assertEqual(
|
self.parent.assertEqual(
|
||||||
result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
|
result.reconstruction.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
|
||||||
)
|
)
|
||||||
|
|
||||||
# test greyscale images
|
# test greyscale images
|
||||||
@@ -145,7 +145,7 @@ class ViTModelTester:
|
|||||||
|
|
||||||
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||||
result = model(pixel_values)
|
result = model(pixel_values)
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size))
|
self.parent.assertEqual(result.reconstruction.shape, (self.batch_size, 1, self.image_size, self.image_size))
|
||||||
|
|
||||||
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
||||||
config.num_labels = self.type_sequence_label_size
|
config.num_labels = self.type_sequence_label_size
|
||||||
|
|||||||
Reference in New Issue
Block a user