tokenization_marian.py: use current_spm for decoding (#10357)
* Fix Marian decoding Tokenizer's decode and batch_decode now accepts a new argument (use_source_tokenizer) which indicates whether the source spm should be used to decode ids. This is useful for Marian models specificallly when decoding source input ids. * Adapt docstrings Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com>
This commit is contained in:
committed by
GitHub
parent
8fd7eb34e2
commit
b880508440
@@ -159,7 +159,7 @@ class MarianTokenizer(PreTrainedTokenizer):
|
|||||||
return self.encoder.get(token, self.encoder[self.unk_token])
|
return self.encoder.get(token, self.encoder[self.unk_token])
|
||||||
|
|
||||||
def remove_language_code(self, text: str):
|
def remove_language_code(self, text: str):
|
||||||
"""Remove language codes like <<fr>> before sentencepiece"""
|
"""Remove language codes like >>fr<< before sentencepiece"""
|
||||||
match = self.language_code_re.match(text)
|
match = self.language_code_re.match(text)
|
||||||
code: list = [match.group(0)] if match else []
|
code: list = [match.group(0)] if match else []
|
||||||
return code, self.language_code_re.sub("", text)
|
return code, self.language_code_re.sub("", text)
|
||||||
@@ -170,12 +170,62 @@ class MarianTokenizer(PreTrainedTokenizer):
|
|||||||
return code + pieces
|
return code + pieces
|
||||||
|
|
||||||
def _convert_id_to_token(self, index: int) -> str:
|
def _convert_id_to_token(self, index: int) -> str:
|
||||||
"""Converts an index (integer) in a token (str) using the encoder."""
|
"""Converts an index (integer) in a token (str) using the decoder."""
|
||||||
return self.decoder.get(index, self.unk_token)
|
return self.decoder.get(index, self.unk_token)
|
||||||
|
|
||||||
|
def batch_decode(self, sequences, **kwargs):
|
||||||
|
"""
|
||||||
|
Convert a list of lists of token ids into a list of strings by calling decode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sequences (:obj:`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`):
|
||||||
|
List of tokenized input ids. Can be obtained using the ``__call__`` method.
|
||||||
|
skip_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to remove special tokens in the decoding.
|
||||||
|
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether or not to clean up the tokenization spaces.
|
||||||
|
use_source_tokenizer (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to use the source tokenizer to decode sequences (only applicable in sequence-to-sequence
|
||||||
|
problems).
|
||||||
|
kwargs (additional keyword arguments, `optional`):
|
||||||
|
Will be passed to the underlying model specific decode method.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`List[str]`: The list of decoded sentences.
|
||||||
|
"""
|
||||||
|
return super().batch_decode(sequences, **kwargs)
|
||||||
|
|
||||||
|
def decode(self, token_ids, **kwargs):
|
||||||
|
"""
|
||||||
|
Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
|
||||||
|
tokens and clean up tokenization spaces.
|
||||||
|
|
||||||
|
Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids (:obj:`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
|
||||||
|
List of tokenized input ids. Can be obtained using the ``__call__`` method.
|
||||||
|
skip_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to remove special tokens in the decoding.
|
||||||
|
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether or not to clean up the tokenization spaces.
|
||||||
|
use_source_tokenizer (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to use the source tokenizer to decode sequences (only applicable in sequence-to-sequence
|
||||||
|
problems).
|
||||||
|
kwargs (additional keyword arguments, `optional`):
|
||||||
|
Will be passed to the underlying model specific decode method.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`str`: The decoded sentence.
|
||||||
|
"""
|
||||||
|
return super().decode(token_ids, **kwargs)
|
||||||
|
|
||||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||||
"""Uses target language sentencepiece model"""
|
"""Uses source spm if _decode_use_source_tokenizer is True, and target spm otherwise """
|
||||||
return self.spm_target.DecodePieces(tokens)
|
if self._decode_use_source_tokenizer:
|
||||||
|
return self.spm_source.DecodePieces(tokens)
|
||||||
|
else:
|
||||||
|
return self.spm_target.DecodePieces(tokens)
|
||||||
|
|
||||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
|
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
|
||||||
"""Build model inputs from a sequence by appending eos_token_id."""
|
"""Build model inputs from a sequence by appending eos_token_id."""
|
||||||
|
|||||||
@@ -486,6 +486,7 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
|
|||||||
token_ids: List[int],
|
token_ids: List[int],
|
||||||
skip_special_tokens: bool = False,
|
skip_special_tokens: bool = False,
|
||||||
clean_up_tokenization_spaces: bool = True,
|
clean_up_tokenization_spaces: bool = True,
|
||||||
|
**kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
special _decode function is needed for Wav2Vec2Tokenizer because added tokens should be treated exactly the
|
special _decode function is needed for Wav2Vec2Tokenizer because added tokens should be treated exactly the
|
||||||
|
|||||||
@@ -122,6 +122,8 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||||||
self.added_tokens_decoder: Dict[int, str] = {}
|
self.added_tokens_decoder: Dict[int, str] = {}
|
||||||
self.unique_no_split_tokens: List[str] = []
|
self.unique_no_split_tokens: List[str] = []
|
||||||
|
|
||||||
|
self._decode_use_source_tokenizer = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_fast(self) -> bool:
|
def is_fast(self) -> bool:
|
||||||
return False
|
return False
|
||||||
@@ -702,7 +704,10 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||||||
skip_special_tokens: bool = False,
|
skip_special_tokens: bool = False,
|
||||||
clean_up_tokenization_spaces: bool = True,
|
clean_up_tokenization_spaces: bool = True,
|
||||||
spaces_between_special_tokens: bool = True,
|
spaces_between_special_tokens: bool = True,
|
||||||
|
**kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
|
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
|
||||||
|
|
||||||
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
||||||
|
|
||||||
# To avoid mixing byte-level and unicode for byte-level BPT
|
# To avoid mixing byte-level and unicode for byte-level BPT
|
||||||
|
|||||||
@@ -106,6 +106,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
|||||||
if slow_tokenizer is not None:
|
if slow_tokenizer is not None:
|
||||||
kwargs.update(slow_tokenizer.init_kwargs)
|
kwargs.update(slow_tokenizer.init_kwargs)
|
||||||
|
|
||||||
|
self._decode_use_source_tokenizer = False
|
||||||
|
|
||||||
# We call this after having initialized the backend tokenizer because we update it.
|
# We call this after having initialized the backend tokenizer because we update it.
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
@@ -491,6 +493,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
|||||||
clean_up_tokenization_spaces: bool = True,
|
clean_up_tokenization_spaces: bool = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
|
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
|
||||||
|
|
||||||
if isinstance(token_ids, int):
|
if isinstance(token_ids, int):
|
||||||
token_ids = [token_ids]
|
token_ids = [token_ids]
|
||||||
text = self._tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
text = self._tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||||
|
|||||||
Reference in New Issue
Block a user