rewamp optimization
This commit is contained in:
@@ -25,19 +25,21 @@ import random
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||||
TensorDataset)
|
TensorDataset)
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from pytorch_transformers import WEIGHTS_NAME
|
from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
|
||||||
from pytorch_transformers import (BertConfig, BertForSequenceClassification,
|
BertForSequenceClassification, BertTokenizer,
|
||||||
BertTokenizer, XLMConfig,
|
XLMConfig, XLMForSequenceClassification,
|
||||||
XLMForSequenceClassification, XLMTokenizer,
|
XLMTokenizer, XLNetConfig,
|
||||||
XLNetConfig, XLNetForSequenceClassification,
|
XLNetForSequenceClassification,
|
||||||
XLNetTokenizer)
|
XLNetTokenizer)
|
||||||
from pytorch_transformers.optimization import BertAdam
|
|
||||||
|
from pytorch_transformers import AdamW, WarmupLinearSchedule
|
||||||
|
|
||||||
from utils_glue import (compute_metrics, convert_examples_to_features,
|
from utils_glue import (compute_metrics, convert_examples_to_features,
|
||||||
output_modes, processors)
|
output_modes, processors)
|
||||||
|
|
||||||
@@ -56,24 +58,24 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
tb_writer = SummaryWriter()
|
tb_writer = SummaryWriter()
|
||||||
|
|
||||||
args.train_batch_size = args.per_gpu_train_batch_size * args.n_gpu
|
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||||
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||||
|
|
||||||
if args.max_steps > 0:
|
if args.max_steps > 0:
|
||||||
num_train_optimization_steps = args.max_steps
|
t_total = args.max_steps
|
||||||
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
||||||
else:
|
else:
|
||||||
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||||
|
|
||||||
# Prepare optimizer
|
# Prepare optimizer and schedule (linear warmup and decay)
|
||||||
no_decay = ['bias', 'LayerNorm.weight']
|
no_decay = ['bias', 'LayerNorm.weight']
|
||||||
optimizer_grouped_parameters = [
|
optimizer_grouped_parameters = [
|
||||||
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
||||||
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||||
]
|
]
|
||||||
optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate,
|
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
|
||||||
t_total=num_train_optimization_steps, warmup=args.warmup_proportion)
|
schedule = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
try:
|
try:
|
||||||
from apex import amp
|
from apex import amp
|
||||||
@@ -89,11 +91,11 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||||
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
||||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||||
logger.info(" Total optimization steps = %d", num_train_optimization_steps)
|
logger.info(" Total optimization steps = %d", t_total)
|
||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
tr_loss, logging_loss = 0.0, 0.0
|
tr_loss, logging_loss = 0.0, 0.0
|
||||||
optimizer.zero_grad()
|
model.zero_grad()
|
||||||
for _ in trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]):
|
for _ in trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]):
|
||||||
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
|
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
|
||||||
model.train()
|
model.train()
|
||||||
@@ -103,7 +105,7 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
|
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
|
||||||
'labels': batch[3]}
|
'labels': batch[3]}
|
||||||
ouputs = model(**inputs)
|
ouputs = model(**inputs)
|
||||||
loss = ouputs[0]
|
loss = ouputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
||||||
|
|
||||||
if args.n_gpu > 1:
|
if args.n_gpu > 1:
|
||||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||||
@@ -113,22 +115,25 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
if args.fp16:
|
if args.fp16:
|
||||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||||
scaled_loss.backward()
|
scaled_loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||||
else:
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||||
|
|
||||||
tr_loss += loss.item()
|
tr_loss += loss.item()
|
||||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||||
|
scheduler.step() # Update learning rate schedule
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
model.zero_grad()
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||||
# Log metrics
|
# Log metrics
|
||||||
if args.local_rank == -1: # Only evaluate on single GPU otherwise metrics may not average well
|
if args.local_rank == -1: # Only evaluate when single GPU otherwise metrics may not average well
|
||||||
results = evaluate(args, model, tokenizer, prefix=global_step)
|
results = evaluate(args, model, tokenizer)
|
||||||
for key, value in results.items():
|
for key, value in results.items():
|
||||||
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
||||||
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
|
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
|
||||||
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
||||||
logging_loss = tr_loss
|
logging_loss = tr_loss
|
||||||
|
|
||||||
@@ -140,6 +145,7 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||||
model_to_save.save_pretrained(output_dir)
|
model_to_save.save_pretrained(output_dir)
|
||||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
||||||
|
logger.info("Saving model checkpoint to %s", output_dir)
|
||||||
|
|
||||||
if args.max_steps > 0 and global_step > args.max_steps:
|
if args.max_steps > 0 and global_step > args.max_steps:
|
||||||
break
|
break
|
||||||
@@ -162,20 +168,21 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
||||||
os.makedirs(eval_output_dir)
|
os.makedirs(eval_output_dir)
|
||||||
|
|
||||||
|
args.eval_batch_size = args.per_gpu_eval_batch_size * args.n_gpu
|
||||||
# Note that DistributedSampler samples randomly
|
# Note that DistributedSampler samples randomly
|
||||||
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
||||||
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||||
|
|
||||||
# Eval!
|
# Eval!
|
||||||
logger.info("***** Running evaluation *****")
|
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||||
logger.info(" Num examples = %d", len(eval_dataset))
|
logger.info(" Num examples = %d", len(eval_dataset))
|
||||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||||
model.eval()
|
|
||||||
eval_loss = 0
|
eval_loss = 0
|
||||||
nb_eval_steps = 0
|
nb_eval_steps = 0
|
||||||
preds = None
|
preds = None
|
||||||
out_label_ids = None
|
out_label_ids = None
|
||||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||||
|
model.eval()
|
||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -186,7 +193,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
tmp_eval_loss, logits = outputs[:2]
|
tmp_eval_loss, logits = outputs[:2]
|
||||||
|
|
||||||
eval_loss += tmp_eval_loss.mean().item()
|
eval_loss += tmp_eval_loss.mean().item()
|
||||||
nb_eval_steps += 1
|
nb_eval_steps += 1
|
||||||
if preds is None:
|
if preds is None:
|
||||||
preds = logits.detach().cpu().numpy()
|
preds = logits.detach().cpu().numpy()
|
||||||
@@ -213,7 +220,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def load_and_cache_examples(args, task, tokenizer, evaluate=False, overwrite_cache=False):
|
def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
||||||
processor = processors[task]()
|
processor = processors[task]()
|
||||||
output_mode = output_modes[task]
|
output_mode = output_modes[task]
|
||||||
# Load data features from cache or dataset file
|
# Load data features from cache or dataset file
|
||||||
@@ -285,20 +292,22 @@ def main():
|
|||||||
|
|
||||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
||||||
help="Batch size per GPU for training.")
|
help="Batch size per GPU for training.")
|
||||||
parser.add_argument("--eval_batch_size", default=8, type=int,
|
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
||||||
help="Total batch size for eval.")
|
help="Batch size per GPU for evaluation.")
|
||||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
||||||
help="The initial learning rate for Adam.")
|
help="The initial learning rate for Adam.")
|
||||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
parser.add_argument("--weight_decay", default=0.0, type=float,
|
||||||
help="Weight deay if we apply some.")
|
help="Weight deay if we apply some.")
|
||||||
|
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
||||||
|
help="Max gradient norm.")
|
||||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
||||||
help="Total number of training epochs to perform.")
|
help="Total number of training epochs to perform.")
|
||||||
parser.add_argument("--max_steps", default=-1, type=int,
|
parser.add_argument("--max_steps", default=-1, type=int,
|
||||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
||||||
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
parser.add_argument("--warmup_steps", default=0, type=int,
|
||||||
help="Proportion of training with linear learning rate warmup (0.1 = 10%% of training).")
|
help="Linear warmup over warmup_steps.")
|
||||||
|
|
||||||
parser.add_argument('--logging_steps', type=int, default=50,
|
parser.add_argument('--logging_steps', type=int, default=50,
|
||||||
help="Log every X updates steps.")
|
help="Log every X updates steps.")
|
||||||
@@ -409,6 +418,7 @@ def main():
|
|||||||
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
||||||
os.makedirs(args.output_dir)
|
os.makedirs(args.output_dir)
|
||||||
|
|
||||||
|
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||||
# They can then be reloaded using `from_pretrained()`
|
# They can then be reloaded using `from_pretrained()`
|
||||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||||
@@ -427,15 +437,18 @@ def main():
|
|||||||
if args.do_eval and args.local_rank in [-1, 0]:
|
if args.do_eval and args.local_rank in [-1, 0]:
|
||||||
checkpoints = [args.output_dir + './' + WEIGHTS_NAME]
|
checkpoints = [args.output_dir + './' + WEIGHTS_NAME]
|
||||||
if args.eval_all_checkpoints:
|
if args.eval_all_checkpoints:
|
||||||
checkpoints = list(os.path.dirname(c) for c in glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))
|
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||||
|
logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
||||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||||
results = {}
|
results = {}
|
||||||
for checkpoint in checkpoints:
|
for checkpoint in checkpoints:
|
||||||
global_step = int(checkpoints.split('-')[-1])
|
global_step = int(checkpoint.split('-')[-1])
|
||||||
model = model_class.from_pretrained(checkpoints)
|
model = model_class.from_pretrained(checkpoint)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
result = evaluate(args, model, tokenizer, prefix=global_step)
|
result = evaluate(args, model, tokenizer, prefix=global_step)
|
||||||
result = dict(n + '_{}'.format())
|
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
||||||
|
results.update(result)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ from .modeling_xlm import (XLMConfig, XLMModel,
|
|||||||
from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
|
from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
|
||||||
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)
|
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)
|
||||||
|
|
||||||
from .optimization import BertAdam
|
from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule,
|
||||||
from .optimization_openai import OpenAIAdam
|
WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
|
||||||
|
|
||||||
from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path)
|
from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path)
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ class PretrainedConfig(object):
|
|||||||
for key in to_remove:
|
for key in to_remove:
|
||||||
kwargs.pop(key, None)
|
kwargs.pop(key, None)
|
||||||
|
|
||||||
logger.info("Model config {}".format(config))
|
logger.info("Model config %s", config)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -211,10 +211,6 @@ class XLNetConfig(PretrainedConfig):
|
|||||||
layer_norm_eps=1e-12,
|
layer_norm_eps=1e-12,
|
||||||
|
|
||||||
dropout=0.1,
|
dropout=0.1,
|
||||||
dropatt=0.1,
|
|
||||||
init="normal",
|
|
||||||
init_range=0.1,
|
|
||||||
init_std=0.02,
|
|
||||||
mem_len=None,
|
mem_len=None,
|
||||||
reuse_len=None,
|
reuse_len=None,
|
||||||
bi_data=False,
|
bi_data=False,
|
||||||
@@ -258,11 +254,6 @@ class XLNetConfig(PretrainedConfig):
|
|||||||
|
|
||||||
dropout: float, dropout rate.
|
dropout: float, dropout rate.
|
||||||
dropatt: float, dropout rate on attention probabilities.
|
dropatt: float, dropout rate on attention probabilities.
|
||||||
init: str, the initialization scheme, either "normal" or "uniform".
|
|
||||||
init_range: float, initialize the parameters with a uniform distribution
|
|
||||||
in [-init_range, init_range]. Only effective when init="uniform".
|
|
||||||
init_std: float, initialize the parameters with a normal distribution
|
|
||||||
with mean 0 and stddev init_std. Only effective when init="normal".
|
|
||||||
mem_len: int, the number of tokens to cache.
|
mem_len: int, the number of tokens to cache.
|
||||||
reuse_len: int, the number of tokens in the currect batch to be cached
|
reuse_len: int, the number of tokens in the currect batch to be cached
|
||||||
and reused in the future.
|
and reused in the future.
|
||||||
@@ -297,11 +288,7 @@ class XLNetConfig(PretrainedConfig):
|
|||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
self.layer_norm_eps = layer_norm_eps
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
|
||||||
self.init = init
|
|
||||||
self.init_range = init_range
|
|
||||||
self.init_std = init_std
|
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.dropatt = dropatt
|
|
||||||
self.mem_len = mem_len
|
self.mem_len = mem_len
|
||||||
self.reuse_len = reuse_len
|
self.reuse_len = reuse_len
|
||||||
self.bi_data = bi_data
|
self.bi_data = bi_data
|
||||||
|
|||||||
@@ -14,174 +14,92 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch optimization for BERT model."""
|
"""PyTorch optimization for BERT model."""
|
||||||
|
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.optimizer import required
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
from torch.nn.utils import clip_grad_norm_
|
|
||||||
import logging
|
|
||||||
import abc
|
|
||||||
import sys
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class ConstantLRSchedule(LambdaLR):
|
||||||
|
def __init__(self, optimizer, last_epoch=-1):
|
||||||
|
super(ConstantLR, self).__init__(optimizer, lambda x: x, last_epoch=last_epoch)
|
||||||
|
|
||||||
if sys.version_info >= (3, 4):
|
class WarmupCosineSchedule(LambdaLR):
|
||||||
ABC = abc.ABC
|
|
||||||
else:
|
|
||||||
ABC = abc.ABCMeta('ABC', (), {})
|
|
||||||
|
|
||||||
|
|
||||||
class _LRSchedule(ABC):
|
|
||||||
""" Parent of all LRSchedules here. """
|
|
||||||
warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense
|
|
||||||
def __init__(self, warmup=0.002, t_total=-1, **kw):
|
|
||||||
"""
|
|
||||||
:param warmup: what fraction of t_total steps will be used for linear warmup
|
|
||||||
:param t_total: how many training steps (updates) are planned
|
|
||||||
:param kw:
|
|
||||||
"""
|
|
||||||
super(_LRSchedule, self).__init__(**kw)
|
|
||||||
if t_total < 0:
|
|
||||||
logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
|
|
||||||
if not 0.0 <= warmup < 1.0 and not warmup == -1:
|
|
||||||
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
|
|
||||||
warmup = max(warmup, 0.)
|
|
||||||
self.warmup, self.t_total = float(warmup), float(t_total)
|
|
||||||
self.warned_for_t_total_at_progress = -1
|
|
||||||
|
|
||||||
def get_lr(self, step, nowarn=False):
|
|
||||||
"""
|
|
||||||
:param step: which of t_total steps we're on
|
|
||||||
:param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps
|
|
||||||
:return: learning rate multiplier for current update
|
|
||||||
"""
|
|
||||||
if self.t_total < 0:
|
|
||||||
return 1.
|
|
||||||
progress = float(step) / self.t_total
|
|
||||||
ret = self.get_lr_(progress)
|
|
||||||
# warning for exceeding t_total (only active with warmup_linear
|
|
||||||
if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress:
|
|
||||||
logger.warning(
|
|
||||||
"Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly."
|
|
||||||
.format(ret, self.__class__.__name__))
|
|
||||||
self.warned_for_t_total_at_progress = progress
|
|
||||||
# end warning
|
|
||||||
return ret
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def get_lr_(self, progress):
|
|
||||||
"""
|
|
||||||
:param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress
|
|
||||||
:return: learning rate multiplier for current update
|
|
||||||
"""
|
|
||||||
return 1.
|
|
||||||
|
|
||||||
|
|
||||||
class ConstantLR(_LRSchedule):
|
|
||||||
def get_lr_(self, progress):
|
|
||||||
return 1.
|
|
||||||
|
|
||||||
|
|
||||||
class WarmupCosineSchedule(_LRSchedule):
|
|
||||||
"""
|
"""
|
||||||
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
|
Linearly increases learning rate from 0 to 1 over `warmup` training steps.
|
||||||
Decreases learning rate from 1. to 0. over remaining `1 - warmup` steps following a cosine curve.
|
Decreases learning rate from 1. to 0. over remaining `t_total - warmup` steps following a cosine curve.
|
||||||
If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
|
If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
|
||||||
|
:param warmup: see LRSchedule
|
||||||
|
:param t_total: see LRSchedule
|
||||||
|
:param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1.
|
||||||
|
:param kw:
|
||||||
"""
|
"""
|
||||||
warn_t_total = True
|
warn_t_total = True
|
||||||
def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw):
|
def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1):
|
||||||
"""
|
|
||||||
:param warmup: see LRSchedule
|
|
||||||
:param t_total: see LRSchedule
|
|
||||||
:param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1.
|
|
||||||
:param kw:
|
|
||||||
"""
|
|
||||||
super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw)
|
|
||||||
self.cycles = cycles
|
|
||||||
|
|
||||||
def get_lr_(self, progress):
|
def lr_lambda(step):
|
||||||
if progress < self.warmup:
|
if step < warmup_steps:
|
||||||
return progress / self.warmup
|
return step / max(1, warmup_steps)
|
||||||
else:
|
else:
|
||||||
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
|
progress = (step - warmup_steps) / max(1, t_total - warmup_steps) # progress after warmup
|
||||||
return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress))
|
return 0.5 * (1. + math.cos(math.pi * cycles * 2 * progress))
|
||||||
|
|
||||||
|
super(WarmupCosineSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch)
|
||||||
|
|
||||||
class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule):
|
class WarmupCosineWithHardRestartsSchedule(LambdaLR):
|
||||||
"""
|
"""
|
||||||
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
|
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
|
||||||
If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying
|
If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying
|
||||||
learning rate (with hard restarts).
|
learning rate (with hard restarts).
|
||||||
"""
|
"""
|
||||||
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
|
def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1):
|
||||||
super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
|
|
||||||
assert(cycles >= 1.)
|
|
||||||
|
|
||||||
def get_lr_(self, progress):
|
def lr_lambda(step):
|
||||||
if progress < self.warmup:
|
if step < warmup_steps:
|
||||||
return progress / self.warmup
|
return step / max(1, warmup_steps)
|
||||||
else:
|
else:
|
||||||
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
|
progress = (step - warmup_steps) / max(1, t_total - warmup_steps) # progress after warmup
|
||||||
ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1)))
|
ret = 0.5 * (1. + math.cos(math.pi * ((cycles * progress) % 1)))
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
super(WarmupCosineWithHardRestartsSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch)
|
||||||
|
|
||||||
|
|
||||||
class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule):
|
class WarmupConstantSchedule(LambdaLR):
|
||||||
"""
|
|
||||||
All training progress is divided in `cycles` (default=1.) parts of equal length.
|
|
||||||
Every part follows a schedule with the first `warmup` fraction of the training steps linearly increasing from 0. to 1.,
|
|
||||||
followed by a learning rate decreasing from 1. to 0. following a cosine curve.
|
|
||||||
"""
|
|
||||||
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
|
|
||||||
assert(warmup * cycles < 1.)
|
|
||||||
warmup = warmup * cycles if warmup >= 0 else warmup
|
|
||||||
super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
|
|
||||||
|
|
||||||
def get_lr_(self, progress):
|
|
||||||
progress = progress * self.cycles % 1.
|
|
||||||
if progress < self.warmup:
|
|
||||||
return progress / self.warmup
|
|
||||||
else:
|
|
||||||
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
|
|
||||||
ret = 0.5 * (1. + math.cos(math.pi * progress))
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
class WarmupConstantSchedule(_LRSchedule):
|
|
||||||
"""
|
"""
|
||||||
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
|
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
|
||||||
Keeps learning rate equal to 1. after warmup.
|
Keeps learning rate equal to 1. after warmup.
|
||||||
"""
|
"""
|
||||||
def get_lr_(self, progress):
|
def __init__(self, optimizer, warmup_steps, last_epoch=-1):
|
||||||
if progress < self.warmup:
|
|
||||||
return progress / self.warmup
|
def lr_lambda(step):
|
||||||
return 1.
|
if step < warmup_steps:
|
||||||
|
return step / warmup_steps
|
||||||
|
return 1.
|
||||||
|
|
||||||
|
super(WarmupConstantSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch)
|
||||||
|
|
||||||
|
|
||||||
class WarmupLinearSchedule(_LRSchedule):
|
class WarmupLinearSchedule(LambdaLR):
|
||||||
"""
|
"""
|
||||||
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
|
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
|
||||||
Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps.
|
Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps.
|
||||||
"""
|
"""
|
||||||
warn_t_total = True
|
def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1):
|
||||||
def get_lr_(self, progress):
|
|
||||||
if progress < self.warmup:
|
def lr_lambda(step):
|
||||||
return progress / self.warmup
|
if step < warmup_steps:
|
||||||
return max((progress - 1.) / (self.warmup - 1.), 0.)
|
return step / max(1, warmup_steps)
|
||||||
|
return (t_total - step) / max(1, t_total - warmup_steps)
|
||||||
|
|
||||||
|
super(WarmupLinearSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch)
|
||||||
|
|
||||||
|
|
||||||
SCHEDULES = {
|
class AdamW(Optimizer):
|
||||||
None: ConstantLR,
|
""" Implements Adam algorithm with weight decay fix.
|
||||||
"none": ConstantLR,
|
|
||||||
"warmup_cosine": WarmupCosineSchedule,
|
|
||||||
"warmup_constant": WarmupConstantSchedule,
|
|
||||||
"warmup_linear": WarmupLinearSchedule
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class BertAdam(Optimizer):
|
|
||||||
"""Implements BERT version of Adam algorithm with weight decay fix.
|
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
lr: learning rate
|
lr: learning rate
|
||||||
@@ -197,46 +115,21 @@ class BertAdam(Optimizer):
|
|||||||
e: Adams epsilon. Default: 1e-6
|
e: Adams epsilon. Default: 1e-6
|
||||||
weight_decay: Weight decay. Default: 0.01
|
weight_decay: Weight decay. Default: 0.01
|
||||||
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
|
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
|
||||||
|
correct_bias: can be set to False to avoid correcting bias in Adam (e.g. like in Bert repository)
|
||||||
"""
|
"""
|
||||||
def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
|
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01, correct_bias=True):
|
||||||
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs):
|
if lr < 0.0:
|
||||||
if lr is not required and lr < 0.0:
|
|
||||||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
||||||
if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES:
|
if not 0.0 <= betas[0] < 1.0:
|
||||||
raise ValueError("Invalid schedule parameter: {}".format(schedule))
|
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
|
||||||
if not 0.0 <= b1 < 1.0:
|
if not 0.0 <= betas[1] < 1.0:
|
||||||
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
|
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1] ))
|
||||||
if not 0.0 <= b2 < 1.0:
|
if not 0.0 <= eps:
|
||||||
raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
|
|
||||||
if not e >= 0.0:
|
|
||||||
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
|
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
|
||||||
# initialize schedule object
|
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
|
||||||
if not isinstance(schedule, _LRSchedule):
|
correct_bias=correct_bias)
|
||||||
schedule_type = SCHEDULES[schedule]
|
|
||||||
schedule = schedule_type(warmup=warmup, t_total=t_total)
|
|
||||||
else:
|
|
||||||
if warmup != -1 or t_total != -1:
|
|
||||||
logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. "
|
|
||||||
"Please specify custom warmup and t_total in _LRSchedule object.")
|
|
||||||
defaults = dict(lr=lr, schedule=schedule,
|
|
||||||
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
|
|
||||||
max_grad_norm=max_grad_norm)
|
|
||||||
super(BertAdam, self).__init__(params, defaults)
|
super(BertAdam, self).__init__(params, defaults)
|
||||||
|
|
||||||
def get_lr(self):
|
|
||||||
lr = []
|
|
||||||
for group in self.param_groups:
|
|
||||||
for p in group['params']:
|
|
||||||
if p.grad is None:
|
|
||||||
continue
|
|
||||||
state = self.state[p]
|
|
||||||
if len(state) == 0:
|
|
||||||
return [0]
|
|
||||||
lr_scheduled = group['lr']
|
|
||||||
lr_scheduled *= group['schedule'].get_lr(state['step'])
|
|
||||||
lr.append(lr_scheduled)
|
|
||||||
return lr
|
|
||||||
|
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
"""Performs a single optimization step.
|
"""Performs a single optimization step.
|
||||||
|
|
||||||
@@ -262,22 +155,28 @@ class BertAdam(Optimizer):
|
|||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
state['step'] = 0
|
state['step'] = 0
|
||||||
# Exponential moving average of gradient values
|
# Exponential moving average of gradient values
|
||||||
state['next_m'] = torch.zeros_like(p.data)
|
state['exp_avg'] = torch.zeros_like(p.data)
|
||||||
# Exponential moving average of squared gradient values
|
# Exponential moving average of squared gradient values
|
||||||
state['next_v'] = torch.zeros_like(p.data)
|
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
||||||
|
|
||||||
next_m, next_v = state['next_m'], state['next_v']
|
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||||
beta1, beta2 = group['b1'], group['b2']
|
beta1, beta2 = group['betas']
|
||||||
|
|
||||||
# Add grad clipping
|
state['step'] += 1
|
||||||
if group['max_grad_norm'] > 0:
|
|
||||||
clip_grad_norm_(p, group['max_grad_norm'])
|
|
||||||
|
|
||||||
# Decay the first and second moment running average coefficient
|
# Decay the first and second moment running average coefficient
|
||||||
# In-place operations to update the averages at the same time
|
# In-place operations to update the averages at the same time
|
||||||
next_m.mul_(beta1).add_(1 - beta1, grad)
|
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||||
next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||||
update = next_m / (next_v.sqrt() + group['e'])
|
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||||
|
|
||||||
|
step_size = group['lr']
|
||||||
|
if group['correct_bias']: # No bias correction for Bert
|
||||||
|
bias_correction1 = 1 - beta1 ** state['step']
|
||||||
|
bias_correction2 = 1 - beta2 ** state['step']
|
||||||
|
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
|
||||||
|
|
||||||
|
p.data.addcdiv_(-step_size, exp_avg, denom)
|
||||||
|
|
||||||
# Just adding the square of the weights to the loss function is *not*
|
# Just adding the square of the weights to the loss function is *not*
|
||||||
# the correct way of using L2 regularization/weight decay with Adam,
|
# the correct way of using L2 regularization/weight decay with Adam,
|
||||||
@@ -286,20 +185,8 @@ class BertAdam(Optimizer):
|
|||||||
# Instead we want to decay the weights in a manner that doesn't interact
|
# Instead we want to decay the weights in a manner that doesn't interact
|
||||||
# with the m/v parameters. This is equivalent to adding the square
|
# with the m/v parameters. This is equivalent to adding the square
|
||||||
# of the weights to the loss with plain (non-momentum) SGD.
|
# of the weights to the loss with plain (non-momentum) SGD.
|
||||||
if group['weight_decay'] > 0.0:
|
# Add weight decay at the end (fixed version)
|
||||||
update += group['weight_decay'] * p.data
|
if group['weight_decay'] > 0:
|
||||||
|
p.data.add_(-group['lr'] * group['weight_decay'], p.data)
|
||||||
lr_scheduled = group['lr']
|
|
||||||
lr_scheduled *= group['schedule'].get_lr(state['step'])
|
|
||||||
|
|
||||||
update_with_lr = lr_scheduled * update
|
|
||||||
p.data.add_(-update_with_lr)
|
|
||||||
|
|
||||||
state['step'] += 1
|
|
||||||
|
|
||||||
# step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
|
|
||||||
# No bias correction
|
|
||||||
# bias_correction1 = 1 - beta1 ** state['step']
|
|
||||||
# bias_correction2 = 1 - beta2 ** state['step']
|
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|||||||
@@ -1,127 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""PyTorch optimization for OpenAI GPT model."""
|
|
||||||
|
|
||||||
import math
|
|
||||||
import torch
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
from torch.optim.optimizer import required
|
|
||||||
from torch.nn.utils import clip_grad_norm_
|
|
||||||
import logging
|
|
||||||
from .optimization import SCHEDULES, _LRSchedule, WarmupCosineWithWarmupRestartsSchedule, \
|
|
||||||
WarmupCosineWithHardRestartsSchedule, WarmupCosineSchedule, WarmupLinearSchedule, WarmupConstantSchedule
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIAdam(Optimizer):
|
|
||||||
"""Implements Open AI version of Adam algorithm with weight decay fix.
|
|
||||||
"""
|
|
||||||
def __init__(self, params, lr=required, schedule='warmup_linear', warmup=-1, t_total=-1,
|
|
||||||
b1=0.9, b2=0.999, e=1e-8, weight_decay=0,
|
|
||||||
vector_l2=False, max_grad_norm=-1, **kwargs):
|
|
||||||
if lr is not required and lr < 0.0:
|
|
||||||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
|
||||||
if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES:
|
|
||||||
raise ValueError("Invalid schedule parameter: {}".format(schedule))
|
|
||||||
if not 0.0 <= b1 < 1.0:
|
|
||||||
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
|
|
||||||
if not 0.0 <= b2 < 1.0:
|
|
||||||
raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
|
|
||||||
if not e >= 0.0:
|
|
||||||
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
|
|
||||||
# initialize schedule object
|
|
||||||
if not isinstance(schedule, _LRSchedule):
|
|
||||||
schedule_type = SCHEDULES[schedule]
|
|
||||||
schedule = schedule_type(warmup=warmup, t_total=t_total)
|
|
||||||
else:
|
|
||||||
if warmup != -1 or t_total != -1:
|
|
||||||
logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. "
|
|
||||||
"Please specify custom warmup and t_total in _LRSchedule object.")
|
|
||||||
defaults = dict(lr=lr, schedule=schedule,
|
|
||||||
b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2,
|
|
||||||
max_grad_norm=max_grad_norm)
|
|
||||||
super(OpenAIAdam, self).__init__(params, defaults)
|
|
||||||
|
|
||||||
def get_lr(self):
|
|
||||||
lr = []
|
|
||||||
for group in self.param_groups:
|
|
||||||
for p in group['params']:
|
|
||||||
state = self.state[p]
|
|
||||||
if len(state) == 0:
|
|
||||||
return [0]
|
|
||||||
lr_scheduled = group['lr']
|
|
||||||
lr_scheduled *= group['schedule'].get_lr(state['step'])
|
|
||||||
lr.append(lr_scheduled)
|
|
||||||
return lr
|
|
||||||
|
|
||||||
def step(self, closure=None):
|
|
||||||
"""Performs a single optimization step.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
closure (callable, optional): A closure that reevaluates the model
|
|
||||||
and returns the loss.
|
|
||||||
"""
|
|
||||||
loss = None
|
|
||||||
if closure is not None:
|
|
||||||
loss = closure()
|
|
||||||
|
|
||||||
for group in self.param_groups:
|
|
||||||
for p in group['params']:
|
|
||||||
if p.grad is None:
|
|
||||||
continue
|
|
||||||
grad = p.grad.data
|
|
||||||
if grad.is_sparse:
|
|
||||||
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
|
|
||||||
|
|
||||||
state = self.state[p]
|
|
||||||
|
|
||||||
# State initialization
|
|
||||||
if len(state) == 0:
|
|
||||||
state['step'] = 0
|
|
||||||
# Exponential moving average of gradient values
|
|
||||||
state['exp_avg'] = torch.zeros_like(p.data)
|
|
||||||
# Exponential moving average of squared gradient values
|
|
||||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
|
||||||
|
|
||||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
|
||||||
beta1, beta2 = group['b1'], group['b2']
|
|
||||||
|
|
||||||
state['step'] += 1
|
|
||||||
|
|
||||||
# Add grad clipping
|
|
||||||
if group['max_grad_norm'] > 0:
|
|
||||||
clip_grad_norm_(p, group['max_grad_norm'])
|
|
||||||
|
|
||||||
# Decay the first and second moment running average coefficient
|
|
||||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
|
||||||
denom = exp_avg_sq.sqrt().add_(group['e'])
|
|
||||||
|
|
||||||
bias_correction1 = 1 - beta1 ** state['step']
|
|
||||||
bias_correction2 = 1 - beta2 ** state['step']
|
|
||||||
|
|
||||||
lr_scheduled = group['lr']
|
|
||||||
lr_scheduled *= group['schedule'].get_lr(state['step'])
|
|
||||||
|
|
||||||
step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
|
|
||||||
|
|
||||||
p.data.addcdiv_(-step_size, exp_avg, denom)
|
|
||||||
|
|
||||||
# Add weight decay at the end (fixed version)
|
|
||||||
if (len(p.size()) > 1 or group['vector_l2']) and group['weight_decay'] > 0:
|
|
||||||
p.data.add_(-lr_scheduled * group['weight_decay'], p.data)
|
|
||||||
|
|
||||||
return loss
|
|
||||||
@@ -20,10 +20,9 @@ import unittest
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch_transformers import BertAdam
|
from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule,
|
||||||
from pytorch_transformers import OpenAIAdam
|
WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
|
||||||
from pytorch_transformers.optimization import ConstantLR, WarmupLinearSchedule, WarmupConstantSchedule, \
|
|
||||||
WarmupCosineWithWarmupRestartsSchedule, WarmupCosineWithHardRestartsSchedule, WarmupCosineSchedule
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
@@ -34,12 +33,12 @@ class OptimizationTest(unittest.TestCase):
|
|||||||
for a, b in zip(list1, list2):
|
for a, b in zip(list1, list2):
|
||||||
self.assertAlmostEqual(a, b, delta=tol)
|
self.assertAlmostEqual(a, b, delta=tol)
|
||||||
|
|
||||||
def test_adam(self):
|
def test_adam_w(self):
|
||||||
w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
|
w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
|
||||||
target = torch.tensor([0.4, 0.2, -0.5])
|
target = torch.tensor([0.4, 0.2, -0.5])
|
||||||
criterion = torch.nn.MSELoss()
|
criterion = torch.nn.MSELoss()
|
||||||
# No warmup, constant schedule, no gradient clipping
|
# No warmup, constant schedule, no gradient clipping
|
||||||
optimizer = BertAdam(params=[w], lr=2e-1,
|
optimizer = AdamW(params=[w], lr=2e-1,
|
||||||
weight_decay=0.0,
|
weight_decay=0.0,
|
||||||
max_grad_norm=-1)
|
max_grad_norm=-1)
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
@@ -52,23 +51,13 @@ class OptimizationTest(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
class ScheduleInitTest(unittest.TestCase):
|
class ScheduleInitTest(unittest.TestCase):
|
||||||
def test_bert_sched_init(self):
|
def test_sched_init(self):
|
||||||
m = torch.nn.Linear(50, 50)
|
m = torch.nn.Linear(50, 50)
|
||||||
optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
|
optim = AdamW(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
|
||||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||||
optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
|
optim = AdamW(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
|
||||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||||
optim = BertAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
|
optim = AdamW(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
|
||||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule))
|
|
||||||
# shouldn't fail
|
|
||||||
|
|
||||||
def test_openai_sched_init(self):
|
|
||||||
m = torch.nn.Linear(50, 50)
|
|
||||||
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
|
|
||||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
|
||||||
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
|
|
||||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
|
||||||
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
|
|
||||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule))
|
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule))
|
||||||
# shouldn't fail
|
# shouldn't fail
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user