From 4ab742459735671189d774cfa336d52561655816 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Fri, 5 Jun 2020 18:45:19 -0400 Subject: [PATCH] [cleanup/marian] pipelines test and new kwarg (#4812) --- src/transformers/tokenization_marian.py | 7 +++---- tests/test_modeling_marian.py | 8 ++++++++ tests/test_tokenization_marian.py | 3 +-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/transformers/tokenization_marian.py b/src/transformers/tokenization_marian.py index a13fed09d0..57307615e5 100644 --- a/src/transformers/tokenization_marian.py +++ b/src/transformers/tokenization_marian.py @@ -48,13 +48,12 @@ class MarianTokenizer(PreTrainedTokenizer): unk_token="", eos_token="", pad_token="", - max_len=512, - **kwargs, + model_max_length=512, + **kwargs ): - super().__init__( # bos_token=bos_token, unused. Start decoding with config.decoder_start_token_id - max_len=max_len, + model_max_length=model_max_length, eos_token=eos_token, unk_token=unk_token, pad_token=pad_token, diff --git a/tests/test_modeling_marian.py b/tests/test_modeling_marian.py index c1a1f4a96c..46377da348 100644 --- a/tests/test_modeling_marian.py +++ b/tests/test_modeling_marian.py @@ -38,6 +38,7 @@ if is_torch_available(): convert_opus_name_to_hf_name, ORG_NAME, ) + from transformers.pipelines import TranslationPipeline class ModelManagementTests(unittest.TestCase): @@ -189,6 +190,7 @@ class TestMarian_RU_FR(MarianIntegrationTest): src_text = ["Он показал мне рукопись своей новой пьесы."] expected_text = ["Il m'a montré le manuscrit de sa nouvelle pièce."] + @slow def test_batch_generation_ru_fr(self): self._assert_generated_batch_equal_expected() @@ -199,6 +201,7 @@ class TestMarian_MT_EN(MarianIntegrationTest): src_text = ["Billi messu b'mod ġentili, Ġesù fejjaq raġel li kien milqut bil - marda kerha tal - ġdiem."] expected_text = ["Touching gently, Jesus healed a man who was affected by the sad disease of leprosy."] + @slow def test_batch_generation_mt_en(self): self._assert_generated_batch_equal_expected() @@ -229,6 +232,11 @@ class TestMarian_en_ROMANCE(MarianIntegrationTest): with self.assertRaises(ValueError): self.tokenizer.prepare_translation_batch([""]) + def test_pipeline(self): + pipeline = TranslationPipeline(self.model, self.tokenizer, framework="pt") + output = pipeline(self.src_text) + self.assertEqual(self.expected_text, [x["translation_text"] for x in output]) + @require_torch class TestConversionUtils(unittest.TestCase): diff --git a/tests/test_tokenization_marian.py b/tests/test_tokenization_marian.py index 688413af82..9f0e2342d3 100644 --- a/tests/test_tokenization_marian.py +++ b/tests/test_tokenization_marian.py @@ -52,8 +52,7 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase): tokenizer.save_pretrained(self.tmpdirname) def get_tokenizer(self, max_len=None, **kwargs) -> MarianTokenizer: - # overwrite max_len=512 default - return MarianTokenizer.from_pretrained(self.tmpdirname, max_len=max_len, **kwargs) + return MarianTokenizer.from_pretrained(self.tmpdirname, model_max_length=max_len, **kwargs) def get_input_output_texts(self): return (