[trainer] apex fixes and tests (#9180)
This commit is contained in:
@@ -53,7 +53,7 @@ from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
||||
|
||||
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
||||
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_in_notebook, is_torch_tpu_available
|
||||
from .file_utils import WEIGHTS_NAME, is_apex_available, is_datasets_available, is_in_notebook, is_torch_tpu_available
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
||||
from .optimization import AdamW, get_linear_schedule_with_warmup
|
||||
@@ -104,13 +104,10 @@ if is_in_notebook():
|
||||
|
||||
DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
|
||||
|
||||
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
|
||||
if version.parse(torch.__version__) < version.parse("1.6"):
|
||||
from .file_utils import is_apex_available
|
||||
if is_apex_available():
|
||||
from apex import amp
|
||||
|
||||
if is_apex_available():
|
||||
from apex import amp
|
||||
else:
|
||||
if version.parse(torch.__version__) >= version.parse("1.6"):
|
||||
_is_native_amp_available = True
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
@@ -309,6 +306,7 @@ class Trainer:
|
||||
backend = "amp" if _is_native_amp_available else "apex"
|
||||
else:
|
||||
backend = args.fp16_backend
|
||||
logger.info(f"Using {backend} fp16 backend")
|
||||
|
||||
if backend == "amp":
|
||||
self.use_amp = True
|
||||
|
||||
Reference in New Issue
Block a user