diff --git a/examples/distillation/train.py b/examples/distillation/train.py index f0255d08fe..311f0580ff 100644 --- a/examples/distillation/train.py +++ b/examples/distillation/train.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Training DistilBERT. +Training the distilled model. +Supported architectures include: BERT -> DistilBERT, RoBERTa -> DistilRoBERTa, GPT2 -> DistilGPT2. """ import os import argparse @@ -23,68 +24,96 @@ import shutil import numpy as np import torch -from transformers import BertTokenizer, BertForMaskedLM, RobertaTokenizer, RobertaForMaskedLM -from transformers import DistilBertForMaskedLM, DistilBertConfig +from transformers import BertConfig, BertForMaskedLM, BertTokenizer +from transformers import RobertaConfig, RobertaForMaskedLM, RobertaTokenizer +from transformers import DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer +from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer from distiller import Distiller from utils import git_log, logger, init_gpu_params, set_seed -from dataset import Dataset +from lm_seqs_dataset import LmSeqsDataset +MODEL_CLASSES = { + 'distilbert': (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer), + 'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer), + 'bert': (BertConfig, BertForMaskedLM, BertTokenizer), + 'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer) +} + +def sanity_checks(args): + """ + A bunch of args sanity checks to perform even starting... + """ + assert (args.mlm and args.alpha_mlm > 0.) or (not args.mlm and args.alpha_mlm == 0.) + assert (args.alpha_mlm > 0. and args.alpha_clm == 0.) or (args.alpha_mlm == 0. and args.alpha_clm > 0.) + if args.mlm: + assert os.path.isfile(args.token_counts) + assert (args.student_type in ['roberta', 'distilbert']) and (args.teacher_type in ['roberta', 'bert']) + else: + assert (args.student_type in ['gpt2']) and (args.teacher_type in ['gpt2']) + + assert args.teacher_type == args.student_type or (args.student_type=='distilbert' and args.teacher_type=='bert') + assert os.path.isfile(args.student_config) + if args.student_pretrained_weights is not None: + assert os.path.isfile(args.student_pretrained_weights) + + if args.freeze_token_type_embds: assert args.student_type in ['roberta'] + + assert args.alpha_ce >= 0. + assert args.alpha_mlm >= 0. + assert args.alpha_clm >= 0. + assert args.alpha_mse >= 0. + assert args.alpha_cos >= 0. + assert args.alpha_ce + args.alpha_mlm + args.alpha_clm + args.alpha_mse + args.alpha_cos > 0. + +def freeze_pos_embeddings(student, args): + if args.student_type == 'roberta': + student.roberta.embeddings.position_embeddings.weight.requires_grad = False + elif args.student_type == 'gpt2': + student.transformer.wpe.weight.requires_grad = False + +def freeze_token_type_embeddings(student, args): + if args.student_type == 'roberta': + student.roberta.embeddings.token_type_embeddings.weight.requires_grad = False + def main(): parser = argparse.ArgumentParser(description="Training") + parser.add_argument("--force", action='store_true', + help="Overwrite dump_path if it already exists.") parser.add_argument("--dump_path", type=str, required=True, help="The output directory (log, checkpoints, parameters, etc.)") parser.add_argument("--data_file", type=str, required=True, help="The binarized file (tokenized + tokens_to_ids) and grouped by sequence.") - parser.add_argument("--token_counts", type=str, required=True, - help="The token counts in the data_file for MLM.") - parser.add_argument("--force", action='store_true', - help="Overwrite dump_path if it already exists.") - parser.add_argument("--vocab_size", default=30522, type=int, - help="The vocabulary size.") - parser.add_argument("--max_position_embeddings", default=512, type=int, - help="Maximum sequence length we can model (including [CLS] and [SEP]).") - parser.add_argument("--sinusoidal_pos_embds", action='store_false', - help="If true, the position embeddings are simply fixed with sinusoidal embeddings.") - parser.add_argument("--n_layers", default=6, type=int, - help="Number of Transformer blocks.") - parser.add_argument("--n_heads", default=12, type=int, - help="Number of heads in the self-attention module.") - parser.add_argument("--dim", default=768, type=int, - help="Dimension through the network. Must be divisible by n_heads") - parser.add_argument("--hidden_dim", default=3072, type=int, - help="Intermediate dimension in the FFN.") - parser.add_argument("--dropout", default=0.1, type=float, - help="Dropout.") - parser.add_argument("--attention_dropout", default=0.1, type=float, - help="Dropout in self-attention.") - parser.add_argument("--activation", default='gelu', type=str, - help="Activation to use in self-attention") - parser.add_argument("--tie_weights_", action='store_false', - help="If true, we tie the embeddings matrix with the projection over the vocabulary matrix. Default is true.") - - parser.add_argument("--from_pretrained_weights", default=None, type=str, + parser.add_argument("--student_type", type=str, choices=["distilbert", "roberta", "gpt2"], required=True, + help="The student type (DistilBERT, RoBERTa).") + parser.add_argument("--student_config", type=str, required=True, + help="Path to the student configuration.") + parser.add_argument("--student_pretrained_weights", default=None, type=str, help="Load student initialization checkpoint.") - parser.add_argument("--from_pretrained_config", default=None, type=str, - help="Load student initialization architecture config.") - parser.add_argument("--teacher_type", default="bert", choices=["bert", "roberta"], + + parser.add_argument("--teacher_type", choices=["bert", "roberta", "gpt2"], required=True, help="Teacher type (BERT, RoBERTa).") - parser.add_argument("--teacher_name", default="bert-base-uncased", type=str, + parser.add_argument("--teacher_name", type=str, required=True, help="The teacher model.") parser.add_argument("--temperature", default=2., type=float, help="Temperature for the softmax temperature.") parser.add_argument("--alpha_ce", default=0.5, type=float, help="Linear weight for the distillation loss. Must be >=0.") - parser.add_argument("--alpha_mlm", default=0.5, type=float, - help="Linear weight for the MLM loss. Must be >=0.") + parser.add_argument("--alpha_mlm", default=0.0, type=float, + help="Linear weight for the MLM loss. Must be >=0. Should be used in coonjunction with `mlm` flag.") + parser.add_argument("--alpha_clm", default=0.5, type=float, + help="Linear weight for the CLM loss. Must be >=0.") parser.add_argument("--alpha_mse", default=0.0, type=float, help="Linear weight of the MSE loss. Must be >=0.") parser.add_argument("--alpha_cos", default=0.0, type=float, help="Linear weight of the cosine embedding loss. Must be >=0.") + + parser.add_argument("--mlm", action="store_true", + help="The LM step: MLM or CLM. If `mlm` is True, the MLM is used over CLM.") parser.add_argument("--mlm_mask_prop", default=0.15, type=float, help="Proportion of tokens for which we need to make a prediction.") parser.add_argument("--word_mask", default=0.8, type=float, @@ -95,17 +124,20 @@ def main(): help="Proportion of tokens to randomly replace.") parser.add_argument("--mlm_smoothing", default=0.7, type=float, help="Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).") + parser.add_argument("--token_counts", type=str, + help="The token counts in the data_file for MLM.") + parser.add_argument("--restrict_ce_to_mask", action='store_true', help="If true, compute the distilation loss only the [MLM] prediction distribution.") + parser.add_argument("--freeze_pos_embs", action="store_true", + help="Freeze positional embeddings during distillation. For student_type in ['roberta', 'gpt2'] only.") + parser.add_argument("--freeze_token_type_embds", action="store_true", + help="Freeze token type embeddings during distillation if existent. For student_type in ['roberta'] only.") parser.add_argument("--n_epoch", type=int, default=3, help="Number of pass on the whole dataset.") parser.add_argument("--batch_size", type=int, default=5, help="Batch size (for each process).") - parser.add_argument("--tokens_per_batch", type=int, default=-1, - help="If specified, modify the batches so that they have approximately this number of tokens.") - parser.add_argument("--shuffle", action='store_false', - help="If true, shuffle the sequence order. Default is true.") parser.add_argument("--group_by_size", action='store_false', help="If true, group sequences that have similar length into the same batch. Default is true.") @@ -141,6 +173,7 @@ def main(): parser.add_argument("--checkpoint_interval", type=int, default=4000, help="Checkpoint interval.") args = parser.parse_args() + sanity_checks(args) ## ARGS ## @@ -164,21 +197,19 @@ def main(): with open(os.path.join(args.dump_path, 'parameters.json'), 'w') as f: json.dump(vars(args), f, indent=4) git_log(args.dump_path) - assert (args.from_pretrained_weights is None and args.from_pretrained_config is None) or \ - (args.from_pretrained_weights is not None and args.from_pretrained_config is not None) + student_config_class, student_model_class, _ = MODEL_CLASSES[args.student_type] + teacher_config_class, teacher_model_class, teacher_tokenizer_class = MODEL_CLASSES[args.teacher_type] ### TOKENIZER ### - if args.teacher_type == 'bert': - tokenizer = BertTokenizer.from_pretrained(args.teacher_name) - elif args.teacher_type == 'roberta': - tokenizer = RobertaTokenizer.from_pretrained(args.teacher_name) + tokenizer = teacher_tokenizer_class.from_pretrained(args.teacher_name) special_tok_ids = {} for tok_name, tok_symbol in tokenizer.special_tokens_map.items(): idx = tokenizer.all_special_tokens.index(tok_symbol) special_tok_ids[tok_name] = tokenizer.all_special_ids[idx] logger.info(f'Special tokens {special_tok_ids}') args.special_tok_ids = special_tok_ids + args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name] ## DATA LOADER ## @@ -187,35 +218,34 @@ def main(): data = pickle.load(fp) - assert os.path.isfile(args.token_counts) - logger.info(f'Loading token counts from {args.token_counts} (already pre-computed)') - with open(args.token_counts, 'rb') as fp: - counts = pickle.load(fp) - assert len(counts) == args.vocab_size - token_probs = np.maximum(counts, 1) ** -args.mlm_smoothing - for idx in special_tok_ids.values(): - token_probs[idx] = 0. # do not predict special tokens - token_probs = torch.from_numpy(token_probs) + if args.mlm: + logger.info(f'Loading token counts from {args.token_counts} (already pre-computed)') + with open(args.token_counts, 'rb') as fp: + counts = pickle.load(fp) + + token_probs = np.maximum(counts, 1) ** -args.mlm_smoothing + for idx in special_tok_ids.values(): + token_probs[idx] = 0. # do not predict special tokens + token_probs = torch.from_numpy(token_probs) + else: + token_probs = None - train_dataloader = Dataset(params=args, data=data) + train_lm_seq_dataset = LmSeqsDataset(params=args, data=data) logger.info(f'Data loader created.') ## STUDENT ## - if args.from_pretrained_weights is not None: - assert os.path.isfile(args.from_pretrained_weights) - assert os.path.isfile(args.from_pretrained_config) - logger.info(f'Loading pretrained weights from {args.from_pretrained_weights}') - logger.info(f'Loading pretrained config from {args.from_pretrained_config}') - stu_architecture_config = DistilBertConfig.from_json_file(args.from_pretrained_config) - stu_architecture_config.output_hidden_states = True - student = DistilBertForMaskedLM.from_pretrained(args.from_pretrained_weights, - config=stu_architecture_config) + logger.info(f'Loading student config from {args.student_config}') + stu_architecture_config = student_config_class.from_pretrained(args.student_config) + stu_architecture_config.output_hidden_states = True + + if args.student_pretrained_weights is not None: + logger.info(f'Loading pretrained weights from {args.student_pretrained_weights}') + student = student_model_class.from_pretrained(args.student_pretrained_weights, + config=stu_architecture_config) else: - args.vocab_size_or_config_json_file = args.vocab_size - stu_architecture_config = DistilBertConfig(**vars(args), output_hidden_states=True) - student = DistilBertForMaskedLM(stu_architecture_config) + student = student_model_class(stu_architecture_config) if args.n_gpu > 0: @@ -224,18 +254,31 @@ def main(): ## TEACHER ## - if args.teacher_type == 'bert': - teacher = BertForMaskedLM.from_pretrained(args.teacher_name, output_hidden_states=True) - elif args.teacher_type == 'roberta': - teacher = RobertaForMaskedLM.from_pretrained(args.teacher_name, output_hidden_states=True) + teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True) if args.n_gpu > 0: teacher.to(f'cuda:{args.local_rank}') logger.info(f'Teacher loaded from {args.teacher_name}.') + + ## FREEZING ## + if args.freeze_pos_embs: + freeze_pos_embeddings(student, args) + if args.freeze_token_type_embds: + freeze_token_type_embeddings(student, args) + + + ## SANITY CHECKS ## + assert student.config.vocab_size == teacher.config.vocab_size + assert student.config.hidden_size == teacher.config.hidden_size + assert student.config.max_position_embeddings == teacher.config.max_position_embeddings + if args.mlm: + assert token_probs.size(0) == stu_architecture_config.vocab_size + + ## DISTILLER ## torch.cuda.empty_cache() distiller = Distiller(params=args, - dataloader=train_dataloader, + dataset=train_lm_seq_dataset, token_probs=token_probs, student=student, teacher=teacher)