From 2e60276b38b261ae5f37786955e59130f4822a02 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Sat, 13 Nov 2021 20:57:12 +0530 Subject: [PATCH] [M2M100Tokenizer] fix _build_translation_inputs (#14382) * add return_tensors paramter * fix test * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * style Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/models/m2m_100/tokenization_m2m_100.py | 2 +- tests/test_tokenization_m2m_100.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/m2m_100/tokenization_m2m_100.py b/src/transformers/models/m2m_100/tokenization_m2m_100.py index f8df8fb70c..88ce4bd44d 100644 --- a/src/transformers/models/m2m_100/tokenization_m2m_100.py +++ b/src/transformers/models/m2m_100/tokenization_m2m_100.py @@ -332,7 +332,7 @@ class M2M100Tokenizer(PreTrainedTokenizer): if src_lang is None or tgt_lang is None: raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") self.src_lang = src_lang - inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs) + inputs = self(raw_inputs, add_special_tokens=True, **extra_kwargs) tgt_lang_id = self.get_lang_id(tgt_lang) inputs["forced_bos_token_id"] = tgt_lang_id return inputs diff --git a/tests/test_tokenization_m2m_100.py b/tests/test_tokenization_m2m_100.py index 1466a45e86..5c9e2083af 100644 --- a/tests/test_tokenization_m2m_100.py +++ b/tests/test_tokenization_m2m_100.py @@ -226,7 +226,7 @@ class M2M100TokenizerIntegrationTest(unittest.TestCase): @require_torch def test_tokenizer_translation(self): - inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en", tgt_lang="ar") + inputs = self.tokenizer._build_translation_inputs("A test", return_tensors="pt", src_lang="en", tgt_lang="ar") self.assertEqual( nested_simplify(inputs),