Compare commits

...

7 Commits

Author SHA1 Message Date
Amy Roberts
a7cab3c283 Release: v4.36.2
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
2023-12-18 15:06:12 +00:00
dumpmemory
f6d61898b3 Fix bug for checkpoint saving on multi node training setting (#28078)
* add multi-node traning setting

* fix style
2023-12-18 11:19:54 +00:00
Sourab Mangrulkar
64bcf77cc9 fix resuming from ckpt when using FSDP with FULL_STATE_DICT (#27891)
* fix resuming from ckpt when suing FSDP with FULL_STATE_DICT

* update tests

* fix tests
2023-12-18 11:19:45 +00:00
Younes Belkada
780376fcd8 [Modeling / Mixtral] Fix GC + PEFT issues with Mixtral (#28061)
fix for mistral
2023-12-18 11:19:37 +00:00
Younes Belkada
6e4429fd47 [FA-2] Fix fa-2 issue when passing config to from_pretrained (#28043)
* fix fa-2 issue

* fix test

* Update src/transformers/modeling_utils.py

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>

* clenaer fix

* up

* add more robust tests

* Update src/transformers/modeling_utils.py

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>

* fixup

* Update src/transformers/modeling_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* pop

* add test

---------

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
2023-12-18 11:19:29 +00:00
Joao Gante
f33b061c8c Generate: Mistral/Mixtral FA2 cache fix when going beyond the context window (#28037) 2023-12-18 11:11:55 +00:00
Younes Belkada
d1dec79d0e [core / modeling] Fix training bug with PEFT + GC (#28031)
fix trainign bug
2023-12-18 11:05:36 +00:00
12 changed files with 135 additions and 59 deletions

View File

@@ -428,7 +428,7 @@ install_requires = [
setup(
name="transformers",
version="4.36.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.36.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co",
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",

View File

@@ -18,7 +18,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.36.1"
__version__ = "4.36.2"
from typing import TYPE_CHECKING

View File

@@ -2955,6 +2955,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
**kwargs,
)
else:
# In case one passes a config to `from_pretrained` + "attn_implementation"
# override the `_attn_implementation` attribute to `attn_implementation` of the kwargs
# Please see: https://github.com/huggingface/transformers/issues/28038
# Overwrite `config._attn_implementation` by the one from the kwargs --> in auto-factory
# we pop attn_implementation from the kwargs but this handles the case where users
# passes manually the config to `from_pretrained`.
config = copy.deepcopy(config)
kwarg_attn_imp = kwargs.pop("attn_implementation", None)
if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp:
config._attn_implementation = kwarg_attn_imp
model_kwargs = kwargs
quantizer = None

View File

@@ -578,6 +578,13 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel):
seq_length_with_past = seq_length
past_key_values_length = 0
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
@@ -608,13 +615,6 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel):
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None

View File

@@ -1000,6 +1000,13 @@ class LlamaModel(LlamaPreTrainedModel):
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
past_key_values_length = 0
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
@@ -1038,13 +1045,6 @@ class LlamaModel(LlamaPreTrainedModel):
# embed positions
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None

View File

@@ -363,6 +363,12 @@ class MistralFlashAttention2(MistralAttention):
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# Because the input can be padded, the absolute sequence length depends on the max position id.
@@ -385,11 +391,16 @@ class MistralFlashAttention2(MistralAttention):
if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
if getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window:
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
if (
getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
and cache_has_contents
):
slicing_tokens = 1 - self.config.sliding_window
past_key = past_key_value[0]
past_value = past_key_value[1]
past_key = past_key_value[self.layer_idx][0]
past_value = past_key_value[self.layer_idx][1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
@@ -400,8 +411,6 @@ class MistralFlashAttention2(MistralAttention):
f" {past_key.shape}"
)
past_key_value = (past_key, past_value)
if attention_mask is not None:
attention_mask = attention_mask[:, slicing_tokens:]
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
@@ -855,6 +864,13 @@ class MistralModel(MistralPreTrainedModel):
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
past_key_values_length = 0
if use_cache:
@@ -899,13 +915,6 @@ class MistralModel(MistralPreTrainedModel):
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None

View File

@@ -414,6 +414,12 @@ class MixtralFlashAttention2(MixtralAttention):
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# Because the input can be padded, the absolute sequence length depends on the max position id.
@@ -436,11 +442,16 @@ class MixtralFlashAttention2(MixtralAttention):
if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
if getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window:
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
if (
getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
and cache_has_contents
):
slicing_tokens = 1 - self.config.sliding_window
past_key = past_key_value[0]
past_value = past_key_value[1]
past_key = past_key_value[self.layer_idx][0]
past_value = past_key_value[self.layer_idx][1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
@@ -451,8 +462,6 @@ class MixtralFlashAttention2(MixtralAttention):
f" {past_key.shape}"
)
past_key_value = (past_key, past_value)
if attention_mask is not None:
attention_mask = attention_mask[:, slicing_tokens:]
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
@@ -1007,6 +1016,13 @@ class MixtralModel(MixtralPreTrainedModel):
past_key_values_length = 0
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
@@ -1049,13 +1065,6 @@ class MixtralModel(MixtralPreTrainedModel):
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None

View File

@@ -608,6 +608,13 @@ class PersimmonModel(PersimmonPreTrainedModel):
seq_length_with_past = seq_length
past_key_values_length = 0
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
@@ -635,13 +642,6 @@ class PersimmonModel(PersimmonPreTrainedModel):
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None

View File

@@ -860,6 +860,13 @@ class PhiModel(PhiPreTrainedModel):
past_key_values_length = 0
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
@@ -890,13 +897,6 @@ class PhiModel(PhiPreTrainedModel):
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None

View File

@@ -2030,10 +2030,15 @@ class Trainer:
weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)
is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and any(
FSDP_MODEL_NAME in folder_name
for folder_name in os.listdir(resume_from_checkpoint)
if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and (
# this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
any(
FSDP_MODEL_NAME in folder_name
for folder_name in os.listdir(resume_from_checkpoint)
if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
)
# this checks the FSDP state dict when `FULL_STATE_DICT` is used
or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin"))
)
if is_fsdp_ckpt and not self.is_fsdp_enabled:
@@ -2383,7 +2388,9 @@ class Trainer:
self.args.distributed_state.wait_for_everyone()
# Then go through the rewriting process starting on process 0
if staging_output_dir != output_dir:
with self.args.main_process_first(desc="Renaming model checkpoint folder to true location"):
with self.args.main_process_first(
desc="Renaming model checkpoint folder to true location", local=self.args.save_on_each_node
):
if os.path.exists(staging_output_dir):
os.rename(staging_output_dir, output_dir)

View File

@@ -41,6 +41,7 @@ from transformers.utils import is_accelerate_available, is_torch_bf16_available_
if is_torch_available():
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_1
from transformers.trainer import FSDP_MODEL_NAME
else:
is_torch_greater_or_equal_than_2_1 = False
@@ -211,6 +212,19 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
# resume from ckpt
checkpoint = os.path.join(output_dir, "checkpoint-115")
resume_args = args + f"--resume_from_checkpoint {checkpoint}".split()
is_fsdp_ckpt = os.path.isdir(checkpoint) and (
# this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
any(
FSDP_MODEL_NAME in folder_name
for folder_name in os.listdir(checkpoint)
if os.path.isdir(os.path.join(checkpoint, folder_name))
)
# this checks the FSDP state dict when `FULL_STATE_DICT` is used
or os.path.isfile(os.path.join(checkpoint, f"{FSDP_MODEL_NAME}.bin"))
)
self.assertTrue(is_fsdp_ckpt)
logs_resume = self.run_cmd_and_get_logs(
use_accelerate, sharding_strategy, launcher, script, resume_args, output_dir
)

View File

@@ -1823,6 +1823,16 @@ class TestAttentionImplementation(unittest.TestCase):
self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))
def test_error_no_flash_available_with_config(self):
with self.assertRaises(ValueError) as cm:
config = AutoConfig.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel")
_ = AutoModel.from_pretrained(
"hf-tiny-model-private/tiny-random-MCTCTModel", config=config, attn_implementation="flash_attention_2"
)
self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))
def test_error_wrong_attn_implementation(self):
with self.assertRaises(ValueError) as cm:
_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="foo")
@@ -1840,6 +1850,21 @@ class TestAttentionImplementation(unittest.TestCase):
self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))
def test_not_available_flash_with_config(self):
if is_flash_attn_2_available():
self.skipTest("Please uninstall flash-attn package to run test_not_available_flash")
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTBigCodeModel")
with self.assertRaises(ImportError) as cm:
_ = AutoModel.from_pretrained(
"hf-internal-testing/tiny-random-GPTBigCodeModel",
config=config,
attn_implementation="flash_attention_2",
)
self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))
def test_not_available_sdpa(self):
if is_torch_sdpa_available():
self.skipTest("This test requires torch<=2.0")