From d39da5a2abf8d7db206343647cee0c4608c302b9 Mon Sep 17 00:00:00 2001 From: Ethan Perez Date: Fri, 23 Oct 2020 07:34:06 -0700 Subject: [PATCH] Handling longformer model_type (#7990) Updating the run_squad training script to handle the "longformer" `model_type`. The longformer is trained in the same was as RoBERTa, so I've added the "longformer" `model_type` (that's the right hugginface name for the LongFormer model, right?) everywhere there was a "roberta" `model_type` reference. The longformer (like RoBERTa) doesn't use `token_type_ids` (as I understand from looking at the [longformer notebook](https://github.com/patil-suraj/Notebooks/blob/master/longformer_qa_training.ipynb), which is what gets updated after this change. This fix might be related to [this issue](https://github.com/huggingface/transformers/issues/7249) with SQuAD training when using run_squad.py --- examples/question-answering/run_squad.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/question-answering/run_squad.py b/examples/question-answering/run_squad.py index 70fc04f9d3..59550347c2 100644 --- a/examples/question-answering/run_squad.py +++ b/examples/question-answering/run_squad.py @@ -187,7 +187,7 @@ def train(args, train_dataset, model, tokenizer): "end_positions": batch[4], } - if args.model_type in ["xlm", "roberta", "distilbert", "camembert", "bart"]: + if args.model_type in ["xlm", "roberta", "distilbert", "camembert", "bart", "longformer"]: del inputs["token_type_ids"] if args.model_type in ["xlnet", "xlm"]: @@ -300,7 +300,7 @@ def evaluate(args, model, tokenizer, prefix=""): "token_type_ids": batch[2], } - if args.model_type in ["xlm", "roberta", "distilbert", "camembert", "bart"]: + if args.model_type in ["xlm", "roberta", "distilbert", "camembert", "bart", "longformer"]: del inputs["token_type_ids"] feature_indices = batch[3]