diff --git a/src/transformers/models/mbart50/tokenization_mbart50.py b/src/transformers/models/mbart50/tokenization_mbart50.py index 15a9551936..dbfd53a7fb 100644 --- a/src/transformers/models/mbart50/tokenization_mbart50.py +++ b/src/transformers/models/mbart50/tokenization_mbart50.py @@ -304,12 +304,14 @@ class MBart50Tokenizer(PreTrainedTokenizer): # We don't expect to process pairs, but leave the pair logic for API consistency return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens - def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs): + def _build_translation_inputs( + self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs + ): """Used by translation pipeline, to prepare inputs for the generate function""" if src_lang is None or tgt_lang is None: raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") self.src_lang = src_lang - inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs) + inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs) tgt_lang_id = self.convert_tokens_to_ids(tgt_lang) inputs["forced_bos_token_id"] = tgt_lang_id return inputs diff --git a/src/transformers/models/mbart50/tokenization_mbart50_fast.py b/src/transformers/models/mbart50/tokenization_mbart50_fast.py index 01980769e5..b3966f9c0b 100644 --- a/src/transformers/models/mbart50/tokenization_mbart50_fast.py +++ b/src/transformers/models/mbart50/tokenization_mbart50_fast.py @@ -245,12 +245,14 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast): special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)), ) - def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs): + def _build_translation_inputs( + self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs + ): """Used by translation pipeline, to prepare inputs for the generate function""" if src_lang is None or tgt_lang is None: raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") self.src_lang = src_lang - inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs) + inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs) tgt_lang_id = self.convert_tokens_to_ids(tgt_lang) inputs["forced_bos_token_id"] = tgt_lang_id return inputs diff --git a/tests/test_tokenization_mbart50.py b/tests/test_tokenization_mbart50.py index 88a0c62da9..1327822f03 100644 --- a/tests/test_tokenization_mbart50.py +++ b/tests/test_tokenization_mbart50.py @@ -229,7 +229,9 @@ class MBartOneToManyIntegrationTest(unittest.TestCase): @require_torch def test_tokenizer_translation(self): - inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en_XX", tgt_lang="ar_AR") + inputs = self.tokenizer._build_translation_inputs( + "A test", return_tensors="pt", src_lang="en_XX", tgt_lang="ar_AR" + ) self.assertEqual( nested_simplify(inputs),