From 12f043eaeaabfef6f6efea411d98e6f6d3c094b7 Mon Sep 17 00:00:00 2001 From: Tanay Mehta Date: Wed, 13 Sep 2023 01:23:31 +0530 Subject: [PATCH] Fix `MarianTokenizer` to remove metaspace character in `decode` (#26091) * add: check to remove metaspace from marian tokenizer * fix: metaspace character being removed from everywhere * fix: remove redundant check at top * add: test for marian tokenizer decode fix * fix: simplified the test --- src/transformers/models/marian/tokenization_marian.py | 3 +++ tests/models/marian/test_tokenization_marian.py | 7 +++++++ 2 files changed, 10 insertions(+) diff --git a/src/transformers/models/marian/tokenization_marian.py b/src/transformers/models/marian/tokenization_marian.py index 96a1f47bf7..2736b03a01 100644 --- a/src/transformers/models/marian/tokenization_marian.py +++ b/src/transformers/models/marian/tokenization_marian.py @@ -55,6 +55,8 @@ PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"Helsinki-NLP/opus-mt-en-de": 512} PRETRAINED_INIT_CONFIGURATION = {} +SPIECE_UNDERLINE = "▁" + # Example URL https://huggingface.co/Helsinki-NLP/opus-mt-en-de/resolve/main/vocab.json @@ -278,6 +280,7 @@ class MarianTokenizer(PreTrainedTokenizer): else: current_sub_tokens.append(token) out_string += sp_model.decode_pieces(current_sub_tokens) + out_string = out_string.replace(SPIECE_UNDERLINE, " ") return out_string.strip() def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]: diff --git a/tests/models/marian/test_tokenization_marian.py b/tests/models/marian/test_tokenization_marian.py index fae0edfa68..f32026be1a 100644 --- a/tests/models/marian/test_tokenization_marian.py +++ b/tests/models/marian/test_tokenization_marian.py @@ -149,3 +149,10 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase): decoded = tokenizer.decode(target_ids, skip_special_tokens=True) self.assertEqual(decoded, target_text) + + def test_tokenizer_decode(self): + tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-es") + source_text = "Hello World" + ids = tokenizer(source_text)["input_ids"] + output_text = tokenizer.decode(ids, skip_special_tokens=True) + self.assertEqual(source_text, output_text)