From 69ed36063a732c37fdf72c605c65ebb5b2e85f44 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Wed, 13 Jan 2021 10:53:43 +0530 Subject: [PATCH] fix BlenderbotSmallTokenizer (#9538) * add model_input_names * fix test --- .../models/blenderbot_small/tokenization_blenderbot_small.py | 1 + tests/test_modeling_blenderbot_small.py | 4 ---- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/transformers/models/blenderbot_small/tokenization_blenderbot_small.py b/src/transformers/models/blenderbot_small/tokenization_blenderbot_small.py index c9e5f36d88..6f39da95d5 100644 --- a/src/transformers/models/blenderbot_small/tokenization_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/tokenization_blenderbot_small.py @@ -92,6 +92,7 @@ class BlenderbotSmallTokenizer(PreTrainedTokenizer): }, } max_model_input_sizes = {"facebook/blenderbot_small-90M": 512} + model_input_names = ["attention_mask"] def __init__( self, diff --git a/tests/test_modeling_blenderbot_small.py b/tests/test_modeling_blenderbot_small.py index f1a107a3dc..c3538c7be1 100644 --- a/tests/test_modeling_blenderbot_small.py +++ b/tests/test_modeling_blenderbot_small.py @@ -288,8 +288,6 @@ class Blenderbot90MIntegrationTests(unittest.TestCase): model_inputs = self.tokenizer(src_text, return_tensors="pt").to(torch_device) - # model does not have "token_type_ids" - model_inputs.pop("token_type_ids") assert isinstance(self.tokenizer, BlenderbotSmallTokenizer) generated_ids = self.model.generate(**model_inputs)[0] reply = self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) @@ -302,8 +300,6 @@ class Blenderbot90MIntegrationTests(unittest.TestCase): def test_90_generation_from_short_input(self): model_inputs = self.tokenizer(["sam"], return_tensors="pt").to(torch_device) - # model does not have "token_type_ids" - model_inputs.pop("token_type_ids") generated_utterances = self.model.generate(**model_inputs) clean_txt = self.tokenizer.decode(