Fix MarianTokenizer to remove metaspace character in decode (#26091)
* add: check to remove metaspace from marian tokenizer * fix: metaspace character being removed from everywhere * fix: remove redundant check at top * add: test for marian tokenizer decode fix * fix: simplified the test
This commit is contained in:
@@ -55,6 +55,8 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
|||||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"Helsinki-NLP/opus-mt-en-de": 512}
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"Helsinki-NLP/opus-mt-en-de": 512}
|
||||||
PRETRAINED_INIT_CONFIGURATION = {}
|
PRETRAINED_INIT_CONFIGURATION = {}
|
||||||
|
|
||||||
|
SPIECE_UNDERLINE = "▁"
|
||||||
|
|
||||||
# Example URL https://huggingface.co/Helsinki-NLP/opus-mt-en-de/resolve/main/vocab.json
|
# Example URL https://huggingface.co/Helsinki-NLP/opus-mt-en-de/resolve/main/vocab.json
|
||||||
|
|
||||||
|
|
||||||
@@ -278,6 +280,7 @@ class MarianTokenizer(PreTrainedTokenizer):
|
|||||||
else:
|
else:
|
||||||
current_sub_tokens.append(token)
|
current_sub_tokens.append(token)
|
||||||
out_string += sp_model.decode_pieces(current_sub_tokens)
|
out_string += sp_model.decode_pieces(current_sub_tokens)
|
||||||
|
out_string = out_string.replace(SPIECE_UNDERLINE, " ")
|
||||||
return out_string.strip()
|
return out_string.strip()
|
||||||
|
|
||||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
|
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
|
||||||
|
|||||||
@@ -149,3 +149,10 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
decoded = tokenizer.decode(target_ids, skip_special_tokens=True)
|
decoded = tokenizer.decode(target_ids, skip_special_tokens=True)
|
||||||
self.assertEqual(decoded, target_text)
|
self.assertEqual(decoded, target_text)
|
||||||
|
|
||||||
|
def test_tokenizer_decode(self):
|
||||||
|
tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-es")
|
||||||
|
source_text = "Hello World"
|
||||||
|
ids = tokenizer(source_text)["input_ids"]
|
||||||
|
output_text = tokenizer.decode(ids, skip_special_tokens=True)
|
||||||
|
self.assertEqual(source_text, output_text)
|
||||||
|
|||||||
Reference in New Issue
Block a user