From 88368c2a16d26bc2d00dc28f79196c81373d3a71 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Fri, 6 Sep 2019 18:05:56 -0400 Subject: [PATCH] Added DistilBERT to `run_lm_finetuning` --- examples/run_lm_finetuning.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/run_lm_finetuning.py b/examples/run_lm_finetuning.py index 4d14fe7ebb..974a84c34e 100644 --- a/examples/run_lm_finetuning.py +++ b/examples/run_lm_finetuning.py @@ -39,7 +39,8 @@ from pytorch_transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule, BertConfig, BertForMaskedLM, BertTokenizer, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, - RobertaConfig, RobertaForMaskedLM, RobertaTokenizer) + RobertaConfig, RobertaForMaskedLM, RobertaTokenizer, + DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer) logger = logging.getLogger(__name__) @@ -49,7 +50,8 @@ MODEL_CLASSES = { 'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer), 'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), 'bert': (BertConfig, BertForMaskedLM, BertTokenizer), - 'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer) + 'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer), + 'distilbert': (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer) } @@ -380,7 +382,7 @@ def main(): parser.add_argument('--server_port', type=str, default='', help="For distant debugging.") args = parser.parse_args() - if args.model_type in ["bert", "roberta"] and not args.mlm: + if args.model_type in ["bert", "roberta", "distilbert"] and not args.mlm: raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm " "flag (masked language modeling).") if args.eval_data_file is None and args.do_eval: