Fixing mbart50 with return_tensors argument too. (#13301)
* Fixing mbart50 with `return_tensors` argument too. * Adding mbart50 tokenization tests.
This commit is contained in:
@@ -304,12 +304,14 @@ class MBart50Tokenizer(PreTrainedTokenizer):
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||
|
||||
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
|
||||
def _build_translation_inputs(
|
||||
self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
|
||||
):
|
||||
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
||||
if src_lang is None or tgt_lang is None:
|
||||
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
||||
self.src_lang = src_lang
|
||||
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
|
||||
inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
|
||||
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
|
||||
inputs["forced_bos_token_id"] = tgt_lang_id
|
||||
return inputs
|
||||
|
||||
@@ -245,12 +245,14 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast):
|
||||
special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
|
||||
)
|
||||
|
||||
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
|
||||
def _build_translation_inputs(
|
||||
self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
|
||||
):
|
||||
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
||||
if src_lang is None or tgt_lang is None:
|
||||
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
||||
self.src_lang = src_lang
|
||||
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
|
||||
inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
|
||||
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
|
||||
inputs["forced_bos_token_id"] = tgt_lang_id
|
||||
return inputs
|
||||
|
||||
@@ -229,7 +229,9 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
def test_tokenizer_translation(self):
|
||||
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en_XX", tgt_lang="ar_AR")
|
||||
inputs = self.tokenizer._build_translation_inputs(
|
||||
"A test", return_tensors="pt", src_lang="en_XX", tgt_lang="ar_AR"
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(inputs),
|
||||
|
||||
Reference in New Issue
Block a user