* first draft * show design proposition for new generate method * up * make better readable * make first version * gpt2 tests pass * make beam search for gpt2 work * add first encoder-decoder code * delete typo * make t5 work * save indermediate * make bart work with beam search * finish beam search bart / t5 * add default kwargs * make more tests pass * fix no bad words sampler * some fixes and tests for all distribution processors * fix test * fix rag slow tests * merge to master * add nograd to generate * make all slow tests pass * speed up generate * fix edge case bug * small fix * correct typo * add type hints and docstrings * fix typos in tests * add beam search tests * add tests for beam scorer * fix test rag * finish beam search tests * move generation tests in seperate file * fix generation tests * more tests * add aggressive generation tests * fix tests * add gpt2 sample test * add more docstring * add more docs * finish doc strings * apply some more of sylvains and sams comments * fix some typos * make fix copies * apply lysandres and sylvains comments * final corrections on examples * small fix for reformer
1216 lines
57 KiB
Python
1216 lines
57 KiB
Python
# coding=utf-8
|
|
# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
|
|
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
|
|
|
import torch
|
|
from torch.nn import functional as F
|
|
|
|
from .file_utils import ModelOutput
|
|
from .generation_beam_search import BeamScorer, BeamSearchScorer
|
|
from .generation_logits_process import (
|
|
LogitsProcessorList,
|
|
MinLengthLogitsProcessor,
|
|
NoBadWordsLogitsProcessor,
|
|
NoRepeatNGramLogitsProcessor,
|
|
RepetitionPenaltyLogitsProcessor,
|
|
TemperatureLogitsWarper,
|
|
TopKLogitsWarper,
|
|
TopPLogitsWarper,
|
|
)
|
|
from .utils import logging
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class GenerationMixin:
|
|
"""
|
|
A class containing all of the functions supporting generation, to be used as a mixin in
|
|
:class:`~transformers.PreTrainedModel`.
|
|
"""
|
|
|
|
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
|
|
"""
|
|
Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to prepare inputs in the
|
|
generate method.
|
|
"""
|
|
return {"input_ids": input_ids}
|
|
|
|
def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
|
"""
|
|
Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to adjust the logits in
|
|
the generate method.
|
|
"""
|
|
return logits
|
|
|
|
def _prepare_input_ids_for_generation(self, bos_token_id: int) -> torch.LongTensor:
|
|
if bos_token_id is None:
|
|
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
|
|
return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id
|
|
|
|
def _prepare_attention_mask_for_generation(
|
|
self, input_ids: torch.Tensor, pad_token_id: int, eos_token_id: int
|
|
) -> torch.LongTensor:
|
|
is_pad_token_in_inputs_ids = (pad_token_id is not None) and (pad_token_id in input_ids)
|
|
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
|
|
(eos_token_id is not None) and (pad_token_id != eos_token_id)
|
|
)
|
|
if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id:
|
|
return input_ids.ne(pad_token_id).long()
|
|
return input_ids.new_ones(input_ids.shape)
|
|
|
|
def _prepare_encoder_decoder_kwargs_for_generation(
|
|
self, input_ids: torch.LongTensor, model_kwargs
|
|
) -> Dict[str, Any]:
|
|
# retrieve encoder hidden states
|
|
encoder = self.get_encoder()
|
|
encoder_kwargs = {
|
|
argument: value for argument, value in model_kwargs.items() if not argument.startswith("decoder_")
|
|
}
|
|
model_kwargs["encoder_outputs"]: ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs)
|
|
return model_kwargs
|
|
|
|
def _prepare_decoder_input_ids_for_generation(
|
|
self, input_ids: torch.LongTensor, decoder_start_token_id: int = None, bos_token_id: int = None, **model_kwargs
|
|
) -> torch.LongTensor:
|
|
|
|
if "decoder_input_ids" in model_kwargs:
|
|
return model_kwargs["decoder_input_ids"]
|
|
|
|
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
|
|
decoder_input_ids = (
|
|
torch.ones((input_ids.shape[0], 1), dtype=input_ids.dtype, device=input_ids.device)
|
|
* decoder_start_token_id
|
|
)
|
|
return decoder_input_ids
|
|
|
|
def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int:
|
|
if pad_token_id is None and eos_token_id is not None:
|
|
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
|
pad_token_id = eos_token_id
|
|
return pad_token_id
|
|
|
|
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
|
|
decoder_start_token_id = (
|
|
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
|
|
)
|
|
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
|
|
|
if decoder_start_token_id is not None:
|
|
return decoder_start_token_id
|
|
elif (
|
|
hasattr(self.config, "decoder")
|
|
and hasattr(self.config.decoder, "decoder_start_token_id")
|
|
and self.config.decoder.decoder_start_token_id is not None
|
|
):
|
|
return self.config.decoder.decoder_start_token_id
|
|
elif bos_token_id is not None:
|
|
return bos_token_id
|
|
elif (
|
|
hasattr(self.config, "decoder")
|
|
and hasattr(self.config.decoder, "bos_token_id")
|
|
and self.config.decoder.bos_token_id is not None
|
|
):
|
|
return self.config.decoder.bos_token_id
|
|
raise ValueError(
|
|
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
|
|
)
|
|
|
|
@staticmethod
|
|
def _expand_inputs_for_generation(
|
|
input_ids: torch.LongTensor,
|
|
expand_size: int = 1,
|
|
is_encoder_decoder: bool = False,
|
|
attention_mask: torch.LongTensor = None,
|
|
encoder_outputs: ModelOutput = None,
|
|
**model_kwargs
|
|
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
|
|
expanded_return_idx = (
|
|
torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
|
|
)
|
|
input_ids = input_ids.index_select(0, expanded_return_idx)
|
|
|
|
if attention_mask is not None:
|
|
model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
|
|
|
|
if is_encoder_decoder:
|
|
assert encoder_outputs is not None
|
|
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
|
|
0, expanded_return_idx
|
|
)
|
|
model_kwargs["encoder_outputs"] = encoder_outputs
|
|
return input_ids, model_kwargs
|
|
|
|
@staticmethod
|
|
def _init_sequence_length_for_generation(
|
|
input_ids: torch.LongTensor, max_length: int
|
|
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
|
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
|
sequence_lengths = input_ids.new(input_ids.shape[0]).fill_(max_length)
|
|
|
|
cur_len = input_ids.shape[-1]
|
|
return sequence_lengths, unfinished_sequences, cur_len
|
|
|
|
@staticmethod
|
|
def _update_seq_length_for_generation(
|
|
sequence_lengths: torch.LongTensor,
|
|
unfinished_sequences: torch.LongTensor,
|
|
cur_len: int,
|
|
is_eos_in_next_token: torch.BoolTensor,
|
|
) -> Tuple[torch.LongTensor, torch.LongTensor]:
|
|
# check if sentence is not finished yet
|
|
is_sent_unfinished = unfinished_sequences.mul(is_eos_in_next_token.long()).bool()
|
|
|
|
# update sentence length
|
|
sequence_lengths = sequence_lengths.masked_fill(is_sent_unfinished, cur_len)
|
|
unfinished_sequences = unfinished_sequences.mul((~is_eos_in_next_token).long())
|
|
return sequence_lengths, unfinished_sequences
|
|
|
|
@staticmethod
|
|
def _update_model_kwargs_for_generation(
|
|
outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
|
|
) -> Dict[str, Any]:
|
|
# update past
|
|
if "past_key_values" in outputs:
|
|
model_kwargs["past"] = outputs.past_key_values
|
|
elif "mems" in outputs:
|
|
model_kwargs["past"] = outputs.mems
|
|
elif "past_buckets_states" in outputs:
|
|
model_kwargs["past"] = outputs.past_buckets_states
|
|
else:
|
|
model_kwargs["past"] = None
|
|
|
|
# update attention mask
|
|
if not is_encoder_decoder:
|
|
if "attention_mask" in model_kwargs:
|
|
attention_mask = model_kwargs["attention_mask"]
|
|
model_kwargs["attention_mask"] = torch.cat(
|
|
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
|
)
|
|
|
|
return model_kwargs
|
|
|
|
@staticmethod
|
|
def _reorder_cache(past: Tuple[torch.Tensor], beam_idx: torch.Tensor) -> Tuple[torch.Tensor]:
|
|
"""
|
|
This function is used to re-order the :obj:`past_key_values` or :obj:`mems` cache if
|
|
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
|
|
called. This is required to match :obj:`past_key_values` or :obj:`mems` with the correct beam_idx at every
|
|
generation step.
|
|
|
|
For custom re-ordering of :obj:`past_key_values` or :obj:`mems`, the function should be implemented in
|
|
subclasses of :class:`~transformers.PreTrainedModel`.
|
|
"""
|
|
return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
|
|
|
|
def _get_logits_warper(
|
|
self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None
|
|
) -> LogitsProcessorList:
|
|
"""
|
|
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
|
|
:obj:`~transformers.LogitsWarper` instances used for multinomial sampling.
|
|
"""
|
|
|
|
# init warp parameters
|
|
top_k = top_k if top_k is not None else self.config.top_k
|
|
top_p = top_p if top_p is not None else self.config.top_p
|
|
temperature = temperature if temperature is not None else self.config.temperature
|
|
# instantiate warpers list
|
|
warpers = LogitsProcessorList()
|
|
|
|
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
|
# all samplers can be found in `generation_utils_samplers.py`
|
|
if top_k is not None and top_k != 0:
|
|
warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
|
|
if top_p is not None and top_p < 1.0:
|
|
warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
|
|
if temperature is not None and temperature != 1.0:
|
|
warpers.append(TemperatureLogitsWarper(temperature))
|
|
return warpers
|
|
|
|
def _get_logits_processor(
|
|
self,
|
|
repetition_penalty: float,
|
|
no_repeat_ngram_size: int,
|
|
bad_words_ids: List[List[int]],
|
|
min_length: int,
|
|
eos_token_id: int,
|
|
) -> LogitsProcessorList:
|
|
"""
|
|
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
|
|
:obj:`~transformers.LogitsProcessor` instances used to modify the scores of the language model head.
|
|
"""
|
|
|
|
# init warp parameters
|
|
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
|
|
no_repeat_ngram_size = (
|
|
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
|
|
)
|
|
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
|
|
min_length = min_length if min_length is not None else self.config.min_length
|
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
|
# instantiate processors list
|
|
processors = LogitsProcessorList()
|
|
|
|
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
|
# all samplers can be found in `generation_utils_samplers.py`
|
|
if repetition_penalty is not None and repetition_penalty != 1.0:
|
|
processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
|
if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
|
|
processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
|
|
if bad_words_ids is not None:
|
|
processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
|
|
if min_length is not None and eos_token_id is not None and min_length > -1:
|
|
processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
|
|
return processors
|
|
|
|
@torch.no_grad()
|
|
def generate(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
max_length: Optional[int] = None,
|
|
min_length: Optional[int] = None,
|
|
do_sample: Optional[bool] = None,
|
|
early_stopping: Optional[bool] = None,
|
|
num_beams: Optional[int] = None,
|
|
temperature: Optional[float] = None,
|
|
top_k: Optional[int] = None,
|
|
top_p: Optional[float] = None,
|
|
repetition_penalty: Optional[float] = None,
|
|
bad_words_ids: Optional[Iterable[int]] = None,
|
|
bos_token_id: Optional[int] = None,
|
|
pad_token_id: Optional[int] = None,
|
|
eos_token_id: Optional[int] = None,
|
|
length_penalty: Optional[float] = None,
|
|
no_repeat_ngram_size: Optional[int] = None,
|
|
num_return_sequences: Optional[int] = None,
|
|
decoder_start_token_id: Optional[int] = None,
|
|
use_cache: Optional[bool] = None,
|
|
**model_kwargs
|
|
) -> torch.LongTensor:
|
|
r"""
|
|
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
|
|
multinomial sampling, beam-search decoding, and beam-search multinomial sampling.
|
|
|
|
Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the
|
|
attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values
|
|
indicated are the default values of those config.
|
|
|
|
Most of these parameters are explained in more detail in `this blog post
|
|
<https://huggingface.co/blog/how-to-generate>`__.
|
|
|
|
Parameters:
|
|
|
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
|
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
|
|
:obj:`torch.LongTensor` of shape :obj:`(1,)`.
|
|
max_length (:obj:`int`, `optional`, defaults to 20):
|
|
The maximum length of the sequence to be generated.
|
|
min_length (:obj:`int`, `optional`, defaults to 10):
|
|
The minimum length of the sequence to be generated.
|
|
do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
|
Whether or not to use sampling ; use greedy decoding otherwise.
|
|
early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
|
Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not.
|
|
num_beams (:obj:`int`, `optional`, defaults to 1):
|
|
Number of beams for beam search. 1 means no beam search.
|
|
temperature (:obj:`float`, `optional`, defaults tp 1.0):
|
|
The value used to module the next token probabilities.
|
|
top_k (:obj:`int`, `optional`, defaults to 50):
|
|
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
|
top_p (:obj:`float`, `optional`, defaults to 1.0):
|
|
If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or
|
|
higher are kept for generation.
|
|
repetition_penalty (:obj:`float`, `optional`, defaults to 1.0):
|
|
The parameter for repetition penalty. 1.0 means no penalty. See `this paper
|
|
<https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
|
|
pad_token_id (:obj:`int`, `optional`):
|
|
The id of the `padding` token.
|
|
bos_token_id (:obj:`int`, `optional`):
|
|
The id of the `beginning-of-sequence` token.
|
|
eos_token_id (:obj:`int`, `optional`):
|
|
The id of the `end-of-sequence` token.
|
|
length_penalty (:obj:`float`, `optional`, defaults to 1.0):
|
|
Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the
|
|
model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer
|
|
sequences.
|
|
no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
|
|
If set to int > 0, all ngrams of that size can only occur once.
|
|
bad_words_ids(:obj:`List[List[int]]`, `optional`):
|
|
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
|
|
should not appear in the generated text, use :obj:`tokenizer(bad_word,
|
|
add_prefix_space=True).input_ids`.
|
|
num_return_sequences(:obj:`int`, `optional`, defaults to 1):
|
|
The number of independently computed returned sequences for each element in the batch.
|
|
attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
|
Mask to avoid performing attention on padding token indices. Mask values are in ``[0, 1]``, 1 for
|
|
tokens that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same
|
|
shape as :obj:`input_ids` that masks the pad token. `What are attention masks?
|
|
<../glossary.html#attention-mask>`__
|
|
decoder_start_token_id (:obj:`int`, `optional`):
|
|
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
|
|
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
|
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
|
|
speed up decoding.
|
|
model_kwargs:
|
|
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the
|
|
model is an Encoder-Decoder model, encoder specific kwargs should not be prefixed and decoder specific
|
|
kwargs should be prefixed with `decoder_`.
|
|
|
|
Return:
|
|
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
|
|
sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all
|
|
batches finished early due to the :obj:`eos_token_id`.
|
|
|
|
Examples::
|
|
|
|
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
|
|
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
|
|
>>> # do greedy decoding without providing a prompt
|
|
>>> outputs = model.generate(max_length=40)
|
|
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
|
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
|
|
>>> document = (
|
|
... "at least two people were killed in a suspected bomb attack on a passenger bus "
|
|
... "in the strife-torn southern philippines on monday , the military said."
|
|
... )
|
|
>>> # encode input contex
|
|
>>> input_ids = tokenizer(document, return_tensors="pt").input_ids
|
|
>>> # generate 3 independent sequences using beam search decoding (5 beams)
|
|
>>> # with T5 encoder-decoder model conditioned on short news article.
|
|
>>> outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3)
|
|
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
|
|
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
|
|
>>> input_context = "The dog"
|
|
>>> # encode input context
|
|
>>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
|
|
>>> # generate 3 candidates using sampling
|
|
>>> outputs = model.generate(input_ids=input_ids, max_length=20, num_return_sequences=3, do_sample=True)
|
|
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("ctrl")
|
|
>>> model = AutoModelForCausalLM.from_pretrained("ctrl")
|
|
>>> # "Legal" is one of the control codes for ctrl
|
|
>>> input_context = "Legal My neighbor is"
|
|
>>> # encode input context
|
|
>>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
|
|
>>> outputs = model.generate(input_ids=input_ids, max_length=20, repetition_penalty=1.2)
|
|
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
|
>>> input_context = "My cute dog"
|
|
>>> # get tokens of words that should not be generated
|
|
>>> bad_words_ids = [tokenizer(bad_word, add_prefix_space=True).input_ids for bad_word in ["idiot", "stupid", "shut up"]]
|
|
>>> # encode input context
|
|
>>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
|
|
>>> # generate sequences without allowing bad_words to be generated
|
|
>>> outputs = model.generate(input_ids=input_ids, max_length=20, do_sample=True, bad_words_ids=bad_words_ids)
|
|
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
|
|
"""
|
|
|
|
# set init values
|
|
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
|
max_length = max_length if max_length is not None else self.config.max_length
|
|
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
|
num_return_sequences = (
|
|
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
|
)
|
|
|
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
|
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
|
if input_ids is None:
|
|
# init `input_ids` with bos_token_id
|
|
input_ids = self._prepare_input_ids_for_generation(bos_token_id)
|
|
|
|
if model_kwargs.get("attention_mask", None) is None:
|
|
# init `attention_mask` depending on `pad_token_id`
|
|
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
|
input_ids, pad_token_id, eos_token_id
|
|
)
|
|
|
|
# special case if pad_token_id is not defined
|
|
if pad_token_id is None and eos_token_id is not None:
|
|
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
|
pad_token_id = eos_token_id
|
|
|
|
if self.config.is_encoder_decoder:
|
|
# add encoder_outputs to model_kwargs
|
|
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
|
|
|
|
# set input_ids as decoder_input_ids
|
|
input_ids = self._prepare_decoder_input_ids_for_generation(
|
|
input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id, **model_kwargs
|
|
)
|
|
|
|
if "encoder_outputs" not in model_kwargs or not isinstance(model_kwargs["encoder_outputs"], ModelOutput):
|
|
raise ValueError("Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`.")
|
|
|
|
# determine generation mode
|
|
is_greedy_gen_mode = (num_beams == 1) and do_sample is False
|
|
is_sample_gen_mode = (num_beams == 1) and do_sample is True
|
|
is_beam_gen_mode = (num_beams > 1) and do_sample is False
|
|
is_beam_sample_gen_mode = (num_beams > 1) and do_sample is True
|
|
|
|
# set model_kwargs
|
|
model_kwargs["use_cache"] = use_cache
|
|
|
|
# get distribution pre_processing samplers
|
|
logits_processor = self._get_logits_processor(
|
|
repetition_penalty=repetition_penalty,
|
|
no_repeat_ngram_size=no_repeat_ngram_size,
|
|
bad_words_ids=bad_words_ids,
|
|
min_length=min_length,
|
|
eos_token_id=eos_token_id,
|
|
)
|
|
|
|
if is_greedy_gen_mode:
|
|
if num_return_sequences > 1:
|
|
raise ValueError(
|
|
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
|
|
)
|
|
|
|
# greedy search
|
|
return self.greedy_search(
|
|
input_ids,
|
|
logits_processor=logits_processor,
|
|
max_length=max_length,
|
|
pad_token_id=pad_token_id,
|
|
eos_token_id=eos_token_id,
|
|
**model_kwargs,
|
|
)
|
|
|
|
elif is_sample_gen_mode:
|
|
# get probability distribution warper
|
|
logits_warper = self._get_logits_warper(
|
|
top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams
|
|
)
|
|
|
|
# expand input_ids with `num_return_sequences` additional sequences per batch
|
|
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
|
input_ids,
|
|
expand_size=num_return_sequences,
|
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
|
**model_kwargs,
|
|
)
|
|
|
|
# sample
|
|
return self.sample(
|
|
input_ids,
|
|
logits_processor=logits_processor,
|
|
logits_warper=logits_warper,
|
|
max_length=max_length,
|
|
pad_token_id=pad_token_id,
|
|
eos_token_id=eos_token_id,
|
|
**model_kwargs,
|
|
)
|
|
|
|
elif is_beam_gen_mode:
|
|
batch_size = input_ids.shape[0]
|
|
|
|
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
|
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
|
|
|
if num_return_sequences > num_beams:
|
|
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
|
|
|
beam_scorer = BeamSearchScorer(
|
|
batch_size=batch_size,
|
|
max_length=max_length,
|
|
num_beams=num_beams,
|
|
device=self.device,
|
|
length_penalty=length_penalty,
|
|
do_early_stopping=early_stopping,
|
|
num_beam_hyps_to_keep=num_return_sequences,
|
|
)
|
|
# interleave with `num_beams`
|
|
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
|
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
|
|
)
|
|
return self.beam_search(
|
|
input_ids,
|
|
beam_scorer,
|
|
logits_processor=logits_processor,
|
|
max_length=max_length,
|
|
pad_token_id=pad_token_id,
|
|
eos_token_id=eos_token_id,
|
|
**model_kwargs,
|
|
)
|
|
|
|
elif is_beam_sample_gen_mode:
|
|
logits_warper = self._get_logits_warper(
|
|
top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams
|
|
)
|
|
|
|
batch_size = input_ids.shape[0] * num_return_sequences
|
|
|
|
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
|
beam_scorer = BeamSearchScorer(
|
|
batch_size=batch_size,
|
|
max_length=max_length,
|
|
num_beams=num_beams,
|
|
device=self.device,
|
|
length_penalty=length_penalty,
|
|
do_early_stopping=early_stopping,
|
|
)
|
|
|
|
# interleave with `num_beams * num_return_sequences`
|
|
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
|
input_ids,
|
|
expand_size=num_beams * num_return_sequences,
|
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
|
**model_kwargs,
|
|
)
|
|
|
|
return self.beam_sample(
|
|
input_ids,
|
|
beam_scorer,
|
|
logits_processor=logits_processor,
|
|
logits_warper=logits_warper,
|
|
max_length=max_length,
|
|
pad_token_id=pad_token_id,
|
|
eos_token_id=eos_token_id,
|
|
**model_kwargs,
|
|
)
|
|
|
|
def greedy_search(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
logits_processor: Optional[LogitsProcessorList] = None,
|
|
max_length: Optional[int] = None,
|
|
pad_token_id: Optional[int] = None,
|
|
eos_token_id: Optional[int] = None,
|
|
**model_kwargs
|
|
):
|
|
r"""
|
|
Generates sequences for models with a language modeling head using greedy decoding.
|
|
|
|
Parameters:
|
|
|
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
|
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
|
|
:obj:`torch.LongTensor` of shape :obj:`(1,)`.
|
|
logits_processor (:obj:`LogitsProcessorList`, `optional`):
|
|
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
|
|
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
|
|
head applied at each generation step.
|
|
max_length (:obj:`int`, `optional`, defaults to 20):
|
|
The maximum length of the sequence to be generated.
|
|
pad_token_id (:obj:`int`, `optional`):
|
|
The id of the `padding` token.
|
|
eos_token_id (:obj:`int`, `optional`):
|
|
The id of the `end-of-sequence` token.
|
|
model_kwargs:
|
|
Additional model specific keyword arguments will be forwarded to the :obj:`forward` function of the
|
|
model. If model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.
|
|
|
|
Return:
|
|
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
|
|
sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all
|
|
batches finished early due to the :obj:`eos_token_id`.
|
|
|
|
Examples::
|
|
|
|
>>> from transformers import (
|
|
... AutoTokenizer,
|
|
... AutoModelForCausalLM,
|
|
... LogitsProcessorList,
|
|
... MinLengthLogitsProcessor,
|
|
... )
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
|
|
|
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
|
|
>>> model.config.pad_token_id = model.config.eos_token_id
|
|
|
|
>>> input_prompt = "Today is a beautiful day, and"
|
|
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
|
|
|
|
>>> # instantiate logits processors
|
|
>>> logits_processor = LogitsProcessorList([
|
|
... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id),
|
|
... ])
|
|
|
|
>>> outputs = model.greedy_search(input_ids, logits_processor=logits_processor)
|
|
|
|
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
|
"""
|
|
|
|
# init values
|
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
|
max_length = max_length if max_length is not None else self.config.max_length
|
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
|
|
|
# init sequence length tensors
|
|
sequence_lengths, unfinished_sequences, cur_len = self._init_sequence_length_for_generation(
|
|
input_ids, max_length
|
|
)
|
|
|
|
while cur_len < max_length:
|
|
# prepare model inputs
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
|
|
|
# forward pass to get next token
|
|
outputs = self(**model_inputs, return_dict=True)
|
|
next_token_logits = outputs.logits[:, -1, :]
|
|
|
|
# pre-process distribution
|
|
scores = logits_processor(input_ids, next_token_logits)
|
|
|
|
# argmax
|
|
next_tokens = torch.argmax(scores, dim=-1)
|
|
|
|
# add code that transfomers next_tokens to tokens_to_add
|
|
if eos_token_id is not None:
|
|
assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined."
|
|
next_tokens = next_tokens * unfinished_sequences + (pad_token_id) * (1 - unfinished_sequences)
|
|
|
|
# add token and increase length by one
|
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
|
|
|
# update sequence length
|
|
if eos_token_id is not None:
|
|
sequence_lengths, unfinished_sequences = self._update_seq_length_for_generation(
|
|
sequence_lengths, unfinished_sequences, cur_len, next_tokens == eos_token_id
|
|
)
|
|
|
|
# update model kwargs
|
|
model_kwargs = self._update_model_kwargs_for_generation(
|
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
|
)
|
|
|
|
# stop when there is a </s> in each sentence, or if we exceed the maximul length
|
|
if unfinished_sequences.max() == 0:
|
|
break
|
|
|
|
# increase cur_len
|
|
cur_len = cur_len + 1
|
|
|
|
return input_ids
|
|
|
|
def sample(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
logits_processor: Optional[LogitsProcessorList] = None,
|
|
logits_warper: Optional[LogitsProcessorList] = None,
|
|
max_length: Optional[int] = None,
|
|
pad_token_id: Optional[int] = None,
|
|
eos_token_id: Optional[int] = None,
|
|
**model_kwargs
|
|
):
|
|
r"""
|
|
Generates sequences for models with a language modeling head using multinomial sampling.
|
|
|
|
Parameters:
|
|
|
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
|
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
|
|
:obj:`torch.LongTensor` of shape :obj:`(1,)`.
|
|
logits_processor (:obj:`LogitsProcessorList`, `optional`):
|
|
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
|
|
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
|
|
head applied at each generation step.
|
|
logits_warper (:obj:`LogitsProcessorList`, `optional`):
|
|
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
|
|
:class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language
|
|
modeling head applied before multinomial sampling at each generation step.
|
|
max_length (:obj:`int`, `optional`, defaults to 20):
|
|
The maximum length of the sequence to be generated.
|
|
pad_token_id (:obj:`int`, `optional`):
|
|
The id of the `padding` token.
|
|
eos_token_id (:obj:`int`, `optional`):
|
|
The id of the `end-of-sequence` token.
|
|
model_kwargs:
|
|
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If
|
|
model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.
|
|
|
|
Return:
|
|
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
|
|
sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all
|
|
batches finished early due to the :obj:`eos_token_id`.
|
|
|
|
Examples::
|
|
|
|
>>> from transformers import (
|
|
... AutoTokenizer,
|
|
... AutoModelForCausalLM,
|
|
... LogitsProcessorList,
|
|
... MinLengthLogitsProcessor,
|
|
... TopKLogitsWarper,
|
|
... TemperatureLogitsWarper,
|
|
... )
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
|
|
|
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
|
|
>>> model.config.pad_token_id = model.config.eos_token_id
|
|
|
|
>>> input_prompt = "Today is a beautiful day, and"
|
|
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
|
|
|
|
>>> # instantiate logits processors
|
|
>>> logits_processor = LogitsProcessorList([
|
|
... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id),
|
|
... ])
|
|
>>> # instantiate logits processors
|
|
>>> logits_warper = LogitsProcessorList([
|
|
... TopKLogitsWarper(50),
|
|
... TemperatureLogitsWarper(0.7),
|
|
... ])
|
|
|
|
>>> outputs = model.sample(input_ids, logits_processor=logits_processor, logits_warper=logits_warper)
|
|
|
|
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
|
"""
|
|
|
|
# init values
|
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
|
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
|
max_length = max_length if max_length is not None else self.config.max_length
|
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
|
|
|
# init sequence length tensors
|
|
sequence_lengths, unfinished_sequences, cur_len = self._init_sequence_length_for_generation(
|
|
input_ids, max_length
|
|
)
|
|
|
|
# auto-regressive generation
|
|
while cur_len < max_length:
|
|
# prepare model inputs
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
|
|
|
# forward pass to get next token
|
|
outputs = self(**model_inputs, return_dict=True)
|
|
next_token_logits = outputs.logits[:, -1, :]
|
|
|
|
# pre-process distribution
|
|
scores = logits_processor(input_ids, next_token_logits)
|
|
scores = logits_warper(input_ids, scores)
|
|
|
|
# sample
|
|
probs = F.softmax(scores, dim=-1)
|
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
|
|
|
# add code that transfomers next_tokens to tokens_to_add
|
|
if eos_token_id is not None:
|
|
assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined."
|
|
next_tokens = next_tokens * unfinished_sequences + (pad_token_id) * (1 - unfinished_sequences)
|
|
|
|
# add token and increase length by one
|
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
|
cur_len = cur_len + 1
|
|
|
|
# update sequence length
|
|
if eos_token_id is not None:
|
|
sequence_lengths, unfinished_sequences = self._update_seq_length_for_generation(
|
|
sequence_lengths, unfinished_sequences, cur_len, next_tokens == eos_token_id
|
|
)
|
|
|
|
# stop when there is a </s> in each sentence, or if we exceed the maximul length
|
|
if unfinished_sequences.max() == 0:
|
|
break
|
|
|
|
# update model kwargs
|
|
model_kwargs = self._update_model_kwargs_for_generation(
|
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
|
)
|
|
|
|
return input_ids
|
|
|
|
def beam_search(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
beam_scorer: BeamScorer,
|
|
logits_processor: Optional[LogitsProcessorList] = None,
|
|
max_length: Optional[int] = None,
|
|
pad_token_id: Optional[int] = None,
|
|
eos_token_id: Optional[int] = None,
|
|
**model_kwargs
|
|
):
|
|
r"""
|
|
Generates sequences for models with a language modeling head using beam search decoding.
|
|
|
|
Parameters:
|
|
|
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
|
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
|
|
:obj:`torch.LongTensor` of shape :obj:`(1,)`.
|
|
beam_scorer (:obj:`BeamScorer`):
|
|
An derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are
|
|
constructed, stored and sorted during generation. For more information, the documentation of
|
|
:class:`~transformers.BeamScorer` should be read.
|
|
logits_processor (:obj:`LogitsProcessorList`, `optional`):
|
|
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
|
|
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
|
|
head applied at each generation step.
|
|
max_length (:obj:`int`, `optional`, defaults to 20):
|
|
The maximum length of the sequence to be generated.
|
|
pad_token_id (:obj:`int`, `optional`):
|
|
The id of the `padding` token.
|
|
eos_token_id (:obj:`int`, `optional`):
|
|
The id of the `end-of-sequence` token.
|
|
model_kwargs:
|
|
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If
|
|
model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.
|
|
|
|
Return:
|
|
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
|
|
sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all
|
|
batches finished early due to the :obj:`eos_token_id`.
|
|
|
|
Examples::
|
|
|
|
>>> from transformers import (
|
|
... AutoTokenizer,
|
|
... AutoModelForSeq2SeqLM,
|
|
... LogitsProcessorList,
|
|
... MinLengthLogitsProcessor,
|
|
... BeamSearchScorer,
|
|
... )
|
|
>>> import torch
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
|
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
|
|
|
|
>>> encoder_input_str = "translate English to German: How old are you?"
|
|
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
|
|
|
|
|
|
>>> # lets run beam search using 3 beams
|
|
>>> num_beams = 3
|
|
>>> # define decoder start token ids
|
|
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
|
|
>>> input_ids = input_ids * model.config.decoder_start_token_id
|
|
|
|
>>> # add encoder_outputs to model keyword arguments
|
|
>>> model_kwargs = {
|
|
... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True)
|
|
... }
|
|
|
|
>>> # instantiate beam scorer
|
|
>>> beam_scorer = BeamSearchScorer(
|
|
... batch_size=1,
|
|
... max_length=model.config.max_length,
|
|
... num_beams=num_beams,
|
|
... device=model.device,
|
|
... )
|
|
|
|
>>> # instantiate logits processors
|
|
>>> logits_processor = LogitsProcessorList([
|
|
... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
|
|
... ])
|
|
|
|
>>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
|
|
|
|
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
|
"""
|
|
|
|
# init values
|
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
|
max_length = max_length if max_length is not None else self.config.max_length
|
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
|
|
|
batch_size = len(beam_scorer._beam_hyps)
|
|
num_beams = beam_scorer.num_beams
|
|
|
|
batch_beam_size, cur_len = input_ids.shape
|
|
|
|
assert (
|
|
num_beams * batch_size == batch_beam_size
|
|
), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
|
|
|
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
|
beam_scores[:, 1:] = -1e9
|
|
beam_scores = beam_scores.view((batch_size * num_beams,))
|
|
|
|
while cur_len < max_length:
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
|
|
|
outputs = self(**model_inputs, return_dict=True)
|
|
next_token_logits = outputs.logits[:, -1, :]
|
|
|
|
# adjust tokens for Bart, *e.g.*
|
|
next_token_logits = self.adjust_logits_during_generation(
|
|
next_token_logits, cur_len=cur_len, max_length=max_length
|
|
)
|
|
|
|
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
|
|
|
next_token_scores = logits_processor(input_ids, next_token_scores)
|
|
next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
|
|
# reshape for beam search
|
|
vocab_size = next_token_scores.shape[-1]
|
|
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
|
|
|
|
next_token_scores, next_tokens = torch.topk(
|
|
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
|
|
)
|
|
|
|
next_indices = next_tokens // vocab_size
|
|
next_tokens = next_tokens % vocab_size
|
|
|
|
# stateless
|
|
beam_outputs = beam_scorer.process(
|
|
input_ids,
|
|
next_token_scores,
|
|
next_tokens,
|
|
next_indices,
|
|
pad_token_id=pad_token_id,
|
|
eos_token_id=eos_token_id,
|
|
)
|
|
beam_scores = beam_outputs["next_beam_scores"]
|
|
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
|
beam_idx = beam_outputs["next_beam_indices"]
|
|
|
|
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
|
cur_len = cur_len + 1
|
|
|
|
model_kwargs = self._update_model_kwargs_for_generation(
|
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
|
)
|
|
if model_kwargs["past"] is not None:
|
|
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
|
|
|
|
if beam_scorer.is_done:
|
|
break
|
|
|
|
decoded = beam_scorer.finalize(
|
|
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
|
|
)
|
|
|
|
return decoded
|
|
|
|
def beam_sample(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
beam_scorer: BeamScorer,
|
|
logits_processor: Optional[LogitsProcessorList] = None,
|
|
logits_warper: Optional[LogitsProcessorList] = None,
|
|
max_length: Optional[int] = None,
|
|
pad_token_id: Optional[int] = None,
|
|
eos_token_id: Optional[int] = None,
|
|
**model_kwargs
|
|
):
|
|
r"""
|
|
Generates sequences for models with a language modeling head using beam search with multinomial sampling.
|
|
|
|
Parameters:
|
|
|
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
|
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
|
|
:obj:`torch.LongTensor` of shape :obj:`(1,)`.
|
|
beam_scorer (:obj:`BeamScorer`):
|
|
A derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are
|
|
constructed, stored and sorted during generation. For more information, the documentation of
|
|
:class:`~transformers.BeamScorer` should be read.
|
|
logits_processor (:obj:`LogitsProcessorList`, `optional`):
|
|
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
|
|
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
|
|
head applied at each generation step.
|
|
logits_warper (:obj:`LogitsProcessorList`, `optional`):
|
|
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
|
|
:class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language
|
|
modeling head applied before multinomial sampling at each generation step.
|
|
max_length (:obj:`int`, `optional`, defaults to 20):
|
|
The maximum length of the sequence to be generated.
|
|
pad_token_id (:obj:`int`, `optional`):
|
|
The id of the `padding` token.
|
|
eos_token_id (:obj:`int`, `optional`):
|
|
The id of the `end-of-sequence` token.
|
|
model_kwargs:
|
|
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If
|
|
model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.
|
|
|
|
Return:
|
|
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
|
|
sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all
|
|
batches finished early due to the :obj:`eos_token_id`.
|
|
|
|
Examples::
|
|
|
|
>>> from transformers import (
|
|
... AutoTokenizer,
|
|
... AutoModelForSeq2SeqLM,
|
|
... LogitsProcessorList,
|
|
... MinLengthLogitsProcessor,
|
|
... TopKLogitsWarper,
|
|
... TemperatureLogitsWarper,
|
|
... BeamSearchScorer,
|
|
... )
|
|
>>> import torch
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
|
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
|
|
|
|
>>> encoder_input_str = "translate English to German: How old are you?"
|
|
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
|
|
|
|
>>> # lets run beam search using 3 beams
|
|
>>> num_beams = 3
|
|
>>> # define decoder start token ids
|
|
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
|
|
>>> input_ids = input_ids * model.config.decoder_start_token_id
|
|
|
|
>>> # add encoder_outputs to model keyword arguments
|
|
>>> model_kwargs = {
|
|
... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True)
|
|
... }
|
|
|
|
>>> # instantiate beam scorer
|
|
>>> beam_scorer = BeamSearchScorer(
|
|
... batch_size=1,
|
|
... max_length=model.config.max_length,
|
|
... num_beams=num_beams,
|
|
... device=model.device,
|
|
... )
|
|
|
|
>>> # instantiate logits processors
|
|
>>> logits_processor = LogitsProcessorList([
|
|
... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id)
|
|
... ])
|
|
>>> # instantiate logits processors
|
|
>>> logits_warper = LogitsProcessorList([
|
|
... TopKLogitsWarper(50),
|
|
... TemperatureLogitsWarper(0.7),
|
|
... ])
|
|
|
|
>>> outputs = model.beam_sample(
|
|
... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs
|
|
... )
|
|
|
|
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
|
"""
|
|
|
|
# init values
|
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
|
max_length = max_length if max_length is not None else self.config.max_length
|
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
|
|
|
batch_size = len(beam_scorer._beam_hyps)
|
|
num_beams = beam_scorer.num_beams
|
|
|
|
batch_beam_size, cur_len = input_ids.shape
|
|
|
|
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
|
beam_scores = beam_scores.view((batch_size * num_beams,))
|
|
|
|
while cur_len < max_length:
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
|
|
|
outputs = self(**model_inputs, return_dict=True)
|
|
next_token_logits = outputs.logits[:, -1, :]
|
|
|
|
# adjust token scores (a no-op by default)
|
|
next_token_logits = self.adjust_logits_during_generation(
|
|
next_token_logits, cur_len=cur_len, max_length=max_length
|
|
)
|
|
|
|
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
|
|
|
next_token_scores = logits_processor(input_ids, next_token_scores)
|
|
next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
|
|
next_token_scores = logits_warper(input_ids, next_token_scores)
|
|
|
|
# reshape for beam search
|
|
vocab_size = next_token_scores.shape[-1]
|
|
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
|
|
|
|
probs = F.softmax(next_token_scores, dim=-1)
|
|
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
|
|
next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
|
|
|
|
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
|
|
next_tokens = torch.gather(next_tokens, -1, _indices)
|
|
|
|
next_indices = next_tokens // vocab_size
|
|
next_tokens = next_tokens % vocab_size
|
|
|
|
# stateless
|
|
beam_outputs = beam_scorer.process(
|
|
input_ids,
|
|
next_token_scores,
|
|
next_tokens,
|
|
next_indices,
|
|
pad_token_id=pad_token_id,
|
|
eos_token_id=eos_token_id,
|
|
)
|
|
beam_scores = beam_outputs["next_beam_scores"]
|
|
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
|
beam_idx = beam_outputs["next_beam_indices"]
|
|
|
|
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
|
cur_len = cur_len + 1
|
|
|
|
model_kwargs = self._update_model_kwargs_for_generation(
|
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
|
)
|
|
if model_kwargs["past"] is not None:
|
|
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
|
|
|
|
if beam_scorer.is_done:
|
|
break
|
|
|
|
decoded = beam_scorer.finalize(
|
|
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
|
|
)
|
|
|
|
return decoded
|
|
|
|
|
|
def top_k_top_p_filtering(
|
|
logits: torch.FloatTensor,
|
|
top_k: int = 0,
|
|
top_p: float = 1.0,
|
|
filter_value: float = -float("Inf"),
|
|
min_tokens_to_keep: int = 1,
|
|
) -> torch.FloatTensor:
|
|
"""
|
|
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
|
|
|
Args:
|
|
logits: logits distribution shape (batch size, vocabulary size)
|
|
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
|
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
|
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
|
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
|
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
|
"""
|
|
if top_k > 0:
|
|
logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
|
|
None, logits
|
|
)
|
|
|
|
if 0 <= top_p <= 1.0:
|
|
logits = TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=min_tokens_to_keep)(None, logits)
|
|
|
|
return logits
|