Add the auto_find_batch_size capability from Accelerate into Trainer (#17068)
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> - Adds auto_batch_size finder - Moves training loop to an inner training loop
This commit is contained in:
@@ -21,6 +21,7 @@ import os
|
||||
import random
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
@@ -58,6 +59,7 @@ from transformers.testing_utils import (
|
||||
require_torch_bf16,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_non_multi_gpu,
|
||||
require_torch_tf32,
|
||||
require_torch_up_to_2_gpus,
|
||||
require_wandb,
|
||||
@@ -1075,6 +1077,41 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertAlmostEqual(a, a1, delta=1e-8)
|
||||
self.assertAlmostEqual(b, b1, delta=1e-8)
|
||||
|
||||
@slow
|
||||
@require_torch_non_multi_gpu
|
||||
def test_auto_batch_size_finder(self):
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
SRC_DIR = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "examples", "pytorch", "text-classification")
|
||||
)
|
||||
sys.path.append(SRC_DIR)
|
||||
import run_glue
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
testargs = f"""
|
||||
run_glue.py
|
||||
--model_name_or_path distilbert-base-uncased
|
||||
--task_name mrpc
|
||||
--do_train
|
||||
--do_eval
|
||||
--max_seq_len 128
|
||||
--per_device_train_batch_size 4096
|
||||
--learning_rate 2e-5
|
||||
--num_train_epochs 1
|
||||
--output_dir {tmpdir}
|
||||
--auto_find_batch_size 0
|
||||
""".split()
|
||||
with self.assertRaises(RuntimeError):
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_glue.main()
|
||||
|
||||
testargs[-1] = "1"
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_glue.main()
|
||||
|
||||
# regression for this issue: https://github.com/huggingface/transformers/issues/12970
|
||||
def test_training_with_resume_from_checkpoint_false(self):
|
||||
train_dataset = RegressionDataset(length=128)
|
||||
|
||||
@@ -18,7 +18,8 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.testing_utils import require_accelerate, require_torch
|
||||
from transformers.trainer_utils import find_executable_batch_size
|
||||
from transformers.utils import is_torch_available
|
||||
|
||||
|
||||
@@ -420,3 +421,39 @@ class TrainerUtilsTest(unittest.TestCase):
|
||||
|
||||
self.check_shard_sampler(dataset, 4, drop_last=True, num_processes=3)
|
||||
self.check_shard_sampler(dataset, 4, drop_last=False, num_processes=3)
|
||||
|
||||
@require_accelerate
|
||||
def test_executable_batch_size(self):
|
||||
batch_sizes = []
|
||||
|
||||
@find_executable_batch_size(starting_batch_size=64, auto_find_batch_size=True)
|
||||
def mock_training_loop_function(batch_size):
|
||||
nonlocal batch_sizes
|
||||
batch_sizes.append(batch_size)
|
||||
if batch_size > 16:
|
||||
raise RuntimeError("CUDA out of memory.")
|
||||
|
||||
mock_training_loop_function()
|
||||
self.assertEqual(batch_sizes, [64, 32, 16])
|
||||
|
||||
@require_accelerate
|
||||
def test_executable_batch_size_no_search(self):
|
||||
batch_sizes = []
|
||||
|
||||
@find_executable_batch_size(starting_batch_size=64, auto_find_batch_size=False)
|
||||
def mock_training_loop_function(batch_size):
|
||||
nonlocal batch_sizes
|
||||
batch_sizes.append(batch_size)
|
||||
|
||||
mock_training_loop_function()
|
||||
self.assertEqual(batch_sizes, [64])
|
||||
|
||||
@require_accelerate
|
||||
def test_executable_batch_size_with_error(self):
|
||||
@find_executable_batch_size(starting_batch_size=64, auto_find_batch_size=False)
|
||||
def mock_training_loop_function(batch_size):
|
||||
raise RuntimeError("CUDA out of memory.")
|
||||
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
mock_training_loop_function()
|
||||
self.assertEqual("CUDA out of memory", cm.args[0])
|
||||
|
||||
Reference in New Issue
Block a user