deepspeed + grad acumm (#9622)
This commit is contained in:
@@ -112,6 +112,11 @@ class TestFinetuneTrainer(TestCasePlus):
|
|||||||
def test_finetune_trainer_deepspeed(self):
|
def test_finetune_trainer_deepspeed(self):
|
||||||
self.finetune_trainer_quick(deepspeed=True)
|
self.finetune_trainer_quick(deepspeed=True)
|
||||||
|
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
@require_deepspeed
|
||||||
|
def test_finetune_trainer_deepspeed_grad_acum(self):
|
||||||
|
self.finetune_trainer_quick(deepspeed=True, extra_args_str="--gradient_accumulation_steps 2")
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_finetune_trainer_slow(self):
|
def test_finetune_trainer_slow(self):
|
||||||
# There is a missing call to __init__process_group somewhere
|
# There is a missing call to __init__process_group somewhere
|
||||||
|
|||||||
@@ -931,7 +931,9 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Optimizer step
|
# Optimizer step
|
||||||
if is_torch_tpu_available():
|
if self.deepspeed:
|
||||||
|
self.deepspeed.step()
|
||||||
|
elif is_torch_tpu_available():
|
||||||
xm.optimizer_step(self.optimizer)
|
xm.optimizer_step(self.optimizer)
|
||||||
elif self.use_amp:
|
elif self.use_amp:
|
||||||
self.scaler.step(self.optimizer)
|
self.scaler.step(self.optimizer)
|
||||||
|
|||||||
Reference in New Issue
Block a user