From b247b0d880fe10e8e1a873d0b710f95f246af8ea Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Wed, 28 Aug 2019 02:12:47 +0000 Subject: [PATCH] add `train.py` for distillation --- examples/distillation/train.py | 230 +++++++++++++++++++++++++++++++++ 1 file changed, 230 insertions(+) create mode 100644 examples/distillation/train.py diff --git a/examples/distillation/train.py b/examples/distillation/train.py new file mode 100644 index 0000000000..824eeac046 --- /dev/null +++ b/examples/distillation/train.py @@ -0,0 +1,230 @@ +import os +import argparse +import pickle +import json +import shutil +import numpy as np +import torch + +from pytorch_transformers import BertTokenizer, BertForMaskedLM +from pytorch_transformers import DilBertForMaskedLM, DilBertConfig + +from distiller import Distiller +from utils import git_log, logger, init_gpu_params, set_seed +from dataset import Dataset + + +def main(): + parser = argparse.ArgumentParser(description="Training") + + parser.add_argument("--dump_path", type=str, required=True, + help="The output directory (log, checkpoints, parameters, etc.)") + parser.add_argument("--data_file", type=str, required=True, + help="The binarized file (tokenized + tokens_to_ids) and grouped by sequence.") + parser.add_argument("--token_counts", type=str, required=True, + help="The token counts in the data_file for MLM.") + parser.add_argument("--force", action='store_true', + help="Overwrite dump_path if it already exists.") + + parser.add_argument("--vocab_size", default=30522, type=int, + help="The vocabulary size.") + parser.add_argument("--max_position_embeddings", default=512, type=int, + help="Maximum sequence length we can model (including [CLS] and [SEP]).") + parser.add_argument("--sinusoidal_pos_embds", action='store_false', + help="If true, the position embeddings are simply fixed with sinusoidal embeddings.") + parser.add_argument("--n_layers", default=6, type=int, + help="Number of Transformer blocks.") + parser.add_argument("--n_heads", default=12, type=int, + help="Number of heads in the self-attention module.") + parser.add_argument("--dim", default=768, type=int, + help="Dimension through the network. Must be divisible by n_heads") + parser.add_argument("--hidden_dim", default=3072, type=int, + help="Intermediate dimension in the FFN.") + parser.add_argument("--dropout", default=0.1, type=float, + help="Dropout.") + parser.add_argument("--attention_dropout", default=0.1, type=float, + help="Dropout in self-attention.") + parser.add_argument("--activation", default='gelu', type=str, + help="Activation to use in self-attention") + parser.add_argument("--tie_weights_", action='store_false', + help="If true, we tie the embeddings matrix with the projection over the vocabulary matrix. Default is true.") + + parser.add_argument("--from_pretrained_weights", default=None, type=str, + help="Load student initialization checkpoint.") + parser.add_argument("--from_pretrained_config", default=None, type=str, + help="Load student initialization architecture config.") + parser.add_argument("--bert_model", default='bert-base-uncased', type=str, + help="The teacher BERT model.") + + parser.add_argument("--temperature", default=2., type=float, + help="Temperature for the softmax temperature.") + parser.add_argument("--alpha_ce", default=0.5, type=float, + help="Linear weight for the distillation loss. Must be >=0.") + parser.add_argument("--alpha_mlm", default=0.5, type=float, + help="Linear weight for the MLM loss. Must be >=0.") + parser.add_argument("--alpha_mse", default=0.0, type=float, + help="Linear weight of the MSE loss. Must be >=0.") + parser.add_argument("--mlm_mask_prop", default=0.15, type=float, + help="Proportion of tokens for which we need to make a prediction.") + parser.add_argument("--word_mask", default=0.8, type=float, + help="Proportion of tokens to mask out.") + parser.add_argument("--word_keep", default=0.1, type=float, + help="Proportion of tokens to keep.") + parser.add_argument("--word_rand", default=0.1, type=float, + help="Proportion of tokens to randomly replace.") + parser.add_argument("--mlm_smoothing", default=0.7, type=float, + help="Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).") + parser.add_argument("--restrict_ce_to_mask", action='store_true', + help="If true, compute the distilation loss only the [MLM] prediction distribution.") + + parser.add_argument("--n_epoch", type=int, default=3, + help="Number of pass on the whole dataset.") + parser.add_argument("--batch_size", type=int, default=5, + help="Batch size (for each process).") + parser.add_argument("--tokens_per_batch", type=int, default=-1, + help="If specified, modify the batches so that they have approximately this number of tokens.") + parser.add_argument("--shuffle", action='store_false', + help="If true, shuffle the sequence order. Default is true.") + parser.add_argument("--group_by_size", action='store_false', + help="If true, group sequences that have similar length into the same batch. Default is true.") + + parser.add_argument("--gradient_accumulation_steps", type=int, default=50, + help="Gradient accumulation for larger training batches.") + parser.add_argument("--warmup_prop", default=0.05, type=float, + help="Linear warmup proportion.") + parser.add_argument("--weight_decay", default=0.0, type=float, + help="Weight deay if we apply some.") + parser.add_argument("--learning_rate", default=5e-4, type=float, + help="The initial learning rate for Adam.") + parser.add_argument("--adam_epsilon", default=1e-6, type=float, + help="Epsilon for Adam optimizer.") + parser.add_argument("--max_grad_norm", default=5.0, type=float, + help="Max gradient norm.") + parser.add_argument("--initializer_range", default=0.02, type=float, + help="Random initialization range.") + + parser.add_argument('--fp16', action='store_true', + help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") + parser.add_argument('--fp16_opt_level', type=str, default='O1', + help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." + "See details at https://nvidia.github.io/apex/amp.html") + parser.add_argument("--n_gpu", type=int, default=1, + help="Number of GPUs in the node.") + parser.add_argument("--local_rank", type=int, default=-1, + help="Distributed training - Local rank") + parser.add_argument("--seed", type=int, default=56, + help="Random seed") + + parser.add_argument("--log_interval", type=int, default=500, + help="Tensorboard logging interval.") + parser.add_argument("--checkpoint_interval", type=int, default=4000, + help="Checkpoint interval.") + args = parser.parse_args() + + + ## ARGS ## + init_gpu_params(args) + set_seed(args) + if args.is_master: + if os.path.exists(args.dump_path): + if not args.force: + raise ValueError(f'Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite it' + 'Use `--force` if you want to overwrite it') + else: + shutil.rmtree(args.dump_path) + + if not os.path.exists(args.dump_path): + os.makedirs(args.dump_path) + logger.info(f'Experiment will be dumped and logged in {args.dump_path}') + + + ### SAVE PARAMS ### + logger.info(f'Param: {args}') + with open(os.path.join(args.dump_path, 'parameters.json'), 'w') as f: + json.dump(vars(args), f, indent=4) + git_log(args.dump_path) + + + ### TOKENIZER ### + bert_tokenizer = BertTokenizer.from_pretrained(args.bert_model) + special_tok_ids = {} + for tok_name, tok_symbol in bert_tokenizer.special_tokens_map.items(): + idx = bert_tokenizer.all_special_tokens.index(tok_symbol) + special_tok_ids[tok_name] = bert_tokenizer.all_special_ids[idx] + logger.info(f'Special tokens {special_tok_ids}') + args.special_tok_ids = special_tok_ids + + + ## DATA LOADER ## + logger.info(f'Loading data from {args.data_file}') + with open(args.data_file, 'rb') as fp: + data = pickle.load(fp) + + + assert os.path.isfile(args.token_counts) + logger.info(f'Loading token counts from {args.token_counts} (already pre-computed)') + with open(args.token_counts, 'rb') as fp: + counts = pickle.load(fp) + assert len(counts) == args.vocab_size + token_probs = np.maximum(counts, 1) ** -args.mlm_smoothing + for idx in special_tok_ids.values(): + token_probs[idx] = 0. # do not predict special tokens + token_probs = torch.from_numpy(token_probs) + + + train_dataloader = Dataset(params=args, data=data) + logger.info(f'Data loader created.') + + + ## STUDENT ## + assert (args.from_pretrained_weights is None and args.from_pretrained_config is None) or \ + (args.from_pretrained_weights is not None and args.from_pretrained_config is not None) + if args.from_pretrained_weights is not None: + assert os.path.isfile(os.path.join(args.from_pretrained, 'config.json')) + assert os.path.isfile(os.path.join(args.from_pretrained, 'config.json')) + logger.info(f'Loading pretrained weights from {args.from_pretrained_weights}') + logger.info(f'Loading pretrained config from {args.from_pretrained_config}') + stu_architecture_config = DilBertConfig.from_json_file(args.from_pretrained_config) + student = DilBertForMaskedLM.from_pretrained(args.from_pretrained_weights, + config=stu_architecture_config) + else: + + stu_architecture_config = DilBertConfig(args) + student = DilBertForMaskedLM(stu_architecture_config) + # student = Model(vocab_size=args.vocab_size, + # max_position_embeddings=args.max_position_embeddings, + # sinusoidal_pos_embds=args.sinusoidal_pos_embds, + # n_layers=args.n_layers, + # n_heads=args.n_heads, + # dim=args.dim, + # dropout=args.dropout, + # attention_dropout=args.attention_dropout, + # activation=args.activation, + # initializer_range=args.initializer_range, + # tie_weights=args.tie_weights) + + + if args.n_gpu > 0: + student.to(f'cuda:{args.local_rank}') + logger.info(f'Student loaded.') + + + ## TEACHER ## + teacher = BertForMaskedLM.from_pretrained(args.bert_model) + if args.n_gpu > 0: + teacher.to(f'cuda:{args.local_rank}') + logger.info(f'Teacher loaded from {args.bert_model}.') + + ## DISTILLER ## + torch.cuda.empty_cache() + distiller = Distiller(params=args, + dataloader=train_dataloader, + token_probs=token_probs, + student=student, + teacher=teacher) + distiller.train() + logger.info("Let's go get some drinks.") + + +if __name__ == "__main__": + main()