Compare commits

...

6 Commits

Author SHA1 Message Date
Arthur Zucker
868d36d29e Patch release v4.43.4
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
2024-08-05 12:40:38 +02:00
Raushan Turganbay
5cea2e73ef Resize embeds with DeepSpeed (#32214)
* fix resize when deepspeed

* deepsped uses new embeds

* we needed this
2024-08-05 12:40:01 +02:00
Arthur Zucker
47c29ccfaf Patch release v4.43.3
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
2024-07-26 17:09:09 +02:00
João Nadkarni
54bc29c1ba don't log base model architecture in wandb if log model is false (#32143)
* don't log base model architecture in wandb is log model is false

* Update src/transformers/integrations/integration_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* convert log model setting into an enum

* fix formatting

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
2024-07-26 17:01:59 +02:00
Kashif Rasul
cc75146d0e [BigBird Pegasus] set _supports_param_buffer_assignment to False (#32222)
set _supports_param_buffer_assignment to False
2024-07-26 17:01:59 +02:00
Sanchit Gandhi
cd06184cc4 [whisper] fix short-form output type (#32178)
* [whisper] fix short-form output type

* add test

* make style

* update long-form tests

* fixes

* last fix

* finalise test
2024-07-26 17:01:59 +02:00
7 changed files with 108 additions and 51 deletions

View File

@@ -430,7 +430,7 @@ install_requires = [
setup(
name="transformers",
version="4.43.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.43.4", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co",
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",

View File

@@ -18,7 +18,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.43.2"
__version__ = "4.43.4"
from typing import TYPE_CHECKING

View File

@@ -26,6 +26,7 @@ import shutil
import sys
import tempfile
from dataclasses import asdict, fields
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
@@ -726,6 +727,35 @@ def save_model_architecture_to_file(model: Any, output_dir: str):
print(model, file=f)
class WandbLogModel(str, Enum):
"""Enum of possible log model values in W&B."""
CHECKPOINT = "checkpoint"
END = "end"
FALSE = "false"
@property
def is_enabled(self) -> bool:
"""Check if the value corresponds to a state where the `WANDB_LOG_MODEL` setting is enabled."""
return self in (WandbLogModel.CHECKPOINT, WandbLogModel.END)
@classmethod
def _missing_(cls, value: Any) -> "WandbLogModel":
if not isinstance(value, str):
raise ValueError(f"Expecting to have a string `WANDB_LOG_MODEL` setting, but got {type(value)}")
if value.upper() in ENV_VARS_TRUE_VALUES:
DeprecationWarning(
f"Setting `WANDB_LOG_MODEL` as {os.getenv('WANDB_LOG_MODEL')} is deprecated and will be removed in "
"version 5 of transformers. Use one of `'end'` or `'checkpoint'` instead."
)
logger.info(f"Setting `WANDB_LOG_MODEL` from {os.getenv('WANDB_LOG_MODEL')} to `end` instead")
return WandbLogModel.END
logger.warning(
f"Received unrecognized `WANDB_LOG_MODEL` setting value={value}; so disabling `WANDB_LOG_MODEL`"
)
return WandbLogModel.FALSE
class WandbCallback(TrainerCallback):
"""
A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/).
@@ -740,16 +770,7 @@ class WandbCallback(TrainerCallback):
self._wandb = wandb
self._initialized = False
# log model
if os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}):
DeprecationWarning(
f"Setting `WANDB_LOG_MODEL` as {os.getenv('WANDB_LOG_MODEL')} is deprecated and will be removed in "
"version 5 of transformers. Use one of `'end'` or `'checkpoint'` instead."
)
logger.info(f"Setting `WANDB_LOG_MODEL` from {os.getenv('WANDB_LOG_MODEL')} to `end` instead")
self._log_model = "end"
else:
self._log_model = os.getenv("WANDB_LOG_MODEL", "false").lower()
self._log_model = WandbLogModel(os.getenv("WANDB_LOG_MODEL", "false"))
def setup(self, args, state, model, **kwargs):
"""
@@ -834,37 +855,38 @@ class WandbCallback(TrainerCallback):
logger.info("Could not log the number of model parameters in Weights & Biases.")
# log the initial model architecture to an artifact
with tempfile.TemporaryDirectory() as temp_dir:
model_name = (
f"model-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir)
else f"model-{self._wandb.run.name}"
)
model_artifact = self._wandb.Artifact(
name=model_name,
type="model",
metadata={
"model_config": model.config.to_dict() if hasattr(model, "config") else None,
"num_parameters": self._wandb.config.get("model/num_parameters"),
"initial_model": True,
},
)
# add the architecture to a separate text file
save_model_architecture_to_file(model, temp_dir)
if self._log_model.is_enabled:
with tempfile.TemporaryDirectory() as temp_dir:
model_name = (
f"model-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir)
else f"model-{self._wandb.run.name}"
)
model_artifact = self._wandb.Artifact(
name=model_name,
type="model",
metadata={
"model_config": model.config.to_dict() if hasattr(model, "config") else None,
"num_parameters": self._wandb.config.get("model/num_parameters"),
"initial_model": True,
},
)
# add the architecture to a separate text file
save_model_architecture_to_file(model, temp_dir)
for f in Path(temp_dir).glob("*"):
if f.is_file():
with model_artifact.new_file(f.name, mode="wb") as fa:
fa.write(f.read_bytes())
self._wandb.run.log_artifact(model_artifact, aliases=["base_model"])
for f in Path(temp_dir).glob("*"):
if f.is_file():
with model_artifact.new_file(f.name, mode="wb") as fa:
fa.write(f.read_bytes())
self._wandb.run.log_artifact(model_artifact, aliases=["base_model"])
badge_markdown = (
f'[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge'
f'-28.svg" alt="Visualize in Weights & Biases" width="20'
f'0" height="32"/>]({self._wandb.run.get_url()})'
)
badge_markdown = (
f'[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge'
f'-28.svg" alt="Visualize in Weights & Biases" width="20'
f'0" height="32"/>]({self._wandb.run.get_url()})'
)
modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
def on_train_begin(self, args, state, control, model=None, **kwargs):
if self._wandb is None:
@@ -880,7 +902,7 @@ class WandbCallback(TrainerCallback):
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
if self._wandb is None:
return
if self._log_model in ("end", "checkpoint") and self._initialized and state.is_world_process_zero:
if self._log_model.is_enabled and self._initialized and state.is_world_process_zero:
from ..trainer import Trainer
fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
@@ -938,7 +960,7 @@ class WandbCallback(TrainerCallback):
self._wandb.log({**non_scalar_logs, "train/global_step": state.global_step})
def on_save(self, args, state, control, **kwargs):
if self._log_model == "checkpoint" and self._initialized and state.is_world_process_zero:
if self._log_model == WandbLogModel.CHECKPOINT and self._initialized and state.is_world_process_zero:
checkpoint_metadata = {
k: v
for k, v in dict(self._wandb.summary).items()

View File

@@ -1980,12 +1980,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if new_num_tokens is None and pad_to_multiple_of is None:
return model_embeds
# Since we are basically resuing the same old embeddings with new weight values, gathering is required
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
with deepspeed.zero.GatheredParameters(model_embeds.weight, modifier_rank=None):
vocab_size = model_embeds.weight.shape[0]
else:
vocab_size = model_embeds.weight.shape[0]
# Update base model and current model config
if hasattr(self.config, "text_config"):
self.config.text_config.vocab_size = model_embeds.weight.shape[0]
self.config.text_config.vocab_size = vocab_size
else:
self.config.vocab_size = model_embeds.weight.shape[0]
self.vocab_size = model_embeds.weight.shape[0]
self.config.vocab_size = vocab_size
self.vocab_size = vocab_size
# Tie weights again if needed
self.tie_weights()
@@ -2139,7 +2149,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
params = [old_embeddings.weight, new_embeddings.weight]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
old_embeddings.weight.data = new_embeddings.weight.data
old_embeddings.weight = new_embeddings.weight
old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0]
# If the new number of tokens is smaller than the original `padding_idx`, the `padding_idx`

View File

@@ -1569,6 +1569,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["BigBirdPegasusEncoderLayer", "BigBirdPegasusDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_param_buffer_assignment = False
def _init_weights(self, module):
std = self.config.init_std

View File

@@ -498,7 +498,7 @@ class WhisperGenerationMixin:
# 3. Make sure generation config is correctly set
# Make sure the generation config is correctly set depending on whether timestamps are to be returned or not
self._set_return_outputs(
return_dict_in_generate = self._set_return_outputs(
return_dict_in_generate=return_dict_in_generate,
return_token_timestamps=return_token_timestamps,
logprob_threshold=logprob_threshold,
@@ -732,7 +732,7 @@ class WhisperGenerationMixin:
else:
outputs = sequences
if generation_config.return_dict_in_generate:
if return_dict_in_generate and generation_config.return_dict_in_generate:
dict_outputs = self._stack_split_outputs(seek_outputs, model_output_type, sequences.device, kwargs)
if num_return_sequences > 1:
@@ -1109,18 +1109,20 @@ class WhisperGenerationMixin:
def _set_return_outputs(return_dict_in_generate, return_token_timestamps, logprob_threshold, generation_config):
if return_dict_in_generate is None:
return_dict_in_generate = generation_config.return_dict_in_generate
else:
generation_config.return_dict_in_generate = return_dict_in_generate
generation_config.return_token_timestamps = return_token_timestamps
if return_token_timestamps:
return_dict_in_generate = True
generation_config.return_dict_in_generate = True
generation_config.output_attentions = True
generation_config.output_scores = True
if logprob_threshold is not None:
return_dict_in_generate = True
generation_config.return_dict_in_generate = True
generation_config.output_scores = True
generation_config.return_dict_in_generate = return_dict_in_generate
return return_dict_in_generate
def _set_return_timestamps(self, return_timestamps, is_shortform, generation_config):
if not is_shortform:

View File

@@ -26,6 +26,7 @@ import unittest
import numpy as np
import pytest
from huggingface_hub import hf_hub_download
from parameterized import parameterized
import transformers
from transformers import WhisperConfig
@@ -72,6 +73,7 @@ if is_torch_available():
BeamSearchEncoderDecoderOutput,
GenerateBeamDecoderOnlyOutput,
GenerateBeamEncoderDecoderOutput,
GenerateEncoderDecoderOutput,
PhrasalConstraint,
)
from transformers.generation.logits_process import LogitsProcessor
@@ -1820,6 +1822,26 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
normalized_1 = torch.nn.functional.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
@parameterized.expand([(True,), (False,)])
def test_generate_output_type(self, return_dict_in_generate):
expected_output_type = GenerateEncoderDecoderOutput if return_dict_in_generate else torch.Tensor
for model_class in self.all_generative_model_classes:
config, inputs = self.model_tester.prepare_config_and_inputs()
model = model_class(config).to(torch_device).eval()
# short-form generation without fallback
pred_ids = model.generate(**inputs, return_dict_in_generate=return_dict_in_generate)
assert isinstance(pred_ids, expected_output_type)
# short-form generation with fallback
pred_ids = model.generate(
**inputs,
logprob_threshold=-1.0,
temperature=[0.0, 0.1],
return_dict_in_generate=return_dict_in_generate,
)
assert isinstance(pred_ids, expected_output_type)
@require_torch
@require_torchaudio