FSDP tests and checkpointing fixes (#26180)

* add fsdp tests

* Update test_fsdp.py

* Update test_fsdp.py

* fixes

* checks

* Update trainer.py

* fix

* fixes for saving/resuming checkpoints

* fixes

* add tests and delete debug statements

* fixing tests

* Update test_fsdp.py

* fix tests

* fix tests

* minor nits

* fix code style and quality

* refactor and modularize test code

* reduce the time of tests

* reduce the test time

* fix test

* reduce test time

* reduce test time

* fix failing tests

* fix

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* resolve comments

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Sourab Mangrulkar
2023-09-20 10:26:16 +05:30
committed by GitHub
parent 8e3980a290
commit 382ba670ed
6 changed files with 310 additions and 24 deletions

View File

@@ -61,6 +61,7 @@ from .utils import (
is_essentia_available,
is_faiss_available,
is_flax_available,
is_fsdp_available,
is_ftfy_available,
is_ipex_available,
is_jieba_available,
@@ -316,6 +317,15 @@ def require_accelerate(test_case):
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
def require_fsdp(test_case, min_version: str = "1.12.0"):
"""
Decorator marking a test that requires fsdp. These tests are skipped when fsdp isn't installed.
"""
return unittest.skipUnless(is_fsdp_available(min_version), f"test requires torch version >= {min_version}")(
test_case
)
def require_safetensors(test_case):
"""
Decorator marking a test that requires safetensors. These tests are skipped when safetensors isn't installed.