Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a7cab3c283 | ||
|
|
f6d61898b3 | ||
|
|
64bcf77cc9 | ||
|
|
780376fcd8 | ||
|
|
6e4429fd47 | ||
|
|
f33b061c8c | ||
|
|
d1dec79d0e |
2
setup.py
2
setup.py
@@ -428,7 +428,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.36.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.36.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
||||
author_email="transformers@huggingface.co",
|
||||
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.36.1"
|
||||
__version__ = "4.36.2"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@@ -2955,6 +2955,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
# In case one passes a config to `from_pretrained` + "attn_implementation"
|
||||
# override the `_attn_implementation` attribute to `attn_implementation` of the kwargs
|
||||
# Please see: https://github.com/huggingface/transformers/issues/28038
|
||||
|
||||
# Overwrite `config._attn_implementation` by the one from the kwargs --> in auto-factory
|
||||
# we pop attn_implementation from the kwargs but this handles the case where users
|
||||
# passes manually the config to `from_pretrained`.
|
||||
config = copy.deepcopy(config)
|
||||
|
||||
kwarg_attn_imp = kwargs.pop("attn_implementation", None)
|
||||
if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp:
|
||||
config._attn_implementation = kwarg_attn_imp
|
||||
model_kwargs = kwargs
|
||||
|
||||
quantizer = None
|
||||
|
||||
@@ -578,6 +578,13 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel):
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
@@ -608,13 +615,6 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel):
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
@@ -1000,6 +1000,13 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
past_key_values_length = 0
|
||||
if use_cache:
|
||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||
@@ -1038,13 +1045,6 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
@@ -363,6 +363,12 @@ class MistralFlashAttention2(MistralAttention):
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||
@@ -385,11 +391,16 @@ class MistralFlashAttention2(MistralAttention):
|
||||
|
||||
if past_key_value is not None:
|
||||
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||
if getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window:
|
||||
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
|
||||
if (
|
||||
getattr(self.config, "sliding_window", None) is not None
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
and cache_has_contents
|
||||
):
|
||||
slicing_tokens = 1 - self.config.sliding_window
|
||||
|
||||
past_key = past_key_value[0]
|
||||
past_value = past_key_value[1]
|
||||
past_key = past_key_value[self.layer_idx][0]
|
||||
past_value = past_key_value[self.layer_idx][1]
|
||||
|
||||
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
||||
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
||||
@@ -400,8 +411,6 @@ class MistralFlashAttention2(MistralAttention):
|
||||
f" {past_key.shape}"
|
||||
)
|
||||
|
||||
past_key_value = (past_key, past_value)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, slicing_tokens:]
|
||||
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
|
||||
@@ -855,6 +864,13 @@ class MistralModel(MistralPreTrainedModel):
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
past_key_values_length = 0
|
||||
|
||||
if use_cache:
|
||||
@@ -899,13 +915,6 @@ class MistralModel(MistralPreTrainedModel):
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
@@ -414,6 +414,12 @@ class MixtralFlashAttention2(MixtralAttention):
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||
@@ -436,11 +442,16 @@ class MixtralFlashAttention2(MixtralAttention):
|
||||
|
||||
if past_key_value is not None:
|
||||
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||
if getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window:
|
||||
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
|
||||
if (
|
||||
getattr(self.config, "sliding_window", None) is not None
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
and cache_has_contents
|
||||
):
|
||||
slicing_tokens = 1 - self.config.sliding_window
|
||||
|
||||
past_key = past_key_value[0]
|
||||
past_value = past_key_value[1]
|
||||
past_key = past_key_value[self.layer_idx][0]
|
||||
past_value = past_key_value[self.layer_idx][1]
|
||||
|
||||
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
||||
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
||||
@@ -451,8 +462,6 @@ class MixtralFlashAttention2(MixtralAttention):
|
||||
f" {past_key.shape}"
|
||||
)
|
||||
|
||||
past_key_value = (past_key, past_value)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, slicing_tokens:]
|
||||
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
|
||||
@@ -1007,6 +1016,13 @@ class MixtralModel(MixtralPreTrainedModel):
|
||||
|
||||
past_key_values_length = 0
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if use_cache:
|
||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||
if use_legacy_cache:
|
||||
@@ -1049,13 +1065,6 @@ class MixtralModel(MixtralPreTrainedModel):
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
@@ -608,6 +608,13 @@ class PersimmonModel(PersimmonPreTrainedModel):
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if use_cache:
|
||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||
if use_legacy_cache:
|
||||
@@ -635,13 +642,6 @@ class PersimmonModel(PersimmonPreTrainedModel):
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
@@ -860,6 +860,13 @@ class PhiModel(PhiPreTrainedModel):
|
||||
|
||||
past_key_values_length = 0
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if use_cache:
|
||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||
if use_legacy_cache:
|
||||
@@ -890,13 +897,6 @@ class PhiModel(PhiPreTrainedModel):
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
@@ -2030,10 +2030,15 @@ class Trainer:
|
||||
weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
|
||||
safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
|
||||
safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)
|
||||
is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and any(
|
||||
FSDP_MODEL_NAME in folder_name
|
||||
for folder_name in os.listdir(resume_from_checkpoint)
|
||||
if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
|
||||
is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and (
|
||||
# this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
|
||||
any(
|
||||
FSDP_MODEL_NAME in folder_name
|
||||
for folder_name in os.listdir(resume_from_checkpoint)
|
||||
if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
|
||||
)
|
||||
# this checks the FSDP state dict when `FULL_STATE_DICT` is used
|
||||
or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin"))
|
||||
)
|
||||
|
||||
if is_fsdp_ckpt and not self.is_fsdp_enabled:
|
||||
@@ -2383,7 +2388,9 @@ class Trainer:
|
||||
self.args.distributed_state.wait_for_everyone()
|
||||
# Then go through the rewriting process starting on process 0
|
||||
if staging_output_dir != output_dir:
|
||||
with self.args.main_process_first(desc="Renaming model checkpoint folder to true location"):
|
||||
with self.args.main_process_first(
|
||||
desc="Renaming model checkpoint folder to true location", local=self.args.save_on_each_node
|
||||
):
|
||||
if os.path.exists(staging_output_dir):
|
||||
os.rename(staging_output_dir, output_dir)
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ from transformers.utils import is_accelerate_available, is_torch_bf16_available_
|
||||
|
||||
if is_torch_available():
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_1
|
||||
from transformers.trainer import FSDP_MODEL_NAME
|
||||
else:
|
||||
is_torch_greater_or_equal_than_2_1 = False
|
||||
|
||||
@@ -211,6 +212,19 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
||||
# resume from ckpt
|
||||
checkpoint = os.path.join(output_dir, "checkpoint-115")
|
||||
resume_args = args + f"--resume_from_checkpoint {checkpoint}".split()
|
||||
|
||||
is_fsdp_ckpt = os.path.isdir(checkpoint) and (
|
||||
# this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
|
||||
any(
|
||||
FSDP_MODEL_NAME in folder_name
|
||||
for folder_name in os.listdir(checkpoint)
|
||||
if os.path.isdir(os.path.join(checkpoint, folder_name))
|
||||
)
|
||||
# this checks the FSDP state dict when `FULL_STATE_DICT` is used
|
||||
or os.path.isfile(os.path.join(checkpoint, f"{FSDP_MODEL_NAME}.bin"))
|
||||
)
|
||||
self.assertTrue(is_fsdp_ckpt)
|
||||
|
||||
logs_resume = self.run_cmd_and_get_logs(
|
||||
use_accelerate, sharding_strategy, launcher, script, resume_args, output_dir
|
||||
)
|
||||
|
||||
@@ -1823,6 +1823,16 @@ class TestAttentionImplementation(unittest.TestCase):
|
||||
|
||||
self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))
|
||||
|
||||
def test_error_no_flash_available_with_config(self):
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
config = AutoConfig.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel")
|
||||
|
||||
_ = AutoModel.from_pretrained(
|
||||
"hf-tiny-model-private/tiny-random-MCTCTModel", config=config, attn_implementation="flash_attention_2"
|
||||
)
|
||||
|
||||
self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))
|
||||
|
||||
def test_error_wrong_attn_implementation(self):
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="foo")
|
||||
@@ -1840,6 +1850,21 @@ class TestAttentionImplementation(unittest.TestCase):
|
||||
|
||||
self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))
|
||||
|
||||
def test_not_available_flash_with_config(self):
|
||||
if is_flash_attn_2_available():
|
||||
self.skipTest("Please uninstall flash-attn package to run test_not_available_flash")
|
||||
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTBigCodeModel")
|
||||
|
||||
with self.assertRaises(ImportError) as cm:
|
||||
_ = AutoModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-GPTBigCodeModel",
|
||||
config=config,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
|
||||
self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))
|
||||
|
||||
def test_not_available_sdpa(self):
|
||||
if is_torch_sdpa_available():
|
||||
self.skipTest("This test requires torch<=2.0")
|
||||
|
||||
Reference in New Issue
Block a user