From b880508440f43f80e35a78ccd2a32f3bde91cb23 Mon Sep 17 00:00:00 2001 From: Mehrad Moradshahi Date: Mon, 8 Mar 2021 05:14:31 -0800 Subject: [PATCH] tokenization_marian.py: use current_spm for decoding (#10357) * Fix Marian decoding Tokenizer's decode and batch_decode now accepts a new argument (use_source_tokenizer) which indicates whether the source spm should be used to decode ids. This is useful for Marian models specificallly when decoding source input ids. * Adapt docstrings Co-authored-by: Sylvain Gugger --- .../models/marian/tokenization_marian.py | 58 +++++++++++++++++-- .../models/wav2vec2/tokenization_wav2vec2.py | 1 + src/transformers/tokenization_utils.py | 5 ++ src/transformers/tokenization_utils_fast.py | 4 ++ 4 files changed, 64 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/marian/tokenization_marian.py b/src/transformers/models/marian/tokenization_marian.py index a12f8451a9..dadc9e2c64 100644 --- a/src/transformers/models/marian/tokenization_marian.py +++ b/src/transformers/models/marian/tokenization_marian.py @@ -159,7 +159,7 @@ class MarianTokenizer(PreTrainedTokenizer): return self.encoder.get(token, self.encoder[self.unk_token]) def remove_language_code(self, text: str): - """Remove language codes like <> before sentencepiece""" + """Remove language codes like >>fr<< before sentencepiece""" match = self.language_code_re.match(text) code: list = [match.group(0)] if match else [] return code, self.language_code_re.sub("", text) @@ -170,12 +170,62 @@ class MarianTokenizer(PreTrainedTokenizer): return code + pieces def _convert_id_to_token(self, index: int) -> str: - """Converts an index (integer) in a token (str) using the encoder.""" + """Converts an index (integer) in a token (str) using the decoder.""" return self.decoder.get(index, self.unk_token) + def batch_decode(self, sequences, **kwargs): + """ + Convert a list of lists of token ids into a list of strings by calling decode. + + Args: + sequences (:obj:`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`): + List of tokenized input ids. Can be obtained using the ``__call__`` method. + skip_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to remove special tokens in the decoding. + clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to clean up the tokenization spaces. + use_source_tokenizer (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to use the source tokenizer to decode sequences (only applicable in sequence-to-sequence + problems). + kwargs (additional keyword arguments, `optional`): + Will be passed to the underlying model specific decode method. + + Returns: + :obj:`List[str]`: The list of decoded sentences. + """ + return super().batch_decode(sequences, **kwargs) + + def decode(self, token_ids, **kwargs): + """ + Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special + tokens and clean up tokenization spaces. + + Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``. + + Args: + token_ids (:obj:`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`): + List of tokenized input ids. Can be obtained using the ``__call__`` method. + skip_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to remove special tokens in the decoding. + clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to clean up the tokenization spaces. + use_source_tokenizer (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to use the source tokenizer to decode sequences (only applicable in sequence-to-sequence + problems). + kwargs (additional keyword arguments, `optional`): + Will be passed to the underlying model specific decode method. + + Returns: + :obj:`str`: The decoded sentence. + """ + return super().decode(token_ids, **kwargs) + def convert_tokens_to_string(self, tokens: List[str]) -> str: - """Uses target language sentencepiece model""" - return self.spm_target.DecodePieces(tokens) + """Uses source spm if _decode_use_source_tokenizer is True, and target spm otherwise """ + if self._decode_use_source_tokenizer: + return self.spm_source.DecodePieces(tokens) + else: + return self.spm_target.DecodePieces(tokens) def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]: """Build model inputs from a sequence by appending eos_token_id.""" diff --git a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py index 3735215073..28c18b0934 100644 --- a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py +++ b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py @@ -486,6 +486,7 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer): token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True, + **kwargs ) -> str: """ special _decode function is needed for Wav2Vec2Tokenizer because added tokens should be treated exactly the diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index b5f55faf35..5ae55b80f2 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -122,6 +122,8 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): self.added_tokens_decoder: Dict[int, str] = {} self.unique_no_split_tokens: List[str] = [] + self._decode_use_source_tokenizer = False + @property def is_fast(self) -> bool: return False @@ -702,7 +704,10 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True, spaces_between_special_tokens: bool = True, + **kwargs ) -> str: + self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False) + filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) # To avoid mixing byte-level and unicode for byte-level BPT diff --git a/src/transformers/tokenization_utils_fast.py b/src/transformers/tokenization_utils_fast.py index 2d33aa7a4e..1f476585b0 100644 --- a/src/transformers/tokenization_utils_fast.py +++ b/src/transformers/tokenization_utils_fast.py @@ -106,6 +106,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): if slow_tokenizer is not None: kwargs.update(slow_tokenizer.init_kwargs) + self._decode_use_source_tokenizer = False + # We call this after having initialized the backend tokenizer because we update it. super().__init__(**kwargs) @@ -491,6 +493,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): clean_up_tokenization_spaces: bool = True, **kwargs ) -> str: + self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False) + if isinstance(token_ids, int): token_ids = [token_ids] text = self._tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)