[trainer] apex fixes and tests (#9180)

This commit is contained in:
Stas Bekman
2020-12-17 16:49:11 -08:00
committed by GitHub
parent 467e9158b4
commit f06d0fadc9
2 changed files with 22 additions and 8 deletions

View File

@@ -53,7 +53,7 @@ from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_in_notebook, is_torch_tpu_available
from .file_utils import WEIGHTS_NAME, is_apex_available, is_datasets_available, is_in_notebook, is_torch_tpu_available
from .modeling_utils import PreTrainedModel
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
from .optimization import AdamW, get_linear_schedule_with_warmup
@@ -104,13 +104,10 @@ if is_in_notebook():
DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
if version.parse(torch.__version__) < version.parse("1.6"):
from .file_utils import is_apex_available
if is_apex_available():
from apex import amp
if is_apex_available():
from apex import amp
else:
if version.parse(torch.__version__) >= version.parse("1.6"):
_is_native_amp_available = True
from torch.cuda.amp import autocast
@@ -309,6 +306,7 @@ class Trainer:
backend = "amp" if _is_native_amp_available else "apex"
else:
backend = args.fp16_backend
logger.info(f"Using {backend} fp16 backend")
if backend == "amp":
self.use_amp = True