From 71b47505175111dd391a5b9de9514fbe50558bf0 Mon Sep 17 00:00:00 2001 From: Stefan Schweter Date: Mon, 16 Dec 2019 16:37:27 +0100 Subject: [PATCH] examples: add support for XLM-RoBERTa to run_ner script --- examples/run_ner.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/run_ner.py b/examples/run_ner.py index 1ab1236d94..6426a6d1db 100644 --- a/examples/run_ner.py +++ b/examples/run_ner.py @@ -38,11 +38,13 @@ from transformers import WEIGHTS_NAME, BertConfig, BertForTokenClassification, B from transformers import RobertaConfig, RobertaForTokenClassification, RobertaTokenizer from transformers import DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer from transformers import CamembertConfig, CamembertForTokenClassification, CamembertTokenizer +from transformers import XLMRobertaConfig, XLMRobertaForTokenClassification, XLMRobertaTokenizer logger = logging.getLogger(__name__) ALL_MODELS = sum( - (tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)), + (tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig, + CamembertConfig, XLMRobertaConfig)), ()) MODEL_CLASSES = { @@ -50,6 +52,7 @@ MODEL_CLASSES = { "roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer), "distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer), "camembert": (CamembertConfig, CamembertForTokenClassification, CamembertTokenizer), + "xlmroberta": (XLMRobertaConfig, XLMRobertaForTokenClassification, XLMRobertaTokenizer), }