Added Camembert to available models
This commit is contained in:
committed by
Julien Chaumond
parent
ecf15ebf3b
commit
b0ee7c7df3
@@ -47,7 +47,8 @@ from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup,
|
|||||||
GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
|
GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
|
||||||
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
|
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
|
||||||
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer,
|
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer,
|
||||||
DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer)
|
DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer,
|
||||||
|
CamembertConfig, CamembertForMaskedLM, CamembertTokenizer)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -58,7 +59,8 @@ MODEL_CLASSES = {
|
|||||||
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
||||||
'bert': (BertConfig, BertForMaskedLM, BertTokenizer),
|
'bert': (BertConfig, BertForMaskedLM, BertTokenizer),
|
||||||
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
|
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
|
||||||
'distilbert': (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer)
|
'distilbert': (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
|
||||||
|
'camembert': (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -432,7 +434,7 @@ def main():
|
|||||||
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.model_type in ["bert", "roberta", "distilbert"] and not args.mlm:
|
if args.model_type in ["bert", "roberta", "distilbert", "camembert"] and not args.mlm:
|
||||||
raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
|
raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
|
||||||
"flag (masked language modeling).")
|
"flag (masked language modeling).")
|
||||||
if args.eval_data_file is None and args.do_eval:
|
if args.eval_data_file is None and args.do_eval:
|
||||||
|
|||||||
Reference in New Issue
Block a user