[trainer] apex fixes and tests (#9180)
This commit is contained in:
@@ -18,7 +18,7 @@ import unittest
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from transformers import BertTokenizer, EncoderDecoderModel
|
from transformers import BertTokenizer, EncoderDecoderModel
|
||||||
from transformers.file_utils import is_datasets_available
|
from transformers.file_utils import is_apex_available, is_datasets_available
|
||||||
from transformers.integrations import is_fairscale_available
|
from transformers.integrations import is_fairscale_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
TestCasePlus,
|
TestCasePlus,
|
||||||
@@ -51,6 +51,17 @@ def require_fairscale(test_case):
|
|||||||
return test_case
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
|
# a candidate for testing_utils
|
||||||
|
def require_apex(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires apex
|
||||||
|
"""
|
||||||
|
if not is_apex_available():
|
||||||
|
return unittest.skip("test requires apex")(test_case)
|
||||||
|
else:
|
||||||
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
class TestFinetuneTrainer(TestCasePlus):
|
class TestFinetuneTrainer(TestCasePlus):
|
||||||
def finetune_trainer_quick(self, distributed=None, extra_args_str=None):
|
def finetune_trainer_quick(self, distributed=None, extra_args_str=None):
|
||||||
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str)
|
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str)
|
||||||
@@ -72,6 +83,7 @@ class TestFinetuneTrainer(TestCasePlus):
|
|||||||
def test_finetune_trainer_ddp(self):
|
def test_finetune_trainer_ddp(self):
|
||||||
self.finetune_trainer_quick(distributed=True)
|
self.finetune_trainer_quick(distributed=True)
|
||||||
|
|
||||||
|
# it's crucial to test --sharded_ddp w/ and w/o --fp16
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_gpu
|
||||||
@require_fairscale
|
@require_fairscale
|
||||||
def test_finetune_trainer_ddp_sharded_ddp(self):
|
def test_finetune_trainer_ddp_sharded_ddp(self):
|
||||||
@@ -82,6 +94,10 @@ class TestFinetuneTrainer(TestCasePlus):
|
|||||||
def test_finetune_trainer_ddp_sharded_ddp_fp16(self):
|
def test_finetune_trainer_ddp_sharded_ddp_fp16(self):
|
||||||
self.finetune_trainer_quick(distributed=True, extra_args_str="--sharded_ddp --fp16")
|
self.finetune_trainer_quick(distributed=True, extra_args_str="--sharded_ddp --fp16")
|
||||||
|
|
||||||
|
@require_apex
|
||||||
|
def test_finetune_trainer_apex(self):
|
||||||
|
self.finetune_trainer_quick(extra_args_str="--fp16 --fp16_backend=apex")
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_finetune_trainer_slow(self):
|
def test_finetune_trainer_slow(self):
|
||||||
# There is a missing call to __init__process_group somewhere
|
# There is a missing call to __init__process_group somewhere
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||||||
from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
||||||
|
|
||||||
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
||||||
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_in_notebook, is_torch_tpu_available
|
from .file_utils import WEIGHTS_NAME, is_apex_available, is_datasets_available, is_in_notebook, is_torch_tpu_available
|
||||||
from .modeling_utils import PreTrainedModel
|
from .modeling_utils import PreTrainedModel
|
||||||
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
||||||
from .optimization import AdamW, get_linear_schedule_with_warmup
|
from .optimization import AdamW, get_linear_schedule_with_warmup
|
||||||
@@ -104,13 +104,10 @@ if is_in_notebook():
|
|||||||
|
|
||||||
DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
|
DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
|
||||||
|
|
||||||
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
|
if is_apex_available():
|
||||||
if version.parse(torch.__version__) < version.parse("1.6"):
|
|
||||||
from .file_utils import is_apex_available
|
|
||||||
|
|
||||||
if is_apex_available():
|
|
||||||
from apex import amp
|
from apex import amp
|
||||||
else:
|
|
||||||
|
if version.parse(torch.__version__) >= version.parse("1.6"):
|
||||||
_is_native_amp_available = True
|
_is_native_amp_available = True
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
|
|
||||||
@@ -309,6 +306,7 @@ class Trainer:
|
|||||||
backend = "amp" if _is_native_amp_available else "apex"
|
backend = "amp" if _is_native_amp_available else "apex"
|
||||||
else:
|
else:
|
||||||
backend = args.fp16_backend
|
backend = args.fp16_backend
|
||||||
|
logger.info(f"Using {backend} fp16 backend")
|
||||||
|
|
||||||
if backend == "amp":
|
if backend == "amp":
|
||||||
self.use_amp = True
|
self.use_amp = True
|
||||||
|
|||||||
Reference in New Issue
Block a user