diff --git a/examples/distillation/run_squad_w_distillation.py b/examples/distillation/run_squad_w_distillation.py index 4900f19ead..3bbfaf482d 100644 --- a/examples/distillation/run_squad_w_distillation.py +++ b/examples/distillation/run_squad_w_distillation.py @@ -45,6 +45,9 @@ from transformers import ( XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer, + RobertaConfig, + RobertaForQuestionAnswering, + RobertaTokenizer, get_linear_schedule_with_warmup, squad_convert_examples_to_features, ) @@ -73,6 +76,7 @@ MODEL_CLASSES = { "xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer), "xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer), "distilbert": (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer), + "roberta": (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer), }