examples: add DistilBert support for NER fine-tuning
This commit is contained in:
@@ -36,16 +36,18 @@ from utils_ner import convert_examples_to_features, get_labels, read_examples_fr
|
|||||||
from transformers import AdamW, WarmupLinearSchedule
|
from transformers import AdamW, WarmupLinearSchedule
|
||||||
from transformers import WEIGHTS_NAME, BertConfig, BertForTokenClassification, BertTokenizer
|
from transformers import WEIGHTS_NAME, BertConfig, BertForTokenClassification, BertTokenizer
|
||||||
from transformers import RobertaConfig, RobertaForTokenClassification, RobertaTokenizer
|
from transformers import RobertaConfig, RobertaForTokenClassification, RobertaTokenizer
|
||||||
|
from transformers import DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MODELS = sum(
|
ALL_MODELS = sum(
|
||||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig)),
|
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)),
|
||||||
())
|
())
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
"bert": (BertConfig, BertForTokenClassification, BertTokenizer),
|
"bert": (BertConfig, BertForTokenClassification, BertTokenizer),
|
||||||
"roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer)
|
"roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer),
|
||||||
|
"distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -121,9 +123,10 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
|
|||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
inputs = {"input_ids": batch[0],
|
inputs = {"input_ids": batch[0],
|
||||||
"attention_mask": batch[1],
|
"attention_mask": batch[1],
|
||||||
"token_type_ids": batch[2] if args.model_type in ["bert", "xlnet"] else None,
|
|
||||||
# XLM and RoBERTa don"t use segment_ids
|
|
||||||
"labels": batch[3]}
|
"labels": batch[3]}
|
||||||
|
if args.model_type != "distilbert":
|
||||||
|
inputs["token_type_ids"]: batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
|
||||||
|
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
||||||
|
|
||||||
@@ -206,9 +209,9 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
inputs = {"input_ids": batch[0],
|
inputs = {"input_ids": batch[0],
|
||||||
"attention_mask": batch[1],
|
"attention_mask": batch[1],
|
||||||
"token_type_ids": batch[2] if args.model_type in ["bert", "xlnet"] else None,
|
|
||||||
# XLM and RoBERTa don"t use segment_ids
|
|
||||||
"labels": batch[3]}
|
"labels": batch[3]}
|
||||||
|
if args.model_type != "distilbert":
|
||||||
|
inputs["token_type_ids"]: batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
tmp_eval_loss, logits = outputs[:2]
|
tmp_eval_loss, logits = outputs[:2]
|
||||||
|
|
||||||
@@ -520,3 +523,4 @@ def main():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user