[pipeline] Tokenizer should not add special tokens for text generation (#4686)
* allow to not add special tokens * remove print
This commit is contained in:
committed by
GitHub
parent
f6d5046af1
commit
47a551d17b
@@ -454,14 +454,17 @@ class Pipeline(_ScikitCompat):
|
|||||||
"""
|
"""
|
||||||
return {name: tensor.to(self.device) for name, tensor in inputs.items()}
|
return {name: tensor.to(self.device) for name, tensor in inputs.items()}
|
||||||
|
|
||||||
def _parse_and_tokenize(self, *args, pad_to_max_length=True, **kwargs):
|
def _parse_and_tokenize(self, *args, pad_to_max_length=True, add_special_tokens=True, **kwargs):
|
||||||
"""
|
"""
|
||||||
Parse arguments and tokenize
|
Parse arguments and tokenize
|
||||||
"""
|
"""
|
||||||
# Parse arguments
|
# Parse arguments
|
||||||
inputs = self._args_parser(*args, **kwargs)
|
inputs = self._args_parser(*args, **kwargs)
|
||||||
inputs = self.tokenizer.batch_encode_plus(
|
inputs = self.tokenizer.batch_encode_plus(
|
||||||
inputs, add_special_tokens=True, return_tensors=self.framework, pad_to_max_length=pad_to_max_length,
|
inputs,
|
||||||
|
add_special_tokens=add_special_tokens,
|
||||||
|
return_tensors=self.framework,
|
||||||
|
pad_to_max_length=pad_to_max_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
@@ -617,9 +620,11 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
# Manage correct placement of the tensors
|
# Manage correct placement of the tensors
|
||||||
with self.device_placement():
|
with self.device_placement():
|
||||||
if self.model.__class__.__name__ in ["XLNetLMHeadModel", "TransfoXLLMHeadModel"]:
|
if self.model.__class__.__name__ in ["XLNetLMHeadModel", "TransfoXLLMHeadModel"]:
|
||||||
inputs = self._parse_and_tokenize(self.PADDING_TEXT + prompt_text, pad_to_max_length=False)
|
inputs = self._parse_and_tokenize(
|
||||||
|
self.PADDING_TEXT + prompt_text, pad_to_max_length=False, add_special_tokens=False
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
inputs = self._parse_and_tokenize(prompt_text, pad_to_max_length=False)
|
inputs = self._parse_and_tokenize(prompt_text, pad_to_max_length=False, add_special_tokens=False)
|
||||||
|
|
||||||
# set input_ids to None to allow empty prompt
|
# set input_ids to None to allow empty prompt
|
||||||
if inputs["input_ids"].shape[-1] == 0:
|
if inputs["input_ids"].shape[-1] == 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user