Add the auto_find_batch_size capability from Accelerate into Trainer (#17068)
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> - Adds auto_batch_size finder - Moves training loop to an inner training loop
This commit is contained in:
@@ -40,6 +40,7 @@ from .integrations import (
|
||||
is_wandb_available,
|
||||
)
|
||||
from .utils import (
|
||||
is_accelerate_available,
|
||||
is_apex_available,
|
||||
is_bitsandbytes_available,
|
||||
is_detectron2_available,
|
||||
@@ -238,6 +239,13 @@ def require_git_lfs(test_case):
|
||||
return unittest.skipUnless(_run_git_lfs_tests, "test of git lfs workflow")(test_case)
|
||||
|
||||
|
||||
def require_accelerate(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
|
||||
|
||||
|
||||
def require_rjieba(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed.
|
||||
|
||||
Reference in New Issue
Block a user