Compare commits

...

2 Commits

Author SHA1 Message Date
Arthur Zucker
052e652d6d v4.46.3
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
2024-11-18 21:10:52 +01:00
Wing Lian
e01a61aeab FSDP grad accum fix (#34645)
* add gradient accumulation steps tests for fsdp

* invert no_sync context to fix training for fsdp
2024-11-18 20:07:21 +01:00
4 changed files with 15 additions and 3 deletions

View File

@@ -435,7 +435,7 @@ install_requires = [
setup(
name="transformers",
version="4.46.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.46.3", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co",
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",

View File

@@ -18,7 +18,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.46.2"
__version__ = "4.46.3"
from typing import TYPE_CHECKING

View File

@@ -2474,7 +2474,7 @@ class Trainer:
# We explicitly want to avoid relying on `accelerator.accumulate` for generation training
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i == len(batch_samples) - 1
if i != len(batch_samples) - 1
else contextlib.nullcontext
)
with context():

View File

@@ -224,6 +224,18 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
cmd = launcher + script + args + fsdp_args
execute_subprocess_async(cmd, env=self.get_env())
@parameterized.expand(params, name_func=_parameterized_custom_name_func)
@require_torch_multi_accelerator
@slow
def test_basic_run_with_gradient_accumulation(self, sharding_strategy, dtype):
launcher = get_launcher(distributed=True, use_accelerate=False)
output_dir = self.get_auto_remove_tmp_dir()
args = self.get_base_args(output_dir, 1, 50).split() + [f"--{dtype}", "--gradient_accumulation_steps", "2"]
fsdp_args = ["--fsdp", f"{sharding_strategy} auto_wrap", "--fsdp_transformer_layer_cls_to_wrap", "BertLayer"]
script = [f"{self.examples_dir_str}/pytorch/text-classification/run_glue.py"]
cmd = launcher + script + args + fsdp_args
execute_subprocess_async(cmd, env=self.get_env())
@parameterized.expand(dtypes)
@require_torch_multi_accelerator
@slow