From 8dfc8c7221d0b5febbff456913e2b7a1390e78eb Mon Sep 17 00:00:00 2001 From: Ethan Perez Date: Sat, 5 Dec 2020 07:52:16 -0700 Subject: [PATCH] Don't pass in token_type_ids to BART for GLUE (#8929) Without this fix, training a `BARTForSequenceClassification` model with `run_pl_glue.py` gives `TypeError: forward() got an unexpected keyword argument 'token_type_ids'`, because BART does not have token_type_ids. I've solved this issue in the same way as it's solved for the "distilbert" model, and I can train BART models on SNLI without errors now. --- examples/text-classification/run_pl_glue.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text-classification/run_pl_glue.py b/examples/text-classification/run_pl_glue.py index 500a0bd627..abb06bf526 100644 --- a/examples/text-classification/run_pl_glue.py +++ b/examples/text-classification/run_pl_glue.py @@ -38,7 +38,7 @@ class GLUETransformer(BaseTransformer): def training_step(self, batch, batch_idx): inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} - if self.config.model_type != "distilbert": + if self.config.model_type not in ["distilbert", "bart"]: inputs["token_type_ids"] = batch[2] if self.config.model_type in ["bert", "xlnet", "albert"] else None outputs = self(**inputs) @@ -101,7 +101,7 @@ class GLUETransformer(BaseTransformer): def validation_step(self, batch, batch_idx): inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} - if self.config.model_type != "distilbert": + if self.config.model_type not in ["distilbert", "bart"]: inputs["token_type_ids"] = batch[2] if self.config.model_type in ["bert", "xlnet", "albert"] else None outputs = self(**inputs)