Fixing mbart50 with return_tensors argument too. (#13301)
* Fixing mbart50 with `return_tensors` argument too. * Adding mbart50 tokenization tests.
This commit is contained in:
@@ -304,12 +304,14 @@ class MBart50Tokenizer(PreTrainedTokenizer):
|
|||||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||||
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||||
|
|
||||||
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
|
def _build_translation_inputs(
|
||||||
|
self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
|
||||||
|
):
|
||||||
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
||||||
if src_lang is None or tgt_lang is None:
|
if src_lang is None or tgt_lang is None:
|
||||||
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
||||||
self.src_lang = src_lang
|
self.src_lang = src_lang
|
||||||
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
|
inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
|
||||||
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
|
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
|
||||||
inputs["forced_bos_token_id"] = tgt_lang_id
|
inputs["forced_bos_token_id"] = tgt_lang_id
|
||||||
return inputs
|
return inputs
|
||||||
|
|||||||
@@ -245,12 +245,14 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast):
|
|||||||
special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
|
special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
|
def _build_translation_inputs(
|
||||||
|
self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
|
||||||
|
):
|
||||||
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
||||||
if src_lang is None or tgt_lang is None:
|
if src_lang is None or tgt_lang is None:
|
||||||
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
||||||
self.src_lang = src_lang
|
self.src_lang = src_lang
|
||||||
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
|
inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
|
||||||
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
|
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
|
||||||
inputs["forced_bos_token_id"] = tgt_lang_id
|
inputs["forced_bos_token_id"] = tgt_lang_id
|
||||||
return inputs
|
return inputs
|
||||||
|
|||||||
@@ -229,7 +229,9 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_tokenizer_translation(self):
|
def test_tokenizer_translation(self):
|
||||||
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en_XX", tgt_lang="ar_AR")
|
inputs = self.tokenizer._build_translation_inputs(
|
||||||
|
"A test", return_tensors="pt", src_lang="en_XX", tgt_lang="ar_AR"
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
nested_simplify(inputs),
|
nested_simplify(inputs),
|
||||||
|
|||||||
Reference in New Issue
Block a user