Adds translation pipeline (#3419)
* fix merge conflicts * add t5 summarization example * change parameters for t5 summarization * make style * add first code snippet for translation * only add prefixes * add prefix patterns * make style * renaming * fix conflicts * remove unused patterns * solve conflicts * fix merge conflicts * remove translation example * remove summarization example * make sure tensors are in numpy for float comparsion * re-add t5 config * fix t5 import config typo * make style * remove unused numpy statements * update doctstring * import translation pipeline
This commit is contained in:
committed by
GitHub
parent
3c5c567507
commit
022e8fab97
@@ -81,6 +81,12 @@ TF_FILL_MASK_FINETUNED_MODELS = [
|
||||
SUMMARIZATION_FINETUNED_MODELS = {("bart-large-cnn", "bart-large-cnn"), ("t5-small", "t5-small")}
|
||||
TF_SUMMARIZATION_FINETUNED_MODELS = {("t5-small", "t5-small")}
|
||||
|
||||
TRANSLATION_FINETUNED_MODELS = {
|
||||
("t5-small", "t5-small", "translation_en_to_de"),
|
||||
("t5-small", "t5-small", "translation_en_to_ro"),
|
||||
}
|
||||
TF_TRANSLATION_FINETUNED_MODELS = {("t5-small", "t5-small", "translation_en_to_fr")}
|
||||
|
||||
|
||||
class MonoColumnInputTestCase(unittest.TestCase):
|
||||
def _test_mono_column_pipeline(
|
||||
@@ -272,6 +278,28 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
||||
nlp, valid_inputs, invalid_inputs, mandatory_keys,
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_translation(self):
|
||||
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
|
||||
invalid_inputs = [4, "<mask>"]
|
||||
mandatory_keys = ["translation_text"]
|
||||
for model, tokenizer, task in TRANSLATION_FINETUNED_MODELS:
|
||||
nlp = pipeline(task=task, model=model, tokenizer=tokenizer)
|
||||
self._test_mono_column_pipeline(
|
||||
nlp, valid_inputs, invalid_inputs, mandatory_keys,
|
||||
)
|
||||
|
||||
@require_tf
|
||||
def test_tf_translation(self):
|
||||
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
|
||||
invalid_inputs = [4, "<mask>"]
|
||||
mandatory_keys = ["translation_text"]
|
||||
for model, tokenizer, task in TF_TRANSLATION_FINETUNED_MODELS:
|
||||
nlp = pipeline(task=task, model=model, tokenizer=tokenizer, framework="tf")
|
||||
self._test_mono_column_pipeline(
|
||||
nlp, valid_inputs, invalid_inputs, mandatory_keys,
|
||||
)
|
||||
|
||||
|
||||
class MultiColumnInputTestCase(unittest.TestCase):
|
||||
def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):
|
||||
|
||||
Reference in New Issue
Block a user