[trainer] apex fixes and tests (#9180)

This commit is contained in:
Stas Bekman
2020-12-17 16:49:11 -08:00
committed by GitHub
parent 467e9158b4
commit f06d0fadc9
2 changed files with 22 additions and 8 deletions

View File

@@ -18,7 +18,7 @@ import unittest
from unittest.mock import patch
from transformers import BertTokenizer, EncoderDecoderModel
from transformers.file_utils import is_datasets_available
from transformers.file_utils import is_apex_available, is_datasets_available
from transformers.integrations import is_fairscale_available
from transformers.testing_utils import (
TestCasePlus,
@@ -51,6 +51,17 @@ def require_fairscale(test_case):
return test_case
# a candidate for testing_utils
def require_apex(test_case):
"""
Decorator marking a test that requires apex
"""
if not is_apex_available():
return unittest.skip("test requires apex")(test_case)
else:
return test_case
class TestFinetuneTrainer(TestCasePlus):
def finetune_trainer_quick(self, distributed=None, extra_args_str=None):
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str)
@@ -72,6 +83,7 @@ class TestFinetuneTrainer(TestCasePlus):
def test_finetune_trainer_ddp(self):
self.finetune_trainer_quick(distributed=True)
# it's crucial to test --sharded_ddp w/ and w/o --fp16
@require_torch_multi_gpu
@require_fairscale
def test_finetune_trainer_ddp_sharded_ddp(self):
@@ -82,6 +94,10 @@ class TestFinetuneTrainer(TestCasePlus):
def test_finetune_trainer_ddp_sharded_ddp_fp16(self):
self.finetune_trainer_quick(distributed=True, extra_args_str="--sharded_ddp --fp16")
@require_apex
def test_finetune_trainer_apex(self):
self.finetune_trainer_quick(extra_args_str="--fp16 --fp16_backend=apex")
@slow
def test_finetune_trainer_slow(self):
# There is a missing call to __init__process_group somewhere