Add the auto_find_batch_size capability from Accelerate into Trainer (#17068)

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

- Adds auto_batch_size finder 
- Moves training loop to an inner training loop
This commit is contained in:
Zachary Mueller
2022-05-09 12:29:18 -04:00
committed by GitHub
parent df735d1317
commit 2fbb237967
11 changed files with 166 additions and 3 deletions

View File

@@ -18,7 +18,8 @@ import unittest
import numpy as np
from transformers.testing_utils import require_torch
from transformers.testing_utils import require_accelerate, require_torch
from transformers.trainer_utils import find_executable_batch_size
from transformers.utils import is_torch_available
@@ -420,3 +421,39 @@ class TrainerUtilsTest(unittest.TestCase):
self.check_shard_sampler(dataset, 4, drop_last=True, num_processes=3)
self.check_shard_sampler(dataset, 4, drop_last=False, num_processes=3)
@require_accelerate
def test_executable_batch_size(self):
batch_sizes = []
@find_executable_batch_size(starting_batch_size=64, auto_find_batch_size=True)
def mock_training_loop_function(batch_size):
nonlocal batch_sizes
batch_sizes.append(batch_size)
if batch_size > 16:
raise RuntimeError("CUDA out of memory.")
mock_training_loop_function()
self.assertEqual(batch_sizes, [64, 32, 16])
@require_accelerate
def test_executable_batch_size_no_search(self):
batch_sizes = []
@find_executable_batch_size(starting_batch_size=64, auto_find_batch_size=False)
def mock_training_loop_function(batch_size):
nonlocal batch_sizes
batch_sizes.append(batch_size)
mock_training_loop_function()
self.assertEqual(batch_sizes, [64])
@require_accelerate
def test_executable_batch_size_with_error(self):
@find_executable_batch_size(starting_batch_size=64, auto_find_batch_size=False)
def mock_training_loop_function(batch_size):
raise RuntimeError("CUDA out of memory.")
with self.assertRaises(RuntimeError) as cm:
mock_training_loop_function()
self.assertEqual("CUDA out of memory", cm.args[0])