@@ -139,6 +139,29 @@ one for summarization with beam search). You must have the right Hub permissions
|
|||||||
['Les fichiers de configuration sont faciles à utiliser !']
|
['Les fichiers de configuration sont faciles à utiliser !']
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Streaming
|
||||||
|
|
||||||
|
The `generate()` supports streaming, through its `streamer` input. The `streamer` input is compatible any instance
|
||||||
|
from a class that has the following methods: `put()` and `end()`. Internally, `put()` is used to push new tokens and
|
||||||
|
`end()` is used to flag the end of text generation.
|
||||||
|
|
||||||
|
In practice, you can craft your own streaming class for all sorts of purposes! We also have basic streaming classes
|
||||||
|
ready for you to use. For example, you can use the [`TextStreamer`] class to stream the output of `generate()` into
|
||||||
|
your screen, one word at a time:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
||||||
|
|
||||||
|
>>> tok = AutoTokenizer.from_pretrained("gpt2")
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||||
|
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
|
||||||
|
>>> streamer = TextStreamer(tok)
|
||||||
|
|
||||||
|
>>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
|
||||||
|
>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
|
||||||
|
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
|
||||||
|
```
|
||||||
|
|
||||||
## Decoding strategies
|
## Decoding strategies
|
||||||
|
|
||||||
Certain combinations of the `generate()` parameters, and ultimately `generation_config`, can be used to enable specific
|
Certain combinations of the `generate()` parameters, and ultimately `generation_config`, can be used to enable specific
|
||||||
|
|||||||
@@ -265,3 +265,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
|||||||
[[autodoc]] top_k_top_p_filtering
|
[[autodoc]] top_k_top_p_filtering
|
||||||
|
|
||||||
[[autodoc]] tf_top_k_top_p_filtering
|
[[autodoc]] tf_top_k_top_p_filtering
|
||||||
|
|
||||||
|
## Streamers
|
||||||
|
|
||||||
|
[[autodoc]] TextStreamer
|
||||||
|
|||||||
@@ -24,7 +24,8 @@ of the generation method.
|
|||||||
|
|
||||||
To learn how to inspect a model's generation configuration, what are the defaults, how to change the parameters ad hoc,
|
To learn how to inspect a model's generation configuration, what are the defaults, how to change the parameters ad hoc,
|
||||||
and how to create and save a customized generation configuration, refer to the
|
and how to create and save a customized generation configuration, refer to the
|
||||||
[text generation strategies guide](../generation_strategies).
|
[text generation strategies guide](../generation_strategies). The guide also explains how to use related features,
|
||||||
|
like token streaming.
|
||||||
|
|
||||||
## GenerationConfig
|
## GenerationConfig
|
||||||
|
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ _import_structure = {
|
|||||||
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
|
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
|
||||||
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
|
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
|
||||||
"file_utils": [],
|
"file_utils": [],
|
||||||
"generation": ["GenerationConfig"],
|
"generation": ["GenerationConfig", "TextStreamer"],
|
||||||
"hf_argparser": ["HfArgumentParser"],
|
"hf_argparser": ["HfArgumentParser"],
|
||||||
"image_transforms": [],
|
"image_transforms": [],
|
||||||
"integrations": [
|
"integrations": [
|
||||||
@@ -3769,7 +3769,7 @@ if TYPE_CHECKING:
|
|||||||
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
||||||
|
|
||||||
# Generation
|
# Generation
|
||||||
from .generation import GenerationConfig
|
from .generation import GenerationConfig, TextStreamer
|
||||||
from .hf_argparser import HfArgumentParser
|
from .hf_argparser import HfArgumentParser
|
||||||
|
|
||||||
# Integrations
|
# Integrations
|
||||||
|
|||||||
@@ -17,8 +17,7 @@ from typing import TYPE_CHECKING
|
|||||||
from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_tf_available, is_torch_available
|
from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_tf_available, is_torch_available
|
||||||
|
|
||||||
|
|
||||||
_import_structure = {"configuration_utils": ["GenerationConfig"]}
|
_import_structure = {"configuration_utils": ["GenerationConfig"], "streamers": ["TextStreamer"]}
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
@@ -150,6 +149,7 @@ else:
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_utils import GenerationConfig
|
from .configuration_utils import GenerationConfig
|
||||||
|
from .streamers import TextStreamer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
|
|||||||
104
src/transformers/generation/streamers.py
Normal file
104
src/transformers/generation/streamers.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..models.auto import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class BaseStreamer:
|
||||||
|
"""
|
||||||
|
Base class from which `.generate()` streamers should inherit.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def put(self, value):
|
||||||
|
"""Function that is called by `.generate()` to push new tokens"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def end(self):
|
||||||
|
"""Function that is called by `.generate()` to signal the end of generation"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class TextStreamer(BaseStreamer):
|
||||||
|
"""
|
||||||
|
Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
tokenizer (`AutoTokenizer`):
|
||||||
|
The tokenized used to decode the tokens.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
||||||
|
|
||||||
|
>>> tok = AutoTokenizer.from_pretrained("gpt2")
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||||
|
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
|
||||||
|
>>> streamer = TextStreamer(tok)
|
||||||
|
|
||||||
|
>>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
|
||||||
|
>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
|
||||||
|
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: "AutoTokenizer"):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.token_cache = []
|
||||||
|
self.print_len = 0
|
||||||
|
|
||||||
|
def put(self, value):
|
||||||
|
"""
|
||||||
|
Recives tokens, decodes them, and prints them to stdout as soon as they form entire words.
|
||||||
|
"""
|
||||||
|
if len(value.shape) > 1 and value.shape[0] > 1:
|
||||||
|
raise ValueError("TextStreamer only supports batch size 1")
|
||||||
|
elif len(value.shape) > 1:
|
||||||
|
value = value[0]
|
||||||
|
|
||||||
|
# Add the new token to the cache and decodes the entire thing.
|
||||||
|
self.token_cache.extend(value.tolist())
|
||||||
|
text = self.tokenizer.decode(self.token_cache)
|
||||||
|
|
||||||
|
# After symbol for a new line, we flush the cache.
|
||||||
|
if text.endswith("\n"):
|
||||||
|
printable_text = text[self.print_len :]
|
||||||
|
self.token_cache = []
|
||||||
|
self.print_len = 0
|
||||||
|
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
|
||||||
|
# which may change with the subsequent token -- there are probably smarter ways to do this!)
|
||||||
|
else:
|
||||||
|
printable_text = text[self.print_len : text.rfind(" ") + 1]
|
||||||
|
self.print_len += len(printable_text)
|
||||||
|
|
||||||
|
print(printable_text, flush=True, end="")
|
||||||
|
|
||||||
|
def end(self):
|
||||||
|
"""Flushes any remaining cache and prints a newline to stdout."""
|
||||||
|
# Flush the cache, if it exists
|
||||||
|
if len(self.token_cache) > 0:
|
||||||
|
text = self.tokenizer.decode(self.token_cache)
|
||||||
|
printable_text = text[self.print_len :]
|
||||||
|
self.token_cache = []
|
||||||
|
self.print_len = 0
|
||||||
|
else:
|
||||||
|
printable_text = ""
|
||||||
|
|
||||||
|
# Print a newline (and the remaining text, if any)
|
||||||
|
print(printable_text, flush=True)
|
||||||
@@ -18,7 +18,7 @@ import copy
|
|||||||
import inspect
|
import inspect
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -72,6 +72,10 @@ from .stopping_criteria import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .streamers import BaseStreamer
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -1116,6 +1120,7 @@ class GenerationMixin:
|
|||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||||
synced_gpus: Optional[bool] = None,
|
synced_gpus: Optional[bool] = None,
|
||||||
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||||
r"""
|
r"""
|
||||||
@@ -1165,6 +1170,9 @@ class GenerationMixin:
|
|||||||
Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
|
Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
|
||||||
`True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
|
`True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
|
||||||
generating before other GPUs. Otherwise it'll be set to `False`.
|
generating before other GPUs. Otherwise it'll be set to `False`.
|
||||||
|
streamer (`BaseStreamer`, *optional*):
|
||||||
|
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||||
|
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||||
|
|
||||||
kwargs:
|
kwargs:
|
||||||
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
||||||
@@ -1295,6 +1303,9 @@ class GenerationMixin:
|
|||||||
else:
|
else:
|
||||||
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
||||||
|
|
||||||
|
if streamer is not None:
|
||||||
|
streamer.put(input_ids.cpu())
|
||||||
|
|
||||||
# 6. Prepare `max_length` depending on other stopping criteria.
|
# 6. Prepare `max_length` depending on other stopping criteria.
|
||||||
input_ids_seq_length = input_ids.shape[-1]
|
input_ids_seq_length = input_ids.shape[-1]
|
||||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||||
@@ -1335,7 +1346,8 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
is_contrastive_search_gen_mode = (
|
is_contrastive_search_gen_mode = (
|
||||||
generation_config.top_k is not None
|
(generation_config.num_beams == 1)
|
||||||
|
and generation_config.top_k is not None
|
||||||
and generation_config.top_k > 1
|
and generation_config.top_k > 1
|
||||||
and generation_config.do_sample is False
|
and generation_config.do_sample is False
|
||||||
and generation_config.penalty_alpha is not None
|
and generation_config.penalty_alpha is not None
|
||||||
@@ -1384,6 +1396,11 @@ class GenerationMixin:
|
|||||||
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
|
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if streamer is not None and (generation_config.num_beams > 1):
|
||||||
|
raise ValueError(
|
||||||
|
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
|
||||||
|
)
|
||||||
|
|
||||||
if self.device.type != input_ids.device.type:
|
if self.device.type != input_ids.device.type:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"You are calling .generate() with the `input_ids` being on a device type different"
|
"You are calling .generate() with the `input_ids` being on a device type different"
|
||||||
@@ -1426,6 +1443,7 @@ class GenerationMixin:
|
|||||||
output_scores=generation_config.output_scores,
|
output_scores=generation_config.output_scores,
|
||||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||||
synced_gpus=synced_gpus,
|
synced_gpus=synced_gpus,
|
||||||
|
streamer=streamer,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1447,6 +1465,7 @@ class GenerationMixin:
|
|||||||
output_scores=generation_config.output_scores,
|
output_scores=generation_config.output_scores,
|
||||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||||
synced_gpus=synced_gpus,
|
synced_gpus=synced_gpus,
|
||||||
|
streamer=streamer,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1473,6 +1492,7 @@ class GenerationMixin:
|
|||||||
output_scores=generation_config.output_scores,
|
output_scores=generation_config.output_scores,
|
||||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||||
synced_gpus=synced_gpus,
|
synced_gpus=synced_gpus,
|
||||||
|
streamer=streamer,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1703,6 +1723,7 @@ class GenerationMixin:
|
|||||||
output_scores: Optional[bool] = None,
|
output_scores: Optional[bool] = None,
|
||||||
return_dict_in_generate: Optional[bool] = None,
|
return_dict_in_generate: Optional[bool] = None,
|
||||||
synced_gpus: Optional[bool] = False,
|
synced_gpus: Optional[bool] = False,
|
||||||
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
) -> Union[ContrastiveSearchOutput, torch.LongTensor]:
|
) -> Union[ContrastiveSearchOutput, torch.LongTensor]:
|
||||||
r"""
|
r"""
|
||||||
@@ -1750,6 +1771,9 @@ class GenerationMixin:
|
|||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
synced_gpus (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||||
|
streamer (`BaseStreamer`, *optional*):
|
||||||
|
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||||
|
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||||
model_kwargs:
|
model_kwargs:
|
||||||
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
|
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
|
||||||
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
|
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||||
@@ -2010,6 +2034,8 @@ class GenerationMixin:
|
|||||||
|
|
||||||
# update generated ids, model inputs, and length for next step
|
# update generated ids, model inputs, and length for next step
|
||||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||||
|
if streamer is not None:
|
||||||
|
streamer.put(next_tokens.cpu())
|
||||||
model_kwargs = self._update_model_kwargs_for_generation(
|
model_kwargs = self._update_model_kwargs_for_generation(
|
||||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||||
)
|
)
|
||||||
@@ -2027,6 +2053,9 @@ class GenerationMixin:
|
|||||||
else:
|
else:
|
||||||
this_peer_finished = True
|
this_peer_finished = True
|
||||||
|
|
||||||
|
if streamer is not None:
|
||||||
|
streamer.end()
|
||||||
|
|
||||||
if return_dict_in_generate:
|
if return_dict_in_generate:
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
return ContrastiveSearchEncoderDecoderOutput(
|
return ContrastiveSearchEncoderDecoderOutput(
|
||||||
@@ -2061,6 +2090,7 @@ class GenerationMixin:
|
|||||||
output_scores: Optional[bool] = None,
|
output_scores: Optional[bool] = None,
|
||||||
return_dict_in_generate: Optional[bool] = None,
|
return_dict_in_generate: Optional[bool] = None,
|
||||||
synced_gpus: Optional[bool] = False,
|
synced_gpus: Optional[bool] = False,
|
||||||
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
) -> Union[GreedySearchOutput, torch.LongTensor]:
|
) -> Union[GreedySearchOutput, torch.LongTensor]:
|
||||||
r"""
|
r"""
|
||||||
@@ -2105,6 +2135,9 @@ class GenerationMixin:
|
|||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
synced_gpus (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||||
|
streamer (`BaseStreamer`, *optional*):
|
||||||
|
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||||
|
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||||
model_kwargs:
|
model_kwargs:
|
||||||
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
|
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
|
||||||
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
|
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||||
@@ -2256,6 +2289,8 @@ class GenerationMixin:
|
|||||||
|
|
||||||
# update generated ids, model inputs, and length for next step
|
# update generated ids, model inputs, and length for next step
|
||||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||||
|
if streamer is not None:
|
||||||
|
streamer.put(next_tokens.cpu())
|
||||||
model_kwargs = self._update_model_kwargs_for_generation(
|
model_kwargs = self._update_model_kwargs_for_generation(
|
||||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||||
)
|
)
|
||||||
@@ -2273,6 +2308,9 @@ class GenerationMixin:
|
|||||||
else:
|
else:
|
||||||
this_peer_finished = True
|
this_peer_finished = True
|
||||||
|
|
||||||
|
if streamer is not None:
|
||||||
|
streamer.end()
|
||||||
|
|
||||||
if return_dict_in_generate:
|
if return_dict_in_generate:
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
return GreedySearchEncoderDecoderOutput(
|
return GreedySearchEncoderDecoderOutput(
|
||||||
@@ -2308,6 +2346,7 @@ class GenerationMixin:
|
|||||||
output_scores: Optional[bool] = None,
|
output_scores: Optional[bool] = None,
|
||||||
return_dict_in_generate: Optional[bool] = None,
|
return_dict_in_generate: Optional[bool] = None,
|
||||||
synced_gpus: Optional[bool] = False,
|
synced_gpus: Optional[bool] = False,
|
||||||
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
) -> Union[SampleOutput, torch.LongTensor]:
|
) -> Union[SampleOutput, torch.LongTensor]:
|
||||||
r"""
|
r"""
|
||||||
@@ -2354,6 +2393,9 @@ class GenerationMixin:
|
|||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
synced_gpus (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||||
|
streamer (`BaseStreamer`, *optional*):
|
||||||
|
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||||
|
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||||
model_kwargs:
|
model_kwargs:
|
||||||
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
||||||
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||||
@@ -2525,6 +2567,8 @@ class GenerationMixin:
|
|||||||
|
|
||||||
# update generated ids, model inputs, and length for next step
|
# update generated ids, model inputs, and length for next step
|
||||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||||
|
if streamer is not None:
|
||||||
|
streamer.put(next_tokens.cpu())
|
||||||
model_kwargs = self._update_model_kwargs_for_generation(
|
model_kwargs = self._update_model_kwargs_for_generation(
|
||||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||||
)
|
)
|
||||||
@@ -2542,6 +2586,9 @@ class GenerationMixin:
|
|||||||
else:
|
else:
|
||||||
this_peer_finished = True
|
this_peer_finished = True
|
||||||
|
|
||||||
|
if streamer is not None:
|
||||||
|
streamer.end()
|
||||||
|
|
||||||
if return_dict_in_generate:
|
if return_dict_in_generate:
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
return SampleEncoderDecoderOutput(
|
return SampleEncoderDecoderOutput(
|
||||||
|
|||||||
44
tests/generation/test_streamers.py
Normal file
44
tests/generation/test_streamers.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023 The HuggingFace Team Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a clone of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer, TextStreamer, is_torch_available
|
||||||
|
from transformers.testing_utils import CaptureStdout, require_torch, torch_device
|
||||||
|
|
||||||
|
from ..test_modeling_common import ids_tensor
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class StreamerTester(unittest.TestCase):
|
||||||
|
def test_text_streamer_stdout(self):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||||
|
model.config.eos_token_id = -1
|
||||||
|
|
||||||
|
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
|
||||||
|
greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False)
|
||||||
|
greedy_text = tokenizer.decode(greedy_ids[0])
|
||||||
|
|
||||||
|
with CaptureStdout() as cs:
|
||||||
|
streamer = TextStreamer(tokenizer)
|
||||||
|
model.generate(input_ids, max_new_tokens=10, do_sample=False, streamer=streamer)
|
||||||
|
|
||||||
|
# The greedy text should be printed to stdout, except for the final "\n" in the streamer
|
||||||
|
self.assertEqual(cs.out[:-1], greedy_text)
|
||||||
Reference in New Issue
Block a user