tokenization_marian.py: use current_spm for decoding (#10357)
* Fix Marian decoding Tokenizer's decode and batch_decode now accepts a new argument (use_source_tokenizer) which indicates whether the source spm should be used to decode ids. This is useful for Marian models specificallly when decoding source input ids. * Adapt docstrings Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com>
This commit is contained in:
committed by
GitHub
parent
8fd7eb34e2
commit
b880508440
@@ -159,7 +159,7 @@ class MarianTokenizer(PreTrainedTokenizer):
|
||||
return self.encoder.get(token, self.encoder[self.unk_token])
|
||||
|
||||
def remove_language_code(self, text: str):
|
||||
"""Remove language codes like <<fr>> before sentencepiece"""
|
||||
"""Remove language codes like >>fr<< before sentencepiece"""
|
||||
match = self.language_code_re.match(text)
|
||||
code: list = [match.group(0)] if match else []
|
||||
return code, self.language_code_re.sub("", text)
|
||||
@@ -170,12 +170,62 @@ class MarianTokenizer(PreTrainedTokenizer):
|
||||
return code + pieces
|
||||
|
||||
def _convert_id_to_token(self, index: int) -> str:
|
||||
"""Converts an index (integer) in a token (str) using the encoder."""
|
||||
"""Converts an index (integer) in a token (str) using the decoder."""
|
||||
return self.decoder.get(index, self.unk_token)
|
||||
|
||||
def batch_decode(self, sequences, **kwargs):
|
||||
"""
|
||||
Convert a list of lists of token ids into a list of strings by calling decode.
|
||||
|
||||
Args:
|
||||
sequences (:obj:`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`):
|
||||
List of tokenized input ids. Can be obtained using the ``__call__`` method.
|
||||
skip_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to remove special tokens in the decoding.
|
||||
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to clean up the tokenization spaces.
|
||||
use_source_tokenizer (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to use the source tokenizer to decode sequences (only applicable in sequence-to-sequence
|
||||
problems).
|
||||
kwargs (additional keyword arguments, `optional`):
|
||||
Will be passed to the underlying model specific decode method.
|
||||
|
||||
Returns:
|
||||
:obj:`List[str]`: The list of decoded sentences.
|
||||
"""
|
||||
return super().batch_decode(sequences, **kwargs)
|
||||
|
||||
def decode(self, token_ids, **kwargs):
|
||||
"""
|
||||
Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
|
||||
tokens and clean up tokenization spaces.
|
||||
|
||||
Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
|
||||
|
||||
Args:
|
||||
token_ids (:obj:`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
|
||||
List of tokenized input ids. Can be obtained using the ``__call__`` method.
|
||||
skip_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to remove special tokens in the decoding.
|
||||
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to clean up the tokenization spaces.
|
||||
use_source_tokenizer (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to use the source tokenizer to decode sequences (only applicable in sequence-to-sequence
|
||||
problems).
|
||||
kwargs (additional keyword arguments, `optional`):
|
||||
Will be passed to the underlying model specific decode method.
|
||||
|
||||
Returns:
|
||||
:obj:`str`: The decoded sentence.
|
||||
"""
|
||||
return super().decode(token_ids, **kwargs)
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
"""Uses target language sentencepiece model"""
|
||||
return self.spm_target.DecodePieces(tokens)
|
||||
"""Uses source spm if _decode_use_source_tokenizer is True, and target spm otherwise """
|
||||
if self._decode_use_source_tokenizer:
|
||||
return self.spm_source.DecodePieces(tokens)
|
||||
else:
|
||||
return self.spm_target.DecodePieces(tokens)
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
|
||||
"""Build model inputs from a sequence by appending eos_token_id."""
|
||||
|
||||
@@ -486,6 +486,7 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
|
||||
token_ids: List[int],
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: bool = True,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
special _decode function is needed for Wav2Vec2Tokenizer because added tokens should be treated exactly the
|
||||
|
||||
Reference in New Issue
Block a user