From 852e032ca6505f8ddd9881a7ed67ea0dd9fc7603 Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Sun, 1 Mar 2020 01:56:50 +0000 Subject: [PATCH] include roberta in run_squad_w_distillation - cc @graviraja --- examples/distillation/run_squad_w_distillation.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/distillation/run_squad_w_distillation.py b/examples/distillation/run_squad_w_distillation.py index 4900f19ead..3bbfaf482d 100644 --- a/examples/distillation/run_squad_w_distillation.py +++ b/examples/distillation/run_squad_w_distillation.py @@ -45,6 +45,9 @@ from transformers import ( XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer, + RobertaConfig, + RobertaForQuestionAnswering, + RobertaTokenizer, get_linear_schedule_with_warmup, squad_convert_examples_to_features, ) @@ -73,6 +76,7 @@ MODEL_CLASSES = { "xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer), "xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer), "distilbert": (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer), + "roberta": (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer), }