[M2M100Tokenizer] fix _build_translation_inputs (#14382)
* add return_tensors paramter * fix test * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * style Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -332,7 +332,7 @@ class M2M100Tokenizer(PreTrainedTokenizer):
|
|||||||
if src_lang is None or tgt_lang is None:
|
if src_lang is None or tgt_lang is None:
|
||||||
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
||||||
self.src_lang = src_lang
|
self.src_lang = src_lang
|
||||||
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
|
inputs = self(raw_inputs, add_special_tokens=True, **extra_kwargs)
|
||||||
tgt_lang_id = self.get_lang_id(tgt_lang)
|
tgt_lang_id = self.get_lang_id(tgt_lang)
|
||||||
inputs["forced_bos_token_id"] = tgt_lang_id
|
inputs["forced_bos_token_id"] = tgt_lang_id
|
||||||
return inputs
|
return inputs
|
||||||
|
|||||||
@@ -226,7 +226,7 @@ class M2M100TokenizerIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_tokenizer_translation(self):
|
def test_tokenizer_translation(self):
|
||||||
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en", tgt_lang="ar")
|
inputs = self.tokenizer._build_translation_inputs("A test", return_tensors="pt", src_lang="en", tgt_lang="ar")
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
nested_simplify(inputs),
|
nested_simplify(inputs),
|
||||||
|
|||||||
Reference in New Issue
Block a user