seq2seq/run_eval.py can take decoder_start_token_id (#5949)
This commit is contained in:
@@ -2255,8 +2255,23 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
def batch_decode(self, sequences: List[List[int]], **kwargs) -> List[str]:
|
||||
return [self.decode(seq, **kwargs) for seq in sequences]
|
||||
def batch_decode(
|
||||
self, sequences: List[List[int]], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
|
||||
) -> List[str]:
|
||||
"""
|
||||
Convert a list of lists of token ids into a list of strings by calling decode.
|
||||
|
||||
Args:
|
||||
token_ids: list of tokenized input ids. Can be obtained using the `encode` or `encode_plus` methods.
|
||||
skip_special_tokens: if set to True, will replace special tokens.
|
||||
clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces.
|
||||
"""
|
||||
return [
|
||||
self.decode(
|
||||
seq, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=clean_up_tokenization_spaces
|
||||
)
|
||||
for seq in sequences
|
||||
]
|
||||
|
||||
def decode(
|
||||
self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
|
||||
|
||||
Reference in New Issue
Block a user