add gather_use_object arguments (#31514)

* add gather_use_object arguments

* fix name and pass the CI test for Seq2SeqTrainer

* make style

* make it to functools

* fix typo

* add accelerate version:

* adding warning

* Update src/transformers/trainer.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* make style

* Update src/transformers/training_args.py

* check function move to initial part

* add test for eval_use_gather_object

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Sangbum Daniel Choi
2024-06-28 21:50:27 +09:00
committed by GitHub
parent 82a1fc7256
commit cb298978ad
3 changed files with 34 additions and 1 deletions

View File

@@ -132,6 +132,7 @@ if is_torch_available():
# for version specific tests in TrainerIntegrationTest
require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28")
require_accelerate_version_min_0_30 = partial(require_accelerate, min_version="0.30")
GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28")
if is_accelerate_available():
from accelerate import Accelerator
@@ -3565,6 +3566,17 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertIn("torch_dtype", args_dict)
self.assertEqual(args_dict["torch_dtype"], dtype)
@require_accelerate_version_min_0_30
def test_eval_use_gather_object(self):
train_dataset = RegressionDataset()
eval_dataset = RegressionDataset()
model = RegressionDictModel()
args = TrainingArguments("./regression", report_to="none", eval_use_gather_object=True)
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
trainer.train()
_ = trainer.evaluate()
_ = trainer.predict(eval_dataset)
@require_torch
@is_staging_test