diff --git a/examples/run_lm_finetuning.py b/examples/run_lm_finetuning.py index c33aa94a32..4acea00c55 100644 --- a/examples/run_lm_finetuning.py +++ b/examples/run_lm_finetuning.py @@ -47,7 +47,8 @@ from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, RobertaConfig, RobertaForMaskedLM, RobertaTokenizer, - DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer) + DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer, + CamembertConfig, CamembertForMaskedLM, CamembertTokenizer) logger = logging.getLogger(__name__) @@ -58,7 +59,8 @@ MODEL_CLASSES = { 'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), 'bert': (BertConfig, BertForMaskedLM, BertTokenizer), 'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer), - 'distilbert': (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer) + 'distilbert': (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer), + 'camembert': (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer) } @@ -432,7 +434,7 @@ def main(): parser.add_argument('--server_port', type=str, default='', help="For distant debugging.") args = parser.parse_args() - if args.model_type in ["bert", "roberta", "distilbert"] and not args.mlm: + if args.model_type in ["bert", "roberta", "distilbert", "camembert"] and not args.mlm: raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm " "flag (masked language modeling).") if args.eval_data_file is None and args.do_eval: