seq2seq/run_eval.py can take decoder_start_token_id (#5949)

This commit is contained in:
Sam Shleifer
2020-07-21 16:58:45 -04:00
committed by GitHub
parent 5b193b39b0
commit 9dab39feea
3 changed files with 35 additions and 3 deletions

View File

@@ -2255,8 +2255,23 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
return encoded_inputs
def batch_decode(self, sequences: List[List[int]], **kwargs) -> List[str]:
return [self.decode(seq, **kwargs) for seq in sequences]
def batch_decode(
self, sequences: List[List[int]], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
) -> List[str]:
"""
Convert a list of lists of token ids into a list of strings by calling decode.
Args:
token_ids: list of tokenized input ids. Can be obtained using the `encode` or `encode_plus` methods.
skip_special_tokens: if set to True, will replace special tokens.
clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces.
"""
return [
self.decode(
seq, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=clean_up_tokenization_spaces
)
for seq in sequences
]
def decode(
self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True