Fixing mbart50 with return_tensors argument too. (#13301)

* Fixing mbart50 with `return_tensors` argument too.

* Adding mbart50 tokenization tests.
This commit is contained in:
Nicolas Patry
2021-08-27 17:22:06 +02:00
committed by GitHub
parent b89a964d3f
commit 8aa67fc192
3 changed files with 11 additions and 5 deletions

View File

@@ -304,12 +304,14 @@ class MBart50Tokenizer(PreTrainedTokenizer):
# We don't expect to process pairs, but leave the pair logic for API consistency # We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs): def _build_translation_inputs(
self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
):
"""Used by translation pipeline, to prepare inputs for the generate function""" """Used by translation pipeline, to prepare inputs for the generate function"""
if src_lang is None or tgt_lang is None: if src_lang is None or tgt_lang is None:
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
self.src_lang = src_lang self.src_lang = src_lang
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs) inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang) tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
inputs["forced_bos_token_id"] = tgt_lang_id inputs["forced_bos_token_id"] = tgt_lang_id
return inputs return inputs

View File

@@ -245,12 +245,14 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast):
special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)), special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
) )
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs): def _build_translation_inputs(
self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
):
"""Used by translation pipeline, to prepare inputs for the generate function""" """Used by translation pipeline, to prepare inputs for the generate function"""
if src_lang is None or tgt_lang is None: if src_lang is None or tgt_lang is None:
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
self.src_lang = src_lang self.src_lang = src_lang
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs) inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang) tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
inputs["forced_bos_token_id"] = tgt_lang_id inputs["forced_bos_token_id"] = tgt_lang_id
return inputs return inputs

View File

@@ -229,7 +229,9 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
@require_torch @require_torch
def test_tokenizer_translation(self): def test_tokenizer_translation(self):
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en_XX", tgt_lang="ar_AR") inputs = self.tokenizer._build_translation_inputs(
"A test", return_tensors="pt", src_lang="en_XX", tgt_lang="ar_AR"
)
self.assertEqual( self.assertEqual(
nested_simplify(inputs), nested_simplify(inputs),