Adds translation pipeline (#3419)
* fix merge conflicts * add t5 summarization example * change parameters for t5 summarization * make style * add first code snippet for translation * only add prefixes * add prefix patterns * make style * renaming * fix conflicts * remove unused patterns * solve conflicts * fix merge conflicts * remove translation example * remove summarization example * make sure tensors are in numpy for float comparsion * re-add t5 config * fix t5 import config typo * make style * remove unused numpy statements * update doctstring * import translation pipeline
This commit is contained in:
committed by
GitHub
parent
3c5c567507
commit
022e8fab97
@@ -116,6 +116,7 @@ from .pipelines import (
|
|||||||
SummarizationPipeline,
|
SummarizationPipeline,
|
||||||
TextClassificationPipeline,
|
TextClassificationPipeline,
|
||||||
TokenClassificationPipeline,
|
TokenClassificationPipeline,
|
||||||
|
TranslationPipeline,
|
||||||
pipeline,
|
pipeline,
|
||||||
)
|
)
|
||||||
from .tokenization_albert import AlbertTokenizer
|
from .tokenization_albert import AlbertTokenizer
|
||||||
|
|||||||
@@ -130,7 +130,9 @@ class PipelineDataFormat:
|
|||||||
|
|
||||||
SUPPORTED_FORMATS = ["json", "csv", "pipe"]
|
SUPPORTED_FORMATS = ["json", "csv", "pipe"]
|
||||||
|
|
||||||
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
|
def __init__(
|
||||||
|
self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False,
|
||||||
|
):
|
||||||
self.output_path = output_path
|
self.output_path = output_path
|
||||||
self.input_path = input_path
|
self.input_path = input_path
|
||||||
self.column = column.split(",") if column is not None else [""]
|
self.column = column.split(",") if column is not None else [""]
|
||||||
@@ -176,7 +178,7 @@ class PipelineDataFormat:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_str(
|
def from_str(
|
||||||
format: str, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False
|
format: str, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False,
|
||||||
):
|
):
|
||||||
if format == "json":
|
if format == "json":
|
||||||
return JsonPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
|
return JsonPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
|
||||||
@@ -189,7 +191,9 @@ class PipelineDataFormat:
|
|||||||
|
|
||||||
|
|
||||||
class CsvPipelineDataFormat(PipelineDataFormat):
|
class CsvPipelineDataFormat(PipelineDataFormat):
|
||||||
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
|
def __init__(
|
||||||
|
self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False,
|
||||||
|
):
|
||||||
super().__init__(output_path, input_path, column, overwrite=overwrite)
|
super().__init__(output_path, input_path, column, overwrite=overwrite)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
@@ -210,7 +214,9 @@ class CsvPipelineDataFormat(PipelineDataFormat):
|
|||||||
|
|
||||||
|
|
||||||
class JsonPipelineDataFormat(PipelineDataFormat):
|
class JsonPipelineDataFormat(PipelineDataFormat):
|
||||||
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
|
def __init__(
|
||||||
|
self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False,
|
||||||
|
):
|
||||||
super().__init__(output_path, input_path, column, overwrite=overwrite)
|
super().__init__(output_path, input_path, column, overwrite=overwrite)
|
||||||
|
|
||||||
with open(input_path, "r") as f:
|
with open(input_path, "r") as f:
|
||||||
@@ -1120,7 +1126,11 @@ class QuestionAnsweringPipeline(Pipeline):
|
|||||||
chars_idx += len(word) + 1
|
chars_idx += len(word) + 1
|
||||||
|
|
||||||
# Join text with spaces
|
# Join text with spaces
|
||||||
return {"answer": " ".join(words), "start": max(0, char_start_idx), "end": min(len(text), char_end_idx)}
|
return {
|
||||||
|
"answer": " ".join(words),
|
||||||
|
"start": max(0, char_start_idx),
|
||||||
|
"end": min(len(text), char_end_idx),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class SummarizationPipeline(Pipeline):
|
class SummarizationPipeline(Pipeline):
|
||||||
@@ -1223,18 +1233,18 @@ class SummarizationPipeline(Pipeline):
|
|||||||
inputs = self.ensure_tensor_on_device(**inputs)
|
inputs = self.ensure_tensor_on_device(**inputs)
|
||||||
input_length = inputs["input_ids"].shape[-1]
|
input_length = inputs["input_ids"].shape[-1]
|
||||||
elif self.framework == "tf":
|
elif self.framework == "tf":
|
||||||
input_length = tf.shape(inputs["input_ids"])[-1]
|
input_length = tf.shape(inputs["input_ids"])[-1].numpy()
|
||||||
|
|
||||||
if input_length < self.model.config.min_length // 2:
|
if input_length < self.model.config.min_length // 2:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Your min_length is set to {}, but you input_length is only {}. You might consider decreasing min_length in config and insert config manually".format(
|
"Your min_length is set to {}, but you input_length is only {}. You might consider decreasing min_length manually, e.g. summarizer('...', min_length=10)".format(
|
||||||
self.model.config.min_length, input_length
|
self.model.config.min_length, input_length
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if input_length < self.model.config.max_length:
|
if input_length < self.model.config.max_length:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length in config and insert config manually".format(
|
"Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)".format(
|
||||||
self.model.config.max_length, input_length
|
self.model.config.max_length, input_length
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -1250,7 +1260,115 @@ class SummarizationPipeline(Pipeline):
|
|||||||
record["summary_token_ids"] = summary
|
record["summary_token_ids"] = summary
|
||||||
if return_text:
|
if return_text:
|
||||||
record["summary_text"] = self.tokenizer.decode(
|
record["summary_text"] = self.tokenizer.decode(
|
||||||
summary, skip_special_tokens=True, clean_up_tokenization_spaces=clean_up_tokenization_spaces
|
summary, skip_special_tokens=True, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||||
|
)
|
||||||
|
results.append(record)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class TranslationPipeline(Pipeline):
|
||||||
|
"""
|
||||||
|
Translates from one language to another.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
en_fr_translator = pipeline("translation_en_to_fr")
|
||||||
|
en_fr_translator("How old are you?")
|
||||||
|
|
||||||
|
Supported Models: "t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
|
||||||
|
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
|
||||||
|
checkpoint identifier or an actual pre-trained model inheriting from
|
||||||
|
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
|
||||||
|
TensorFlow.
|
||||||
|
If :obj:`None`, the default of the pipeline will be loaded.
|
||||||
|
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
|
||||||
|
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
|
||||||
|
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
|
||||||
|
:class:`~transformers.PreTrainedTokenizer`.
|
||||||
|
If :obj:`None`, the default of the pipeline will be loaded.
|
||||||
|
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
|
||||||
|
Model card attributed to the model for this pipeline.
|
||||||
|
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||||
|
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be
|
||||||
|
installed.
|
||||||
|
If no framework is specified, will default to the one currently installed. If no framework is specified
|
||||||
|
and both frameworks are installed, will default to PyTorch.
|
||||||
|
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`):
|
||||||
|
Reference to the object in charge of parsing supplied pipeline parameters.
|
||||||
|
device (:obj:`int`, `optional`, defaults to :obj:`-1`):
|
||||||
|
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model
|
||||||
|
on the associated CUDA device id.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, *texts, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
*texts: (list of strings) articles to be summarized
|
||||||
|
return_text: (bool, default=True) whether to add a decoded "translation_text" to each result
|
||||||
|
return_tensors: (bool, default=False) whether to return the raw "translation_token_ids" to each result
|
||||||
|
|
||||||
|
**generate_kwargs: extra kwargs passed to `self.model.generate`_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of dicts with 'translation_text' and/or 'translation_token_ids' for each text_to_translate
|
||||||
|
.. _`self.model.generate`:
|
||||||
|
https://huggingface.co/transformers/model_doc/bart.html#transformers.BartForConditionalGeneration.generate
|
||||||
|
"""
|
||||||
|
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
|
||||||
|
|
||||||
|
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
|
||||||
|
|
||||||
|
if isinstance(texts[0], list):
|
||||||
|
assert (
|
||||||
|
self.tokenizer.pad_token_id is not None
|
||||||
|
), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
|
||||||
|
texts = ([prefix + text for text in texts[0]],)
|
||||||
|
pad_to_max_length = True
|
||||||
|
|
||||||
|
elif isinstance(texts[0], str):
|
||||||
|
texts = (prefix + texts[0],)
|
||||||
|
pad_to_max_length = False
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
" `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format(
|
||||||
|
texts[0]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.device_placement():
|
||||||
|
inputs = self._parse_and_tokenize(*texts, pad_to_max_length=pad_to_max_length)
|
||||||
|
|
||||||
|
if self.framework == "pt":
|
||||||
|
inputs = self.ensure_tensor_on_device(**inputs)
|
||||||
|
input_length = inputs["input_ids"].shape[-1]
|
||||||
|
|
||||||
|
elif self.framework == "tf":
|
||||||
|
input_length = tf.shape(inputs["input_ids"])[-1].numpy()
|
||||||
|
|
||||||
|
if input_length > 0.9 * self.model.config.max_length:
|
||||||
|
logger.warning(
|
||||||
|
"Your input_length: {} is bigger than 0.9 * max_length: {}. You might consider increasing your max_length manually, e.g. translator('...', max_length=400)".format(
|
||||||
|
input_length, self.model.config.max_length
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
translations = self.model.generate(
|
||||||
|
inputs["input_ids"], attention_mask=inputs["attention_mask"], **generate_kwargs,
|
||||||
|
)
|
||||||
|
results = []
|
||||||
|
for translation in translations:
|
||||||
|
record = {}
|
||||||
|
if return_tensors:
|
||||||
|
record["translation_token_ids"] = translation
|
||||||
|
if return_text:
|
||||||
|
record["translation_text"] = self.tokenizer.decode(
|
||||||
|
translation,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||||
)
|
)
|
||||||
results.append(record)
|
results.append(record)
|
||||||
return results
|
return results
|
||||||
@@ -1324,6 +1442,36 @@ SUPPORTED_TASKS = {
|
|||||||
"tokenizer": ("bart-large-cnn", {"use_fast": False}),
|
"tokenizer": ("bart-large-cnn", {"use_fast": False}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"translation_en_to_fr": {
|
||||||
|
"impl": TranslationPipeline,
|
||||||
|
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
||||||
|
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
||||||
|
"default": {
|
||||||
|
"model": {"pt": "t5-base", "tf": "t5-base"},
|
||||||
|
"config": None,
|
||||||
|
"tokenizer": ("t5-base", {"use_fast": False}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"translation_en_to_de": {
|
||||||
|
"impl": TranslationPipeline,
|
||||||
|
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
||||||
|
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
||||||
|
"default": {
|
||||||
|
"model": {"pt": "t5-base", "tf": "t5-base"},
|
||||||
|
"config": None,
|
||||||
|
"tokenizer": ("t5-base", {"use_fast": False}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"translation_en_to_ro": {
|
||||||
|
"impl": TranslationPipeline,
|
||||||
|
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
||||||
|
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
||||||
|
"default": {
|
||||||
|
"model": {"pt": "t5-base", "tf": "t5-base"},
|
||||||
|
"config": None,
|
||||||
|
"tokenizer": ("t5-base", {"use_fast": False}),
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -1472,4 +1620,4 @@ def pipeline(
|
|||||||
)
|
)
|
||||||
model = model_class.from_pretrained(model, config=config, **model_kwargs)
|
model = model_class.from_pretrained(model, config=config, **model_kwargs)
|
||||||
|
|
||||||
return task_class(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, task=task, **kwargs)
|
return task_class(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, task=task, **kwargs,)
|
||||||
|
|||||||
@@ -81,6 +81,12 @@ TF_FILL_MASK_FINETUNED_MODELS = [
|
|||||||
SUMMARIZATION_FINETUNED_MODELS = {("bart-large-cnn", "bart-large-cnn"), ("t5-small", "t5-small")}
|
SUMMARIZATION_FINETUNED_MODELS = {("bart-large-cnn", "bart-large-cnn"), ("t5-small", "t5-small")}
|
||||||
TF_SUMMARIZATION_FINETUNED_MODELS = {("t5-small", "t5-small")}
|
TF_SUMMARIZATION_FINETUNED_MODELS = {("t5-small", "t5-small")}
|
||||||
|
|
||||||
|
TRANSLATION_FINETUNED_MODELS = {
|
||||||
|
("t5-small", "t5-small", "translation_en_to_de"),
|
||||||
|
("t5-small", "t5-small", "translation_en_to_ro"),
|
||||||
|
}
|
||||||
|
TF_TRANSLATION_FINETUNED_MODELS = {("t5-small", "t5-small", "translation_en_to_fr")}
|
||||||
|
|
||||||
|
|
||||||
class MonoColumnInputTestCase(unittest.TestCase):
|
class MonoColumnInputTestCase(unittest.TestCase):
|
||||||
def _test_mono_column_pipeline(
|
def _test_mono_column_pipeline(
|
||||||
@@ -272,6 +278,28 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
|||||||
nlp, valid_inputs, invalid_inputs, mandatory_keys,
|
nlp, valid_inputs, invalid_inputs, mandatory_keys,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_translation(self):
|
||||||
|
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
|
||||||
|
invalid_inputs = [4, "<mask>"]
|
||||||
|
mandatory_keys = ["translation_text"]
|
||||||
|
for model, tokenizer, task in TRANSLATION_FINETUNED_MODELS:
|
||||||
|
nlp = pipeline(task=task, model=model, tokenizer=tokenizer)
|
||||||
|
self._test_mono_column_pipeline(
|
||||||
|
nlp, valid_inputs, invalid_inputs, mandatory_keys,
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
def test_tf_translation(self):
|
||||||
|
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
|
||||||
|
invalid_inputs = [4, "<mask>"]
|
||||||
|
mandatory_keys = ["translation_text"]
|
||||||
|
for model, tokenizer, task in TF_TRANSLATION_FINETUNED_MODELS:
|
||||||
|
nlp = pipeline(task=task, model=model, tokenizer=tokenizer, framework="tf")
|
||||||
|
self._test_mono_column_pipeline(
|
||||||
|
nlp, valid_inputs, invalid_inputs, mandatory_keys,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MultiColumnInputTestCase(unittest.TestCase):
|
class MultiColumnInputTestCase(unittest.TestCase):
|
||||||
def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):
|
def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):
|
||||||
|
|||||||
Reference in New Issue
Block a user