Added DistilBERT to run_squad script

This commit is contained in:
LysandreJik
2019-09-11 10:21:18 +02:00
parent 88368c2a16
commit de8e14b6c0

View File

@@ -37,7 +37,8 @@ from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
XLMConfig, XLMForQuestionAnswering, XLMConfig, XLMForQuestionAnswering,
XLMTokenizer, XLNetConfig, XLMTokenizer, XLNetConfig,
XLNetForQuestionAnswering, XLNetForQuestionAnswering,
XLNetTokenizer) XLNetTokenizer,
DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer)
from pytorch_transformers import AdamW, WarmupLinearSchedule from pytorch_transformers import AdamW, WarmupLinearSchedule
@@ -59,6 +60,7 @@ MODEL_CLASSES = {
'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer), 'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer), 'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer), 'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer)
} }
def set_seed(args): def set_seed(args):