diff --git a/examples/run_ner.py b/examples/run_ner.py index b35d8298fe..1c5774df97 100644 --- a/examples/run_ner.py +++ b/examples/run_ner.py @@ -36,16 +36,18 @@ from utils_ner import convert_examples_to_features, get_labels, read_examples_fr from transformers import AdamW, WarmupLinearSchedule from transformers import WEIGHTS_NAME, BertConfig, BertForTokenClassification, BertTokenizer from transformers import RobertaConfig, RobertaForTokenClassification, RobertaTokenizer +from transformers import DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer logger = logging.getLogger(__name__) ALL_MODELS = sum( - (tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig)), + (tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)), ()) MODEL_CLASSES = { "bert": (BertConfig, BertForTokenClassification, BertTokenizer), - "roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer) + "roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer), + "distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer) } @@ -121,9 +123,10 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id): batch = tuple(t.to(args.device) for t in batch) inputs = {"input_ids": batch[0], "attention_mask": batch[1], - "token_type_ids": batch[2] if args.model_type in ["bert", "xlnet"] else None, - # XLM and RoBERTa don"t use segment_ids "labels": batch[3]} + if args.model_type != "distilbert": + inputs["token_type_ids"]: batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids + outputs = model(**inputs) loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc) @@ -206,9 +209,9 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix="" with torch.no_grad(): inputs = {"input_ids": batch[0], "attention_mask": batch[1], - "token_type_ids": batch[2] if args.model_type in ["bert", "xlnet"] else None, - # XLM and RoBERTa don"t use segment_ids "labels": batch[3]} + if args.model_type != "distilbert": + inputs["token_type_ids"]: batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids outputs = model(**inputs) tmp_eval_loss, logits = outputs[:2] @@ -520,3 +523,4 @@ def main(): if __name__ == "__main__": main() +