FSDP tests and checkpointing fixes (#26180)
* add fsdp tests * Update test_fsdp.py * Update test_fsdp.py * fixes * checks * Update trainer.py * fix * fixes for saving/resuming checkpoints * fixes * add tests and delete debug statements * fixing tests * Update test_fsdp.py * fix tests * fix tests * minor nits * fix code style and quality * refactor and modularize test code * reduce the time of tests * reduce the test time * fix test * reduce test time * reduce test time * fix failing tests * fix * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * resolve comments --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
8e3980a290
commit
382ba670ed
@@ -61,6 +61,7 @@ from .utils import (
|
||||
is_essentia_available,
|
||||
is_faiss_available,
|
||||
is_flax_available,
|
||||
is_fsdp_available,
|
||||
is_ftfy_available,
|
||||
is_ipex_available,
|
||||
is_jieba_available,
|
||||
@@ -316,6 +317,15 @@ def require_accelerate(test_case):
|
||||
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
|
||||
|
||||
|
||||
def require_fsdp(test_case, min_version: str = "1.12.0"):
|
||||
"""
|
||||
Decorator marking a test that requires fsdp. These tests are skipped when fsdp isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_fsdp_available(min_version), f"test requires torch version >= {min_version}")(
|
||||
test_case
|
||||
)
|
||||
|
||||
|
||||
def require_safetensors(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires safetensors. These tests are skipped when safetensors isn't installed.
|
||||
|
||||
Reference in New Issue
Block a user