fix BlenderbotSmallTokenizer (#9538)
* add model_input_names * fix test
This commit is contained in:
@@ -92,6 +92,7 @@ class BlenderbotSmallTokenizer(PreTrainedTokenizer):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
max_model_input_sizes = {"facebook/blenderbot_small-90M": 512}
|
max_model_input_sizes = {"facebook/blenderbot_small-90M": 512}
|
||||||
|
model_input_names = ["attention_mask"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -288,8 +288,6 @@ class Blenderbot90MIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
model_inputs = self.tokenizer(src_text, return_tensors="pt").to(torch_device)
|
model_inputs = self.tokenizer(src_text, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
# model does not have "token_type_ids"
|
|
||||||
model_inputs.pop("token_type_ids")
|
|
||||||
assert isinstance(self.tokenizer, BlenderbotSmallTokenizer)
|
assert isinstance(self.tokenizer, BlenderbotSmallTokenizer)
|
||||||
generated_ids = self.model.generate(**model_inputs)[0]
|
generated_ids = self.model.generate(**model_inputs)[0]
|
||||||
reply = self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
reply = self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||||
@@ -302,8 +300,6 @@ class Blenderbot90MIntegrationTests(unittest.TestCase):
|
|||||||
def test_90_generation_from_short_input(self):
|
def test_90_generation_from_short_input(self):
|
||||||
model_inputs = self.tokenizer(["sam"], return_tensors="pt").to(torch_device)
|
model_inputs = self.tokenizer(["sam"], return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
# model does not have "token_type_ids"
|
|
||||||
model_inputs.pop("token_type_ids")
|
|
||||||
generated_utterances = self.model.generate(**model_inputs)
|
generated_utterances = self.model.generate(**model_inputs)
|
||||||
|
|
||||||
clean_txt = self.tokenizer.decode(
|
clean_txt = self.tokenizer.decode(
|
||||||
|
|||||||
Reference in New Issue
Block a user