Add the auto_find_batch_size capability from Accelerate into Trainer (#17068)

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

- Adds auto_batch_size finder 
- Moves training loop to an inner training loop
This commit is contained in:
Zachary Mueller
2022-05-09 12:29:18 -04:00
committed by GitHub
parent df735d1317
commit 2fbb237967
11 changed files with 166 additions and 3 deletions

View File

@@ -40,6 +40,7 @@ from .integrations import (
is_wandb_available,
)
from .utils import (
is_accelerate_available,
is_apex_available,
is_bitsandbytes_available,
is_detectron2_available,
@@ -238,6 +239,13 @@ def require_git_lfs(test_case):
return unittest.skipUnless(_run_git_lfs_tests, "test of git lfs workflow")(test_case)
def require_accelerate(test_case):
"""
Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
"""
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
def require_rjieba(test_case):
"""
Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed.