diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 345a2cba76..621a993aa9 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -607,7 +607,7 @@ def is_accelerate_available(min_version: str = None): def is_fsdp_available(min_version: str = "1.12.0"): - return version.parse(_torch_version) >= version.parse(min_version) + return is_torch_available() and version.parse(_torch_version) >= version.parse(min_version) def is_optimum_available(): diff --git a/tests/fsdp/test_fsdp.py b/tests/fsdp/test_fsdp.py index 1a968f68fa..f9dd300626 100644 --- a/tests/fsdp/test_fsdp.py +++ b/tests/fsdp/test_fsdp.py @@ -79,6 +79,7 @@ if is_torch_available(): # hack to restore original logging level pre #21700 get_regression_trainer = partial(tests.trainer.test_trainer.get_regression_trainer, log_level="info") +require_fsdp_version = require_fsdp if is_accelerate_available(): from accelerate.utils.constants import ( FSDP_PYTORCH_VERSION,