From 8aa67fc192c1485f499e4c2dcb22bf8ad245160b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 27 Aug 2021 17:22:06 +0200 Subject: [PATCH] Fixing mbart50 with `return_tensors` argument too. (#13301) * Fixing mbart50 with `return_tensors` argument too. * Adding mbart50 tokenization tests. --- src/transformers/models/mbart50/tokenization_mbart50.py | 6 ++++-- .../models/mbart50/tokenization_mbart50_fast.py | 6 ++++-- tests/test_tokenization_mbart50.py | 4 +++- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/mbart50/tokenization_mbart50.py b/src/transformers/models/mbart50/tokenization_mbart50.py index 15a9551936..dbfd53a7fb 100644 --- a/src/transformers/models/mbart50/tokenization_mbart50.py +++ b/src/transformers/models/mbart50/tokenization_mbart50.py @@ -304,12 +304,14 @@ class MBart50Tokenizer(PreTrainedTokenizer): # We don't expect to process pairs, but leave the pair logic for API consistency return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens - def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs): + def _build_translation_inputs( + self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs + ): """Used by translation pipeline, to prepare inputs for the generate function""" if src_lang is None or tgt_lang is None: raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") self.src_lang = src_lang - inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs) + inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs) tgt_lang_id = self.convert_tokens_to_ids(tgt_lang) inputs["forced_bos_token_id"] = tgt_lang_id return inputs diff --git a/src/transformers/models/mbart50/tokenization_mbart50_fast.py b/src/transformers/models/mbart50/tokenization_mbart50_fast.py index 01980769e5..b3966f9c0b 100644 --- a/src/transformers/models/mbart50/tokenization_mbart50_fast.py +++ b/src/transformers/models/mbart50/tokenization_mbart50_fast.py @@ -245,12 +245,14 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast): special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)), ) - def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs): + def _build_translation_inputs( + self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs + ): """Used by translation pipeline, to prepare inputs for the generate function""" if src_lang is None or tgt_lang is None: raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") self.src_lang = src_lang - inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs) + inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs) tgt_lang_id = self.convert_tokens_to_ids(tgt_lang) inputs["forced_bos_token_id"] = tgt_lang_id return inputs diff --git a/tests/test_tokenization_mbart50.py b/tests/test_tokenization_mbart50.py index 88a0c62da9..1327822f03 100644 --- a/tests/test_tokenization_mbart50.py +++ b/tests/test_tokenization_mbart50.py @@ -229,7 +229,9 @@ class MBartOneToManyIntegrationTest(unittest.TestCase): @require_torch def test_tokenizer_translation(self): - inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en_XX", tgt_lang="ar_AR") + inputs = self.tokenizer._build_translation_inputs( + "A test", return_tensors="pt", src_lang="en_XX", tgt_lang="ar_AR" + ) self.assertEqual( nested_simplify(inputs),