Removing duplicated code for Translation,Summarization and Text2TextGeneration pipelines (#9433)
* Merging all duplicated codes for Text2TextPipeline while preserving backward compat. * Fixing TranslationPipeline Hierarchy + return_name * torch import guard. * Update isort version. * Remove code from other PR disentanglement. * Removed named example to something more agnostic.
This commit is contained in:
@@ -6,7 +6,7 @@ from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING
|
||||
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
|
||||
if is_torch_available():
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
@@ -14,242 +14,6 @@ if is_torch_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||
class SummarizationPipeline(Pipeline):
|
||||
"""
|
||||
Summarize news articles and other documents.
|
||||
|
||||
This summarizing pipeline can currently be loaded from :func:`~transformers.pipeline` using the following task
|
||||
identifier: :obj:`"summarization"`.
|
||||
|
||||
The models that this pipeline can use are models that have been fine-tuned on a summarization task, which is
|
||||
currently, '`bart-large-cnn`', '`t5-small`', '`t5-base`', '`t5-large`', '`t5-3b`', '`t5-11b`'. See the up-to-date
|
||||
list of available models on `huggingface.co/models <https://huggingface.co/models?filter=summarization>`__.
|
||||
|
||||
Usage::
|
||||
|
||||
# use bart in pytorch
|
||||
summarizer = pipeline("summarization")
|
||||
summarizer("Sam Shleifer writes the best docstring examples in the whole world.", min_length=5, max_length=20)
|
||||
|
||||
# use t5 in tf
|
||||
summarizer = pipeline("summarization", model="t5-base", tokenizer="t5-base", framework="tf")
|
||||
summarizer("Sam Shleifer writes the best docstring examples in the whole world.", min_length=5, max_length=20)
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.update(task="summarization")
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.check_model_type(
|
||||
TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, *documents, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
|
||||
):
|
||||
r"""
|
||||
Summarize the text(s) given as inputs.
|
||||
|
||||
Args:
|
||||
documents (`str` or :obj:`List[str]`):
|
||||
One or several articles (or one list of articles) to summarize.
|
||||
return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to include the decoded texts in the outputs
|
||||
return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to include the tensors of predictions (as token indices) in the outputs.
|
||||
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to clean up the potential extra spaces in the text output.
|
||||
generate_kwargs:
|
||||
Additional keyword arguments to pass along to the generate method of the model (see the generate method
|
||||
corresponding to your framework `here <./model.html#generative-models>`__).
|
||||
|
||||
Return:
|
||||
A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the following keys:
|
||||
|
||||
- **summary_text** (:obj:`str`, present when ``return_text=True``) -- The summary of the corresponding
|
||||
input.
|
||||
- **summary_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``) --
|
||||
The token ids of the summary.
|
||||
"""
|
||||
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
|
||||
assert len(documents) > 0, "Please provide a document to summarize"
|
||||
|
||||
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
|
||||
|
||||
if isinstance(documents[0], list):
|
||||
assert (
|
||||
self.tokenizer.pad_token_id is not None
|
||||
), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
|
||||
|
||||
documents = ([prefix + document for document in documents[0]],)
|
||||
padding = True
|
||||
|
||||
elif isinstance(documents[0], str):
|
||||
documents = (prefix + documents[0],)
|
||||
padding = False
|
||||
else:
|
||||
raise ValueError(
|
||||
" `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format(
|
||||
documents[0]
|
||||
)
|
||||
)
|
||||
|
||||
with self.device_placement():
|
||||
inputs = self._parse_and_tokenize(*documents, padding=padding)
|
||||
|
||||
if self.framework == "pt":
|
||||
inputs = self.ensure_tensor_on_device(**inputs)
|
||||
input_length = inputs["input_ids"].shape[-1]
|
||||
elif self.framework == "tf":
|
||||
input_length = tf.shape(inputs["input_ids"])[-1].numpy()
|
||||
|
||||
min_length = generate_kwargs.get("min_length", self.model.config.min_length)
|
||||
if input_length < min_length // 2:
|
||||
logger.warning(
|
||||
"Your min_length is set to {}, but you input_length is only {}. You might consider decreasing min_length manually, e.g. summarizer('...', min_length=10)".format(
|
||||
min_length, input_length
|
||||
)
|
||||
)
|
||||
|
||||
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
|
||||
if input_length < max_length:
|
||||
logger.warning(
|
||||
"Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)".format(
|
||||
max_length, input_length
|
||||
)
|
||||
)
|
||||
|
||||
summaries = self.model.generate(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
results = []
|
||||
for summary in summaries:
|
||||
record = {}
|
||||
if return_tensors:
|
||||
record["summary_token_ids"] = summary
|
||||
if return_text:
|
||||
record["summary_text"] = self.tokenizer.decode(
|
||||
summary,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
)
|
||||
results.append(record)
|
||||
return results
|
||||
|
||||
|
||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||
class TranslationPipeline(Pipeline):
|
||||
"""
|
||||
Translates from one language to another.
|
||||
|
||||
This translation pipeline can currently be loaded from :func:`~transformers.pipeline` using the following task
|
||||
identifier: :obj:`"translation_xx_to_yy"`.
|
||||
|
||||
The models that this pipeline can use are models that have been fine-tuned on a translation task. See the
|
||||
up-to-date list of available models on `huggingface.co/models
|
||||
<https://huggingface.co/models?filter=translation>`__.
|
||||
|
||||
Usage::
|
||||
en_fr_translator = pipeline("translation_en_to_fr")
|
||||
en_fr_translator("How old are you?")
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.check_model_type(
|
||||
TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
|
||||
):
|
||||
r"""
|
||||
Translate the text(s) given as inputs.
|
||||
|
||||
Args:
|
||||
args (:obj:`str` or :obj:`List[str]`):
|
||||
Texts to be translated.
|
||||
return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to include the tensors of predictions (as token indices) in the outputs.
|
||||
return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to include the decoded texts in the outputs.
|
||||
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to clean up the potential extra spaces in the text output.
|
||||
generate_kwargs:
|
||||
Additional keyword arguments to pass along to the generate method of the model (see the generate method
|
||||
corresponding to your framework `here <./model.html#generative-models>`__).
|
||||
|
||||
Return:
|
||||
A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the following keys:
|
||||
|
||||
- **translation_text** (:obj:`str`, present when ``return_text=True``) -- The translation.
|
||||
- **translation_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
|
||||
-- The token ids of the translation.
|
||||
"""
|
||||
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
|
||||
|
||||
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
|
||||
|
||||
if isinstance(args[0], list):
|
||||
assert (
|
||||
self.tokenizer.pad_token_id is not None
|
||||
), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
|
||||
args = ([prefix + text for text in args[0]],)
|
||||
padding = True
|
||||
|
||||
elif isinstance(args[0], str):
|
||||
args = (prefix + args[0],)
|
||||
padding = False
|
||||
else:
|
||||
raise ValueError(
|
||||
" `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format(
|
||||
args[0]
|
||||
)
|
||||
)
|
||||
|
||||
with self.device_placement():
|
||||
inputs = self._parse_and_tokenize(*args, padding=padding)
|
||||
|
||||
if self.framework == "pt":
|
||||
inputs = self.ensure_tensor_on_device(**inputs)
|
||||
input_length = inputs["input_ids"].shape[-1]
|
||||
|
||||
elif self.framework == "tf":
|
||||
input_length = tf.shape(inputs["input_ids"])[-1].numpy()
|
||||
|
||||
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
|
||||
if input_length > 0.9 * max_length:
|
||||
logger.warning(
|
||||
"Your input_length: {} is bigger than 0.9 * max_length: {}. You might consider increasing your max_length manually, e.g. translator('...', max_length=400)".format(
|
||||
input_length, max_length
|
||||
)
|
||||
)
|
||||
|
||||
translations = self.model.generate(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
**generate_kwargs,
|
||||
)
|
||||
results = []
|
||||
for translation in translations:
|
||||
record = {}
|
||||
if return_tensors:
|
||||
record["translation_token_ids"] = translation
|
||||
if return_text:
|
||||
record["translation_text"] = self.tokenizer.decode(
|
||||
translation,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
)
|
||||
results.append(record)
|
||||
return results
|
||||
|
||||
|
||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||
class Text2TextGenerationPipeline(Pipeline):
|
||||
"""
|
||||
@@ -267,6 +31,9 @@ class Text2TextGenerationPipeline(Pipeline):
|
||||
text2text_generator("question: What is 42 ? context: 42 is the answer to life, the universe and everything")
|
||||
"""
|
||||
|
||||
# Used in the return key of the pipeline.
|
||||
return_name = "generated"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@@ -276,6 +43,12 @@ class Text2TextGenerationPipeline(Pipeline):
|
||||
else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
)
|
||||
|
||||
def check_inputs(self, input_length: int, min_length: int, max_length: int):
|
||||
"""
|
||||
Checks wether there might be something wrong with given input with regard to the model.
|
||||
"""
|
||||
return True
|
||||
|
||||
def __call__(
|
||||
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
|
||||
):
|
||||
@@ -304,26 +77,39 @@ class Text2TextGenerationPipeline(Pipeline):
|
||||
"""
|
||||
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
|
||||
|
||||
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
|
||||
if isinstance(args[0], list):
|
||||
assert (
|
||||
self.tokenizer.pad_token_id is not None
|
||||
), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
|
||||
args = ([prefix + arg for arg in args[0]],)
|
||||
padding = True
|
||||
|
||||
elif isinstance(args[0], str):
|
||||
args = (prefix + args[0],)
|
||||
padding = False
|
||||
else:
|
||||
raise ValueError(
|
||||
" `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format(
|
||||
" `args[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format(
|
||||
args[0]
|
||||
)
|
||||
)
|
||||
|
||||
with self.device_placement():
|
||||
inputs = self._parse_and_tokenize(*args, padding=padding)
|
||||
inputs = self._parse_and_tokenize(*args, padding=padding, **generate_kwargs)
|
||||
|
||||
if self.framework == "pt":
|
||||
inputs = self.ensure_tensor_on_device(**inputs)
|
||||
input_length = inputs["input_ids"].shape[-1]
|
||||
elif self.framework == "tf":
|
||||
input_length = tf.shape(inputs["input_ids"])[-1].numpy()
|
||||
|
||||
min_length = generate_kwargs.get("min_length", self.model.config.min_length)
|
||||
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
|
||||
self.check_inputs(input_length, min_length, max_length)
|
||||
|
||||
# truncation should be used by _parse_and_tokenize
|
||||
generate_kwargs.pop("truncation", None)
|
||||
|
||||
generations = self.model.generate(
|
||||
inputs["input_ids"],
|
||||
@@ -334,12 +120,139 @@ class Text2TextGenerationPipeline(Pipeline):
|
||||
for generation in generations:
|
||||
record = {}
|
||||
if return_tensors:
|
||||
record["generated_token_ids"] = generation
|
||||
record[f"{self.return_name}_token_ids"] = generation
|
||||
if return_text:
|
||||
record["generated_text"] = self.tokenizer.decode(
|
||||
record[f"{self.return_name}_text"] = self.tokenizer.decode(
|
||||
generation,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
)
|
||||
results.append(record)
|
||||
return results
|
||||
|
||||
|
||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||
class SummarizationPipeline(Text2TextGenerationPipeline):
|
||||
"""
|
||||
Summarize news articles and other documents.
|
||||
|
||||
This summarizing pipeline can currently be loaded from :func:`~transformers.pipeline` using the following task
|
||||
identifier: :obj:`"summarization"`.
|
||||
|
||||
The models that this pipeline can use are models that have been fine-tuned on a summarization task, which is
|
||||
currently, '`bart-large-cnn`', '`t5-small`', '`t5-base`', '`t5-large`', '`t5-3b`', '`t5-11b`'. See the up-to-date
|
||||
list of available models on `huggingface.co/models <https://huggingface.co/models?filter=summarization>`__.
|
||||
|
||||
Usage::
|
||||
|
||||
# use bart in pytorch
|
||||
summarizer = pipeline("summarization")
|
||||
summarizer("An apple a day, keeps the doctor away", min_length=5, max_length=20)
|
||||
|
||||
# use t5 in tf
|
||||
summarizer = pipeline("summarization", model="t5-base", tokenizer="t5-base", framework="tf")
|
||||
summarizer("An apple a day, keeps the doctor away", min_length=5, max_length=20)
|
||||
"""
|
||||
|
||||
# Used in the return key of the pipeline.
|
||||
return_name = "summary"
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
r"""
|
||||
Summarize the text(s) given as inputs.
|
||||
|
||||
Args:
|
||||
documents (`str` or :obj:`List[str]`):
|
||||
One or several articles (or one list of articles) to summarize.
|
||||
return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to include the decoded texts in the outputs
|
||||
return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to include the tensors of predictions (as token indices) in the outputs.
|
||||
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to clean up the potential extra spaces in the text output.
|
||||
generate_kwargs:
|
||||
Additional keyword arguments to pass along to the generate method of the model (see the generate method
|
||||
corresponding to your framework `here <./model.html#generative-models>`__).
|
||||
|
||||
Return:
|
||||
A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the following keys:
|
||||
|
||||
- **summary_text** (:obj:`str`, present when ``return_text=True``) -- The summary of the corresponding
|
||||
input.
|
||||
- **summary_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``) --
|
||||
The token ids of the summary.
|
||||
"""
|
||||
return super().__call__(*args, **kwargs)
|
||||
|
||||
def check_inputs(self, input_length: int, min_length: int, max_length: int) -> bool:
|
||||
"""
|
||||
Checks wether there might be something wrong with given input with regard to the model.
|
||||
"""
|
||||
if input_length < min_length // 2:
|
||||
logger.warning(
|
||||
"Your min_length is set to {}, but you input_length is only {}. You might consider decreasing min_length manually, e.g. summarizer('...', min_length=10)".format(
|
||||
min_length, input_length
|
||||
)
|
||||
)
|
||||
|
||||
if input_length < max_length:
|
||||
logger.warning(
|
||||
"Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)".format(
|
||||
max_length, input_length
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||
class TranslationPipeline(Text2TextGenerationPipeline):
|
||||
"""
|
||||
Translates from one language to another.
|
||||
|
||||
This translation pipeline can currently be loaded from :func:`~transformers.pipeline` using the following task
|
||||
identifier: :obj:`"translation_xx_to_yy"`.
|
||||
|
||||
The models that this pipeline can use are models that have been fine-tuned on a translation task. See the
|
||||
up-to-date list of available models on `huggingface.co/models
|
||||
<https://huggingface.co/models?filter=translation>`__.
|
||||
|
||||
Usage::
|
||||
en_fr_translator = pipeline("translation_en_to_fr")
|
||||
en_fr_translator("How old are you?")
|
||||
"""
|
||||
|
||||
# Used in the return key of the pipeline.
|
||||
return_name = "translation"
|
||||
|
||||
def check_inputs(self, input_length: int, min_length: int, max_length: int):
|
||||
if input_length > 0.9 * max_length:
|
||||
logger.warning(
|
||||
"Your input_length: {} is bigger than 0.9 * max_length: {}. You might consider increasing your max_length manually, e.g. translator('...', max_length=400)".format(
|
||||
input_length, max_length
|
||||
)
|
||||
)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
r"""
|
||||
Translate the text(s) given as inputs.
|
||||
|
||||
Args:
|
||||
args (:obj:`str` or :obj:`List[str]`):
|
||||
Texts to be translated.
|
||||
return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to include the tensors of predictions (as token indices) in the outputs.
|
||||
return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to include the decoded texts in the outputs.
|
||||
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to clean up the potential extra spaces in the text output.
|
||||
generate_kwargs:
|
||||
Additional keyword arguments to pass along to the generate method of the model (see the generate method
|
||||
corresponding to your framework `here <./model.html#generative-models>`__).
|
||||
|
||||
Return:
|
||||
A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the following keys:
|
||||
|
||||
- **translation_text** (:obj:`str`, present when ``return_text=True``) -- The translation.
|
||||
- **translation_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
|
||||
-- The token ids of the translation.
|
||||
"""
|
||||
return super().__call__(*args, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user