add gather_use_object arguments (#31514)
* add gather_use_object arguments * fix name and pass the CI test for Seq2SeqTrainer * make style * make it to functools * fix typo * add accelerate version: * adding warning * Update src/transformers/trainer.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * make style * Update src/transformers/training_args.py * check function move to initial part * add test for eval_use_gather_object --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
82a1fc7256
commit
cb298978ad
@@ -132,6 +132,7 @@ if is_torch_available():
|
||||
|
||||
# for version specific tests in TrainerIntegrationTest
|
||||
require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28")
|
||||
require_accelerate_version_min_0_30 = partial(require_accelerate, min_version="0.30")
|
||||
GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28")
|
||||
if is_accelerate_available():
|
||||
from accelerate import Accelerator
|
||||
@@ -3565,6 +3566,17 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertIn("torch_dtype", args_dict)
|
||||
self.assertEqual(args_dict["torch_dtype"], dtype)
|
||||
|
||||
@require_accelerate_version_min_0_30
|
||||
def test_eval_use_gather_object(self):
|
||||
train_dataset = RegressionDataset()
|
||||
eval_dataset = RegressionDataset()
|
||||
model = RegressionDictModel()
|
||||
args = TrainingArguments("./regression", report_to="none", eval_use_gather_object=True)
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
|
||||
trainer.train()
|
||||
_ = trainer.evaluate()
|
||||
_ = trainer.predict(eval_dataset)
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user