From 578d23e06114bbd63cf5e931e0fdef9b8b6ac8c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 17 Oct 2019 14:02:27 +0200 Subject: [PATCH] add training pipeline (formatting temporary) --- examples/run_seq2seq_finetuning.py | 139 +++++++++++++++++++++++++++-- 1 file changed, 130 insertions(+), 9 deletions(-) diff --git a/examples/run_seq2seq_finetuning.py b/examples/run_seq2seq_finetuning.py index 3e3cc34cb8..ad6d126165 100644 --- a/examples/run_seq2seq_finetuning.py +++ b/examples/run_seq2seq_finetuning.py @@ -23,8 +23,9 @@ import random import os import numpy as np +from tqdm import tqdm, trange import torch -from torch.utils.data import Dataset +from torch.utils.data import Dataset, RandomSampler from transformers import AutoTokenizer, Model2Model @@ -90,10 +91,14 @@ class TextDataset(Dataset): except IndexError: # skip ill-formed stories continue - story = tokenizer_src.convert_tokens_to_ids(tokenizer_src.tokenize(story)) + story = tokenizer_src.convert_tokens_to_ids( + tokenizer_src.tokenize(story) + ) story_seq = _fit_to_block_size(story, block_size) - summary = tokenizer_tgt.convert_tokens_to_ids(tokenizer_tgt.tokenize(summary)) + summary = tokenizer_tgt.convert_tokens_to_ids( + tokenizer_tgt.tokenize(summary) + ) summary_seq = _fit_to_block_size(summary, block_size) self.examples.append((story_seq, summary_seq)) @@ -179,7 +184,89 @@ def load_and_cache_examples(args, tokenizer_src, tokenizer_tgt): def train(args, train_dataset, model, tokenizer): """ Fine-tune the pretrained model on the corpus. """ - raise NotImplementedError + + # Prepare the data loading + args.train_bach_size = 1 + train_sampler = RandomSampler(train_dataset) + train_dataloader = DataLoader( + train_dataset, sampler=train_sampler, batch_size=args.train_bach_size + ) + + # Prepare the optimizer and schedule (linear warmup and decay) + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in model.named_parameters() + if not any(nd in n for nd in no_decay) + ], + "weight_decay": args.weight_decay, + }, + { + "params": [ + p + for n, p in model.named_parameters() + if any(nd in n for nd in no_decay) + ], + "weight_decay": 0.0, + }, + ] + optimizer = AdamW( + optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon + ) + scheduler = WarmupLinearSchedule( + optimizer, warmup_steps=args.warmup_steps, t_total=t_total + ) + + # Train + logger.info("***** Running training *****") + logger.info(" Num examples = %d", len(train_dataset)) + logger.info(" Num Epochs = %d", args.num_train_epochs) + logger.info( + " Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size + ) + logger.info( + " Total train batch size (w. parallel, distributed & accumulation) = %d", + args.train_batch_size + * args.gradient_accumulation_steps + * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), + ) + logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) + logger.info(" Total optimization steps = %d", t_total) + + global_step = 0 + tr_loss, logging_loss = 0.0, 0.0 + model.zero_grad() + train_iterator = trange(args.num_train_epochs, desc="Epoch", disable=True) + set_seed(args) + for _ in train_iterator: + epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True) + for step, batch in enumerate(epoch_iterator): + source = ([s for s, _ in batch]).to(args.device) + target = ([t for _, t in batch]).to(args.device) + model.train() + outputs = model(source, target) + loss = outputs[0] + loss.backward() + + tr_loss += loss.item() + if (step + 1) % args.gradient_accumulation_steps == 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + optimizer.step() + scheduler.step() + model.zero_grad() + global_step += 1 + + if args.max_steps > 0 and global_step > args.max_steps: + epoch_iterator.close() + break + + if args.max_steps > 0 and global_step > args.max_steps: + train_iterator.close() + break + + return global_step, tr_loss / global_step def main(): @@ -202,6 +289,9 @@ def main(): ) # Optional parameters + parser.add_argument( + "--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer." + ) parser.add_argument( "--decoder_name_or_path", default="bert-base-cased", @@ -226,11 +316,40 @@ def main(): type=str, help="The encoder architecture to be fine-tuned.", ) + parser.add_argument( + "--learning_rate", + default=5e-5, + type=float, + help="The initial learning rate for Adam.", + ) + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." + ) + parser.add_argument( + "--max_steps", + default=-1, + type=int, + help="If > 0: set total number of training steps to perform. Override num_train_epochs.", + ) + parser.add_argument( + "--num_train_epochs", + default=1, + type=int, + help="Total number of training epochs to perform.", + ) parser.add_argument("--seed", default=42, type=int) + parser.add_argument( + "--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps." + ) + parser.add_argument( + "--weight_decay", default=0.0, type=float, help="Weight deay if we apply some." + ) args = parser.parse_args() - if args.encoder_type != 'bert' or args.decoder_type != 'bert': - raise ValueError("Only the BERT architecture is currently supported for seq2seq.") + if args.encoder_type != "bert" or args.decoder_type != "bert": + raise ValueError( + "Only the BERT architecture is currently supported for seq2seq." + ) # Set up training device # device = torch.device("cpu") @@ -241,14 +360,16 @@ def main(): # Load pretrained model and tokenizer encoder_tokenizer_class = AutoTokenizer.from_pretrained(args.encoder_name_or_path) decoder_tokenizer_class = AutoTokenizer.from_pretrained(args.decoder_name_or_path) - model = Model2Model.from_pretrained(args.encoder_name_or_path, args.decoder_name_or_path) + model = Model2Model.from_pretrained( + args.encoder_name_or_path, args.decoder_name_or_path + ) # model.to(device) logger.info("Training/evaluation parameters %s", args) # Training - source, target = load_and_cache_examples(args, tokenizer) - # global_step, tr_loss = train(args, train_dataset, model, tokenizer) + train_dataset = load_and_cache_examples(args, tokenizer) + global_step, tr_loss = train(args, train_dataset, model, tokenizer) # logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)