Add the auto_find_batch_size capability from Accelerate into Trainer (#17068)
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> - Adds auto_batch_size finder - Moves training loop to an inner training loop
This commit is contained in:
@@ -21,6 +21,7 @@ import os
|
||||
import random
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
@@ -58,6 +59,7 @@ from transformers.testing_utils import (
|
||||
require_torch_bf16,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_non_multi_gpu,
|
||||
require_torch_tf32,
|
||||
require_torch_up_to_2_gpus,
|
||||
require_wandb,
|
||||
@@ -1075,6 +1077,41 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertAlmostEqual(a, a1, delta=1e-8)
|
||||
self.assertAlmostEqual(b, b1, delta=1e-8)
|
||||
|
||||
@slow
|
||||
@require_torch_non_multi_gpu
|
||||
def test_auto_batch_size_finder(self):
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
SRC_DIR = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "examples", "pytorch", "text-classification")
|
||||
)
|
||||
sys.path.append(SRC_DIR)
|
||||
import run_glue
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
testargs = f"""
|
||||
run_glue.py
|
||||
--model_name_or_path distilbert-base-uncased
|
||||
--task_name mrpc
|
||||
--do_train
|
||||
--do_eval
|
||||
--max_seq_len 128
|
||||
--per_device_train_batch_size 4096
|
||||
--learning_rate 2e-5
|
||||
--num_train_epochs 1
|
||||
--output_dir {tmpdir}
|
||||
--auto_find_batch_size 0
|
||||
""".split()
|
||||
with self.assertRaises(RuntimeError):
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_glue.main()
|
||||
|
||||
testargs[-1] = "1"
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_glue.main()
|
||||
|
||||
# regression for this issue: https://github.com/huggingface/transformers/issues/12970
|
||||
def test_training_with_resume_from_checkpoint_false(self):
|
||||
train_dataset = RegressionDataset(length=128)
|
||||
|
||||
Reference in New Issue
Block a user