Compare commits

...

13 Commits

Author SHA1 Message Date
Amy Roberts
a7cab3c283 Release: v4.36.2
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
2023-12-18 15:06:12 +00:00
dumpmemory
f6d61898b3 Fix bug for checkpoint saving on multi node training setting (#28078)
* add multi-node traning setting

* fix style
2023-12-18 11:19:54 +00:00
Sourab Mangrulkar
64bcf77cc9 fix resuming from ckpt when using FSDP with FULL_STATE_DICT (#27891)
* fix resuming from ckpt when suing FSDP with FULL_STATE_DICT

* update tests

* fix tests
2023-12-18 11:19:45 +00:00
Younes Belkada
780376fcd8 [Modeling / Mixtral] Fix GC + PEFT issues with Mixtral (#28061)
fix for mistral
2023-12-18 11:19:37 +00:00
Younes Belkada
6e4429fd47 [FA-2] Fix fa-2 issue when passing config to from_pretrained (#28043)
* fix fa-2 issue

* fix test

* Update src/transformers/modeling_utils.py

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>

* clenaer fix

* up

* add more robust tests

* Update src/transformers/modeling_utils.py

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>

* fixup

* Update src/transformers/modeling_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* pop

* add test

---------

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
2023-12-18 11:19:29 +00:00
Joao Gante
f33b061c8c Generate: Mistral/Mixtral FA2 cache fix when going beyond the context window (#28037) 2023-12-18 11:11:55 +00:00
Younes Belkada
d1dec79d0e [core / modeling] Fix training bug with PEFT + GC (#28031)
fix trainign bug
2023-12-18 11:05:36 +00:00
ArthurZucker
c48787f347 fix seamless import
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
2023-12-14 01:23:29 -05:00
ArthurZucker
bd6541006b Release: v4.36.1 2023-12-14 00:44:30 -05:00
Zach Mueller
6342b9bd20 Fix bug with rotating checkpoints (#28009)
* Fix bug

* Write test

* Keep back old modification for grad accum steps

* Whitespace...

* Whitespace again

* Race condition

* Wait for everyone
2023-12-14 00:39:06 -05:00
fxmarty
5b7d5bd290 Fix SDPA correctness following torch==2.1.2 regression (#27973)
* fix sdpa with non-contiguous inputs for gpt_bigcode

* fix other archs

* add currently comment

* format
2023-12-14 00:38:42 -05:00
Arthur
6c3c0dc72a Hot-fix-mixstral-loss (#27948)
* fix loss computation

* compute on GPU if possible
2023-12-14 00:38:08 -05:00
Arthur
a5ee6f06f1 [Tokenizer Serialization] Fix the broken serialisation (#27099)
* nits

* nits

* actual fix

* style

* ze fix

* fix fix fix style
2023-12-14 00:37:33 -05:00
21 changed files with 215 additions and 76 deletions

View File

@@ -428,7 +428,7 @@ install_requires = [
setup(
name="transformers",
version="4.36.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.36.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co",
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",

View File

@@ -18,7 +18,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.36.0"
__version__ = "4.36.2"
from typing import TYPE_CHECKING

View File

@@ -2955,6 +2955,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
**kwargs,
)
else:
# In case one passes a config to `from_pretrained` + "attn_implementation"
# override the `_attn_implementation` attribute to `attn_implementation` of the kwargs
# Please see: https://github.com/huggingface/transformers/issues/28038
# Overwrite `config._attn_implementation` by the one from the kwargs --> in auto-factory
# we pop attn_implementation from the kwargs but this handles the case where users
# passes manually the config to `from_pretrained`.
config = copy.deepcopy(config)
kwarg_attn_imp = kwargs.pop("attn_implementation", None)
if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp:
config._attn_implementation = kwarg_attn_imp
model_kwargs = kwargs
quantizer = None

View File

@@ -583,6 +583,8 @@ class BartSdpaAttention(BartAttention):
query_states = self._shape(query_states, tgt_len, bsz)
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,

View File

@@ -578,6 +578,13 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel):
seq_length_with_past = seq_length
past_key_values_length = 0
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
@@ -608,13 +615,6 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel):
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None

View File

@@ -447,6 +447,13 @@ class FalconAttention(nn.Module):
else:
present = None
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_layer.device.type == "cuda" and attention_mask is not None:
query_layer = query_layer.contiguous()
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()
if alibi is None:
if self._use_sdpa and not output_attentions:
attn_output = F.scaled_dot_product_attention(

View File

@@ -532,24 +532,37 @@ class GPTBigCodeSdpaAttention(GPTBigCodeAttention):
if self.multi_query:
query_length = query_shape[1]
# NOTE: Maybe there is better than this?
# SDPA requires the dimension [..., sequence_length, head_dim].
query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
# Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions.
key = key.unsqueeze(1)
value = value.unsqueeze(1)
# Although these expand are not numerically useful, PyTorch 2.1 can not dispatch to mem-efficient attention
# and flash attention (No available kernel. Aborting execution.) from the shapes
# Although these expand are not numerically useful, PyTorch 2.1 can not dispatch to memory-efficient backend
# and flash attention backend (No available kernel. Aborting execution.) from the shapes
# query = [batch_size, num_heads, query_length, head_dim]
# key = [batch_size, 1, past_length, head_dim]
# value = [batch_size, 1, past_length, head_dim]
# which is unfortunate. Hopefully can be improved in the future. These expand should not be too expansive as they do not do memory copy.
key = key.expand(-1, self.num_heads, -1, -1)
value = value.expand(-1, self.num_heads, -1, -1)
#
# so we could do:
#
# key = key.expand(-1, self.num_heads, -1, -1)
# value = value.expand(-1, self.num_heads, -1, -1)
#
# However SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# so we always dispatch to the math path: https://github.com/pytorch/pytorch/issues/112577.
# Arguably we could still do expand + contiguous when `query.device.type == "cuda"` in order to dispatch on memory-efficient
# backend, but it feels very hacky.
else:
query_length = query_shape[-1]
# See the comment above.
if query.device.type == "cuda" and attention_mask is not None:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query,
key,

View File

@@ -688,6 +688,13 @@ class IdeficsAttention(nn.Module):
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = nn.functional.scaled_dot_product_attention(
query_states,
key_states,

View File

@@ -506,7 +506,6 @@ class LlamaFlashAttention2(LlamaAttention):
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
@@ -701,6 +700,7 @@ class LlamaSdpaAttention(LlamaAttention):
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
@@ -716,6 +716,13 @@ class LlamaSdpaAttention(LlamaAttention):
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
@@ -993,6 +1000,13 @@ class LlamaModel(LlamaPreTrainedModel):
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
past_key_values_length = 0
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
@@ -1031,13 +1045,6 @@ class LlamaModel(LlamaPreTrainedModel):
# embed positions
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None

View File

@@ -363,6 +363,12 @@ class MistralFlashAttention2(MistralAttention):
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# Because the input can be padded, the absolute sequence length depends on the max position id.
@@ -385,11 +391,16 @@ class MistralFlashAttention2(MistralAttention):
if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
if getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window:
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
if (
getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
and cache_has_contents
):
slicing_tokens = 1 - self.config.sliding_window
past_key = past_key_value[0]
past_value = past_key_value[1]
past_key = past_key_value[self.layer_idx][0]
past_value = past_key_value[self.layer_idx][1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
@@ -400,8 +411,6 @@ class MistralFlashAttention2(MistralAttention):
f" {past_key.shape}"
)
past_key_value = (past_key, past_value)
if attention_mask is not None:
attention_mask = attention_mask[:, slicing_tokens:]
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
@@ -855,6 +864,13 @@ class MistralModel(MistralPreTrainedModel):
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
past_key_values_length = 0
if use_cache:
@@ -899,13 +915,6 @@ class MistralModel(MistralPreTrainedModel):
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None

View File

@@ -95,7 +95,8 @@ def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tenso
if isinstance(gate_logits, tuple):
# cat along the layers?
gate_logits = torch.cat(gate_logits, dim=0)
compute_device = gate_logits[0].device
gate_logits = torch.cat([gate.to(compute_device) for gate in gate_logits], dim=0)
routing_weights, selected_experts = torch.topk(gate_logits, top_k, dim=-1)
routing_weights = routing_weights.softmax(dim=-1)
@@ -413,6 +414,12 @@ class MixtralFlashAttention2(MixtralAttention):
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# Because the input can be padded, the absolute sequence length depends on the max position id.
@@ -435,11 +442,16 @@ class MixtralFlashAttention2(MixtralAttention):
if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
if getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window:
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
if (
getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
and cache_has_contents
):
slicing_tokens = 1 - self.config.sliding_window
past_key = past_key_value[0]
past_value = past_key_value[1]
past_key = past_key_value[self.layer_idx][0]
past_value = past_key_value[self.layer_idx][1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
@@ -450,8 +462,6 @@ class MixtralFlashAttention2(MixtralAttention):
f" {past_key.shape}"
)
past_key_value = (past_key, past_value)
if attention_mask is not None:
attention_mask = attention_mask[:, slicing_tokens:]
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
@@ -1006,6 +1016,13 @@ class MixtralModel(MixtralPreTrainedModel):
past_key_values_length = 0
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
@@ -1048,13 +1065,6 @@ class MixtralModel(MixtralPreTrainedModel):
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None

View File

@@ -145,6 +145,8 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast):
from_slow = kwargs.pop("from_slow", None)
from_slow = from_slow or str(pad_token) != "<pad>" or str(eos_token) != "</s>" or str(unk_token) != "<unk>"
kwargs.pop("added_tokens_decoder", {})
super().__init__(
vocab_file,
tokenizer_file=tokenizer_file,

View File

@@ -608,6 +608,13 @@ class PersimmonModel(PersimmonPreTrainedModel):
seq_length_with_past = seq_length
past_key_values_length = 0
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
@@ -635,13 +642,6 @@ class PersimmonModel(PersimmonPreTrainedModel):
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None

View File

@@ -860,6 +860,13 @@ class PhiModel(PhiPreTrainedModel):
past_key_values_length = 0
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
@@ -890,13 +897,6 @@ class PhiModel(PhiPreTrainedModel):
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None

View File

@@ -25,12 +25,14 @@ from ...tokenization_utils import (
TextInput,
)
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import PaddingStrategy, logging
from .tokenization_seamless_m4t import (
SeamlessM4TTokenizer,
)
from ...utils import PaddingStrategy, is_sentencepiece_available, logging
if is_sentencepiece_available():
from .tokenization_seamless_m4t import SeamlessM4TTokenizer
else:
SeamlessM4TTokenizer = None
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"}

View File

@@ -764,6 +764,8 @@ class WhisperSdpaAttention(WhisperAttention):
query_states = self._shape(query_states, tgt_len, bsz)
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,

View File

@@ -2235,7 +2235,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# allows converting a fast -> slow: add the `tokenizer.json`'s `"added_tokens"` to the slow tokenizer
# if `tokenizer_config.json` is `None`
if "Fast" not in cls.__name__ and tokenizer_file is not None:
if tokenizer_file is not None:
# This is for slow so can be done before
with open(tokenizer_file, encoding="utf-8") as tokenizer_file_handle:
tokenizer_file_handle = json.load(tokenizer_file_handle)
@@ -2247,14 +2247,14 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# end legacy
# Passing AddedTokens and not strings to the class to prevent it from casting the string to a different AddedToken
# convert {'__type': 'AddedToken', 'content': '<ent>', 'lstrip': False, 'normalized': True, ...} to AddedTokens
init_kwargs["added_tokens_decoder"] = added_tokens_decoder
init_kwargs = cls.convert_added_tokens(init_kwargs, save=False)
for key in cls.SPECIAL_TOKENS_ATTRIBUTES & init_kwargs.keys():
if added_tokens_map != {} and init_kwargs[key] is not None:
if key != "additional_special_tokens":
init_kwargs[key] = added_tokens_map.get(init_kwargs[key], init_kwargs[key])
init_kwargs[key] = added_tokens_map.get(str(init_kwargs[key]), init_kwargs[key])
init_kwargs["added_tokens_decoder"] = added_tokens_decoder
# convert {'__type': 'AddedToken', 'content': '<ent>', 'lstrip': False, 'normalized': True, ...} to AddedTokens
init_kwargs = cls.convert_added_tokens(init_kwargs, save=False)
# Instantiate the tokenizer.
try:
tokenizer = cls(*init_inputs, **init_kwargs)

View File

@@ -2030,10 +2030,15 @@ class Trainer:
weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)
is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and any(
FSDP_MODEL_NAME in folder_name
for folder_name in os.listdir(resume_from_checkpoint)
if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and (
# this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
any(
FSDP_MODEL_NAME in folder_name
for folder_name in os.listdir(resume_from_checkpoint)
if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
)
# this checks the FSDP state dict when `FULL_STATE_DICT` is used
or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin"))
)
if is_fsdp_ckpt and not self.is_fsdp_enabled:
@@ -2379,8 +2384,15 @@ class Trainer:
self._push_from_checkpoint(staging_output_dir)
# Place checkpoint in final location after all saving is finished.
# First wait for everyone to finish writing
self.args.distributed_state.wait_for_everyone()
# Then go through the rewriting process starting on process 0
if staging_output_dir != output_dir:
os.rename(staging_output_dir, output_dir)
with self.args.main_process_first(
desc="Renaming model checkpoint folder to true location", local=self.args.save_on_each_node
):
if os.path.exists(staging_output_dir):
os.rename(staging_output_dir, output_dir)
# Maybe delete some older checkpoints.
if self.args.should_save:

View File

@@ -41,6 +41,7 @@ from transformers.utils import is_accelerate_available, is_torch_bf16_available_
if is_torch_available():
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_1
from transformers.trainer import FSDP_MODEL_NAME
else:
is_torch_greater_or_equal_than_2_1 = False
@@ -211,6 +212,19 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
# resume from ckpt
checkpoint = os.path.join(output_dir, "checkpoint-115")
resume_args = args + f"--resume_from_checkpoint {checkpoint}".split()
is_fsdp_ckpt = os.path.isdir(checkpoint) and (
# this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
any(
FSDP_MODEL_NAME in folder_name
for folder_name in os.listdir(checkpoint)
if os.path.isdir(os.path.join(checkpoint, folder_name))
)
# this checks the FSDP state dict when `FULL_STATE_DICT` is used
or os.path.isfile(os.path.join(checkpoint, f"{FSDP_MODEL_NAME}.bin"))
)
self.assertTrue(is_fsdp_ckpt)
logs_resume = self.run_cmd_and_get_logs(
use_accelerate, sharding_strategy, launcher, script, resume_args, output_dir
)

View File

@@ -1823,6 +1823,16 @@ class TestAttentionImplementation(unittest.TestCase):
self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))
def test_error_no_flash_available_with_config(self):
with self.assertRaises(ValueError) as cm:
config = AutoConfig.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel")
_ = AutoModel.from_pretrained(
"hf-tiny-model-private/tiny-random-MCTCTModel", config=config, attn_implementation="flash_attention_2"
)
self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))
def test_error_wrong_attn_implementation(self):
with self.assertRaises(ValueError) as cm:
_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="foo")
@@ -1840,6 +1850,21 @@ class TestAttentionImplementation(unittest.TestCase):
self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))
def test_not_available_flash_with_config(self):
if is_flash_attn_2_available():
self.skipTest("Please uninstall flash-attn package to run test_not_available_flash")
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTBigCodeModel")
with self.assertRaises(ImportError) as cm:
_ = AutoModel.from_pretrained(
"hf-internal-testing/tiny-random-GPTBigCodeModel",
config=config,
attn_implementation="flash_attention_2",
)
self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))
def test_not_available_sdpa(self):
if is_torch_sdpa_available():
self.skipTest("This test requires torch<=2.0")

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import Dict
import numpy as np
@@ -236,6 +237,20 @@ if __name__ == "__main__":
trainer.args.eval_accumulation_steps = None
# Check that saving does indeed work with temp dir rotation
# If this fails, will see a FileNotFoundError
model = RegressionModel()
training_args.max_steps = 1
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
sched = torch.optim.lr_scheduler.LambdaLR(opt, lambda x: 1)
trainer = Trainer(
model, training_args, optimizers=(opt, sched), data_collator=DummyDataCollator(), eval_dataset=dataset
)
trainer._save_checkpoint(model=None, trial=None)
# Check that the temp folder does not exist
assert not (Path(training_args.output_dir) / "tmp-checkpoint-0").exists()
assert (Path(training_args.output_dir) / "checkpoint-0").exists()
# Check that `dispatch_batches=False` will work on a finite iterable dataset
train_dataset = FiniteIterableDataset(label_names=["labels", "extra"], length=1)