Merge pull request #116 from FDecaYed/deyuf/fp16_with_apex
Change to use apex for better fp16 and multi-gpu support
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -33,7 +34,7 @@ from torch.nn import CrossEntropyLoss
|
||||
|
||||
from .file_utils import cached_path
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -152,22 +153,24 @@ class BertConfig(object):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
|
||||
except ImportError:
|
||||
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
|
||||
class BertLayerNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-12):
|
||||
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
||||
"""
|
||||
super(BertLayerNorm, self).__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
class BertLayerNorm(nn.Module):
|
||||
def __init__(self, config, variance_epsilon=1e-12):
|
||||
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
||||
"""
|
||||
super(BertLayerNorm, self).__init__()
|
||||
self.gamma = nn.Parameter(torch.ones(config.hidden_size))
|
||||
self.beta = nn.Parameter(torch.zeros(config.hidden_size))
|
||||
self.variance_epsilon = variance_epsilon
|
||||
|
||||
def forward(self, x):
|
||||
u = x.mean(-1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||
return self.gamma * x + self.beta
|
||||
|
||||
def forward(self, x):
|
||||
u = x.mean(-1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||
return self.weight * x + self.bias
|
||||
|
||||
class BertEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings.
|
||||
@@ -180,7 +183,7 @@ class BertEmbeddings(nn.Module):
|
||||
|
||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = BertLayerNorm(config)
|
||||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None):
|
||||
@@ -255,7 +258,7 @@ class BertSelfOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertSelfOutput, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.LayerNorm = BertLayerNorm(config)
|
||||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
@@ -294,7 +297,7 @@ class BertOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertOutput, self).__init__()
|
||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.LayerNorm = BertLayerNorm(config)
|
||||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
@@ -322,7 +325,7 @@ class BertEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertEncoder, self).__init__()
|
||||
layer = BertLayer(config)
|
||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
|
||||
all_encoder_layers = []
|
||||
@@ -356,7 +359,7 @@ class BertPredictionHeadTransform(nn.Module):
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.transform_act_fn = ACT2FN[config.hidden_act] \
|
||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
||||
self.LayerNorm = BertLayerNorm(config)
|
||||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
@@ -439,8 +442,8 @@ class PreTrainedBertModel(nn.Module):
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
elif isinstance(module, BertLayerNorm):
|
||||
module.beta.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.gamma.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.bias.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
@@ -449,7 +452,7 @@ class PreTrainedBertModel(nn.Module):
|
||||
"""
|
||||
Instantiate a PreTrainedBertModel from a pre-trained model file.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
|
||||
|
||||
Params:
|
||||
pretrained_model_name: either:
|
||||
- a str with the name of a pre-trained model to load selected in the list of:
|
||||
@@ -505,6 +508,20 @@ class PreTrainedBertModel(nn.Module):
|
||||
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
||||
state_dict = torch.load(weights_path)
|
||||
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if 'gamma' in key:
|
||||
new_key = key.replace('gamma','weight')
|
||||
if 'beta' in key:
|
||||
new_key = key.replace('beta','bias')
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key]=state_dict.pop(old_key)
|
||||
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
|
||||
@@ -53,11 +53,11 @@ class BertAdam(Optimizer):
|
||||
b1: Adams b1. Default: 0.9
|
||||
b2: Adams b2. Default: 0.999
|
||||
e: Adams epsilon. Default: 1e-6
|
||||
weight_decay_rate: Weight decay. Default: 0.01
|
||||
weight_decay: Weight decay. Default: 0.01
|
||||
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
|
||||
"""
|
||||
def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
|
||||
b1=0.9, b2=0.999, e=1e-6, weight_decay_rate=0.01,
|
||||
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01,
|
||||
max_grad_norm=1.0):
|
||||
if lr is not required and lr < 0.0:
|
||||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
||||
@@ -72,7 +72,7 @@ class BertAdam(Optimizer):
|
||||
if not e >= 0.0:
|
||||
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
|
||||
defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
|
||||
b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate,
|
||||
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
|
||||
max_grad_norm=max_grad_norm)
|
||||
super(BertAdam, self).__init__(params, defaults)
|
||||
|
||||
@@ -140,8 +140,8 @@ class BertAdam(Optimizer):
|
||||
# Instead we want to decay the weights in a manner that doesn't interact
|
||||
# with the m/v parameters. This is equivalent to adding the square
|
||||
# of the weights to the loss with plain (non-momentum) SGD.
|
||||
if group['weight_decay_rate'] > 0.0:
|
||||
update += group['weight_decay_rate'] * p.data
|
||||
if group['weight_decay'] > 0.0:
|
||||
update += group['weight_decay'] * p.data
|
||||
|
||||
if group['t_total'] != -1:
|
||||
schedule_fct = SCHEDULES[group['schedule']]
|
||||
|
||||
Reference in New Issue
Block a user