[cleanup/marian] pipelines test and new kwarg (#4812)
This commit is contained in:
@@ -48,13 +48,12 @@ class MarianTokenizer(PreTrainedTokenizer):
|
|||||||
unk_token="<unk>",
|
unk_token="<unk>",
|
||||||
eos_token="</s>",
|
eos_token="</s>",
|
||||||
pad_token="<pad>",
|
pad_token="<pad>",
|
||||||
max_len=512,
|
model_max_length=512,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
# bos_token=bos_token, unused. Start decoding with config.decoder_start_token_id
|
# bos_token=bos_token, unused. Start decoding with config.decoder_start_token_id
|
||||||
max_len=max_len,
|
model_max_length=model_max_length,
|
||||||
eos_token=eos_token,
|
eos_token=eos_token,
|
||||||
unk_token=unk_token,
|
unk_token=unk_token,
|
||||||
pad_token=pad_token,
|
pad_token=pad_token,
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ if is_torch_available():
|
|||||||
convert_opus_name_to_hf_name,
|
convert_opus_name_to_hf_name,
|
||||||
ORG_NAME,
|
ORG_NAME,
|
||||||
)
|
)
|
||||||
|
from transformers.pipelines import TranslationPipeline
|
||||||
|
|
||||||
|
|
||||||
class ModelManagementTests(unittest.TestCase):
|
class ModelManagementTests(unittest.TestCase):
|
||||||
@@ -189,6 +190,7 @@ class TestMarian_RU_FR(MarianIntegrationTest):
|
|||||||
src_text = ["Он показал мне рукопись своей новой пьесы."]
|
src_text = ["Он показал мне рукопись своей новой пьесы."]
|
||||||
expected_text = ["Il m'a montré le manuscrit de sa nouvelle pièce."]
|
expected_text = ["Il m'a montré le manuscrit de sa nouvelle pièce."]
|
||||||
|
|
||||||
|
@slow
|
||||||
def test_batch_generation_ru_fr(self):
|
def test_batch_generation_ru_fr(self):
|
||||||
self._assert_generated_batch_equal_expected()
|
self._assert_generated_batch_equal_expected()
|
||||||
|
|
||||||
@@ -199,6 +201,7 @@ class TestMarian_MT_EN(MarianIntegrationTest):
|
|||||||
src_text = ["Billi messu b'mod ġentili, Ġesù fejjaq raġel li kien milqut bil - marda kerha tal - ġdiem."]
|
src_text = ["Billi messu b'mod ġentili, Ġesù fejjaq raġel li kien milqut bil - marda kerha tal - ġdiem."]
|
||||||
expected_text = ["Touching gently, Jesus healed a man who was affected by the sad disease of leprosy."]
|
expected_text = ["Touching gently, Jesus healed a man who was affected by the sad disease of leprosy."]
|
||||||
|
|
||||||
|
@slow
|
||||||
def test_batch_generation_mt_en(self):
|
def test_batch_generation_mt_en(self):
|
||||||
self._assert_generated_batch_equal_expected()
|
self._assert_generated_batch_equal_expected()
|
||||||
|
|
||||||
@@ -229,6 +232,11 @@ class TestMarian_en_ROMANCE(MarianIntegrationTest):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
self.tokenizer.prepare_translation_batch([""])
|
self.tokenizer.prepare_translation_batch([""])
|
||||||
|
|
||||||
|
def test_pipeline(self):
|
||||||
|
pipeline = TranslationPipeline(self.model, self.tokenizer, framework="pt")
|
||||||
|
output = pipeline(self.src_text)
|
||||||
|
self.assertEqual(self.expected_text, [x["translation_text"] for x in output])
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class TestConversionUtils(unittest.TestCase):
|
class TestConversionUtils(unittest.TestCase):
|
||||||
|
|||||||
@@ -52,8 +52,7 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
tokenizer.save_pretrained(self.tmpdirname)
|
tokenizer.save_pretrained(self.tmpdirname)
|
||||||
|
|
||||||
def get_tokenizer(self, max_len=None, **kwargs) -> MarianTokenizer:
|
def get_tokenizer(self, max_len=None, **kwargs) -> MarianTokenizer:
|
||||||
# overwrite max_len=512 default
|
return MarianTokenizer.from_pretrained(self.tmpdirname, model_max_length=max_len, **kwargs)
|
||||||
return MarianTokenizer.from_pretrained(self.tmpdirname, max_len=max_len, **kwargs)
|
|
||||||
|
|
||||||
def get_input_output_texts(self):
|
def get_input_output_texts(self):
|
||||||
return (
|
return (
|
||||||
|
|||||||
Reference in New Issue
Block a user