From de8e14b6c0d1a9c573835972221c0bda883d7f2a Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Wed, 11 Sep 2019 10:21:18 +0200 Subject: [PATCH] Added DistilBERT to run_squad script --- examples/run_squad.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/run_squad.py b/examples/run_squad.py index cc4eda306c..affef90ca9 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -37,7 +37,8 @@ from pytorch_transformers import (WEIGHTS_NAME, BertConfig, XLMConfig, XLMForQuestionAnswering, XLMTokenizer, XLNetConfig, XLNetForQuestionAnswering, - XLNetTokenizer) + XLNetTokenizer, + DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer) from pytorch_transformers import AdamW, WarmupLinearSchedule @@ -59,6 +60,7 @@ MODEL_CLASSES = { 'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer), 'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer), 'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer), + 'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer) } def set_seed(args):