diff --git a/examples/run_squad.py b/examples/run_squad.py index a8ac1d1b05..2df29014ef 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -44,7 +44,9 @@ from transformers import (WEIGHTS_NAME, BertConfig, XLNetForQuestionAnswering, XLNetTokenizer, DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer, - AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer) + AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer, + XLMConfig, XLMForQuestionAnswering, XLMTokenizer, + ) from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_examples_to_features @@ -58,7 +60,8 @@ MODEL_CLASSES = { 'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer), 'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer), 'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer), - 'albert': (AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer) + 'albert': (AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer), + 'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer) } def set_seed(args):