change to apex for better fp16 and multi-gpu support

This commit is contained in:
Deyu Fu
2018-12-05 15:07:40 -08:00
parent a3a3180c86
commit c8ea286048
6 changed files with 142 additions and 169 deletions

View File

@@ -1,5 +1,6 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -30,10 +31,14 @@ import shutil
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this.")
from .file_utils import cached_path
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logger = logging.getLogger(__name__)
@@ -180,7 +185,7 @@ class BertEmbeddings(nn.Module):
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = BertLayerNorm(config)
self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None):
@@ -255,7 +260,7 @@ class BertSelfOutput(nn.Module):
def __init__(self, config):
super(BertSelfOutput, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config)
self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
@@ -294,7 +299,7 @@ class BertOutput(nn.Module):
def __init__(self, config):
super(BertOutput, self).__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config)
self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
@@ -322,7 +327,7 @@ class BertEncoder(nn.Module):
def __init__(self, config):
super(BertEncoder, self).__init__()
layer = BertLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
all_encoder_layers = []
@@ -356,7 +361,7 @@ class BertPredictionHeadTransform(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.transform_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
self.LayerNorm = BertLayerNorm(config)
self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-12)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
@@ -438,6 +443,9 @@ class PreTrainedBertModel(nn.Module):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, FusedLayerNorm):
module.bias.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, BertLayerNorm):
module.beta.data.normal_(mean=0.0, std=self.config.initializer_range)
module.gamma.data.normal_(mean=0.0, std=self.config.initializer_range)
@@ -449,7 +457,7 @@ class PreTrainedBertModel(nn.Module):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name: either:
- a str with the name of a pre-trained model to load selected in the list of:
@@ -505,6 +513,20 @@ class PreTrainedBertModel(nn.Module):
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
state_dict = torch.load(weights_path)
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma','weight')
if 'beta' in key:
new_key = key.replace('beta','bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key]=state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []