Compare commits

...

5 Commits

Author SHA1 Message Date
Sylvain Gugger
68287689f2 Patch release: v4.27.2
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
2023-03-20 12:02:35 -04:00
Sylvain Gugger
1e39734c4b Fix balanced and auto device_map (#22271) 2023-03-20 12:01:08 -04:00
Lysandre
2355e46395 Release: v4.27.1
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
2023-03-15 15:39:22 -04:00
Sylvain Gugger
659ef0b5fe Regression pipeline device (#22190)
* Fix regression in pipeline when device=-1 is passed

* Add regression test
2023-03-15 14:14:23 -04:00
amyeroberts
36ed7508b0 Revert 22152 MaskedImageCompletionOutput changes (#22187)
Revert changes
2023-03-15 14:00:33 -04:00
8 changed files with 21 additions and 45 deletions

View File

@@ -418,7 +418,7 @@ install_requires = [
setup(
name="transformers",
version="4.27.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.27.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co",
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",

View File

@@ -18,7 +18,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.27.0"
__version__ = "4.27.2"
from typing import TYPE_CHECKING

View File

@@ -1281,34 +1281,6 @@ class ImageSuperResolutionOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class MaskedImageCompletionOutput(ModelOutput):
"""
Base class for outputs of masked image completion / in-painting models.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Reconstruction loss.
reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Reconstructed / completed images.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
(also called feature maps) of the model at the output of each stage.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
reconstruction: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class Wav2Vec2BaseModelOutput(ModelOutput):
"""

View File

@@ -2563,7 +2563,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
elif device_map in ["balanced", "balanced_low_0"] and get_balanced_memory is None:
raise ValueError(f"`device_map={device_map}` requires a source install of Accelerate.")
kwargs = {"no_split_module_classes": no_split_modules, "max_memory": max_memory}
kwargs = {"no_split_module_classes": no_split_modules}
if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
kwargs["special_dtypes"] = special_dtypes
elif len(special_dtypes) > 0:
@@ -2578,6 +2578,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
low_zero=(device_map == "balanced_low_0"),
**kwargs,
)
kwargs["max_memory"] = max_memory
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(model, dtype=torch_dtype if not load_in_8bit else torch.int8, **kwargs)

View File

@@ -25,12 +25,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
ImageClassifierOutput,
MaskedImageCompletionOutput,
)
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedLMOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
@@ -648,7 +643,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
self.post_init()
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MaskedImageCompletionOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
@@ -658,7 +653,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, MaskedImageCompletionOutput]:
) -> Union[tuple, MaskedLMOutput]:
r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
@@ -728,9 +723,9 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
output = (reconstructed_pixel_values,) + outputs[1:]
return ((masked_im_loss,) + output) if masked_im_loss is not None else output
return MaskedImageCompletionOutput(
return MaskedLMOutput(
loss=masked_im_loss,
reconstruction=reconstructed_pixel_values,
logits=reconstructed_pixel_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@@ -769,8 +769,8 @@ class Pipeline(_ScikitCompat):
self.modelcard = modelcard
self.framework = framework
if self.framework == "pt" and device is not None:
self.model = self.model.to(device=device)
if self.framework == "pt" and device is not None and not (isinstance(device, int) and device < 0):
self.model.to(device)
if device is None:
# `accelerate` device map

View File

@@ -134,7 +134,7 @@ class ViTModelTester:
model.eval()
result = model(pixel_values)
self.parent.assertEqual(
result.reconstruction.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
)
# test greyscale images
@@ -145,7 +145,7 @@ class ViTModelTester:
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values)
self.parent.assertEqual(result.reconstruction.shape, (self.batch_size, 1, self.image_size, self.image_size))
self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size

View File

@@ -484,6 +484,14 @@ class PipelineUtilsTest(unittest.TestCase):
outputs = list(dataset)
self.assertEqual(outputs, [[{"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}]])
def test_pipeline_negative_device(self):
# To avoid regressing, pipeline used to accept device=-1
classifier = pipeline("text-generation", "hf-internal-testing/tiny-random-bert", device=-1)
expected_output = [{"generated_text": ANY(str)}]
actual_output = classifier("Test input.")
self.assertEqual(expected_output, actual_output)
@slow
@require_torch
def test_load_default_pipelines_pt(self):