Add the auto_find_batch_size capability from Accelerate into Trainer (#17068)

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

- Adds auto_batch_size finder 
- Moves training loop to an inner training loop
This commit is contained in:
Zachary Mueller
2022-05-09 12:29:18 -04:00
committed by GitHub
parent df735d1317
commit 2fbb237967
11 changed files with 166 additions and 3 deletions

View File

@@ -21,6 +21,7 @@ import os
import random
import re
import subprocess
import sys
import tempfile
import time
import unittest
@@ -58,6 +59,7 @@ from transformers.testing_utils import (
require_torch_bf16,
require_torch_gpu,
require_torch_multi_gpu,
require_torch_non_multi_gpu,
require_torch_tf32,
require_torch_up_to_2_gpus,
require_wandb,
@@ -1075,6 +1077,41 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertAlmostEqual(a, a1, delta=1e-8)
self.assertAlmostEqual(b, b1, delta=1e-8)
@slow
@require_torch_non_multi_gpu
def test_auto_batch_size_finder(self):
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
SRC_DIR = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "..", "examples", "pytorch", "text-classification")
)
sys.path.append(SRC_DIR)
import run_glue
with tempfile.TemporaryDirectory() as tmpdir:
testargs = f"""
run_glue.py
--model_name_or_path distilbert-base-uncased
--task_name mrpc
--do_train
--do_eval
--max_seq_len 128
--per_device_train_batch_size 4096
--learning_rate 2e-5
--num_train_epochs 1
--output_dir {tmpdir}
--auto_find_batch_size 0
""".split()
with self.assertRaises(RuntimeError):
with patch.object(sys, "argv", testargs):
run_glue.main()
testargs[-1] = "1"
with patch.object(sys, "argv", testargs):
run_glue.main()
# regression for this issue: https://github.com/huggingface/transformers/issues/12970
def test_training_with_resume_from_checkpoint_false(self):
train_dataset = RegressionDataset(length=128)

View File

@@ -18,7 +18,8 @@ import unittest
import numpy as np
from transformers.testing_utils import require_torch
from transformers.testing_utils import require_accelerate, require_torch
from transformers.trainer_utils import find_executable_batch_size
from transformers.utils import is_torch_available
@@ -420,3 +421,39 @@ class TrainerUtilsTest(unittest.TestCase):
self.check_shard_sampler(dataset, 4, drop_last=True, num_processes=3)
self.check_shard_sampler(dataset, 4, drop_last=False, num_processes=3)
@require_accelerate
def test_executable_batch_size(self):
batch_sizes = []
@find_executable_batch_size(starting_batch_size=64, auto_find_batch_size=True)
def mock_training_loop_function(batch_size):
nonlocal batch_sizes
batch_sizes.append(batch_size)
if batch_size > 16:
raise RuntimeError("CUDA out of memory.")
mock_training_loop_function()
self.assertEqual(batch_sizes, [64, 32, 16])
@require_accelerate
def test_executable_batch_size_no_search(self):
batch_sizes = []
@find_executable_batch_size(starting_batch_size=64, auto_find_batch_size=False)
def mock_training_loop_function(batch_size):
nonlocal batch_sizes
batch_sizes.append(batch_size)
mock_training_loop_function()
self.assertEqual(batch_sizes, [64])
@require_accelerate
def test_executable_batch_size_with_error(self):
@find_executable_batch_size(starting_batch_size=64, auto_find_batch_size=False)
def mock_training_loop_function(batch_size):
raise RuntimeError("CUDA out of memory.")
with self.assertRaises(RuntimeError) as cm:
mock_training_loop_function()
self.assertEqual("CUDA out of memory", cm.args[0])