Fixing mbart50 with return_tensors argument too. (#13301)

* Fixing mbart50 with `return_tensors` argument too.

* Adding mbart50 tokenization tests.
This commit is contained in:
Nicolas Patry
2021-08-27 17:22:06 +02:00
committed by GitHub
parent b89a964d3f
commit 8aa67fc192
3 changed files with 11 additions and 5 deletions

View File

@@ -229,7 +229,9 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
@require_torch
def test_tokenizer_translation(self):
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en_XX", tgt_lang="ar_AR")
inputs = self.tokenizer._build_translation_inputs(
"A test", return_tensors="pt", src_lang="en_XX", tgt_lang="ar_AR"
)
self.assertEqual(
nested_simplify(inputs),