diff --git a/examples/run_ner.py b/examples/run_ner.py index 00eb039258..16fa89c3e7 100644 --- a/examples/run_ner.py +++ b/examples/run_ner.py @@ -35,15 +35,17 @@ from utils_ner import convert_examples_to_features, get_labels, read_examples_fr from transformers import AdamW, WarmupLinearSchedule from transformers import WEIGHTS_NAME, BertConfig, BertForTokenClassification, BertTokenizer +from transformers import RobertaConfig, RobertaForTokenClassification, RobertaTokenizer logger = logging.getLogger(__name__) ALL_MODELS = sum( - (tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, )), + (tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig)), ()) MODEL_CLASSES = { "bert": (BertConfig, BertForTokenClassification, BertTokenizer), + "roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer) }