Merge pull request #2290 from patrickvonplaten/fix_typo_in_doc_for_language_generation
duplicated line for repeating_words_penalty_for_language_generation
This commit is contained in:
@@ -556,48 +556,89 @@ class PreTrainedModel(nn.Module):
|
||||
length_penalty=None,
|
||||
num_return_sequences=None,
|
||||
):
|
||||
""" Sequence generator for models with a LM head.
|
||||
|
||||
The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
|
||||
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
|
||||
and beam-search.
|
||||
|
||||
Adapted in part from Facebook's XLM beam search code: https://github.com/facebookresearch/XLM
|
||||
Adapted in part from `Facebook's XLM beam search code`_.
|
||||
|
||||
Params:
|
||||
**input_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length)
|
||||
.. _`Facebook's XLM beam search code`:
|
||||
https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529
|
||||
|
||||
|
||||
Parameters:
|
||||
|
||||
input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
|
||||
The sequence used as a prompt for the generation. If `None` the method initializes
|
||||
it as an empty `torch.LongTensor` of shape (1,)
|
||||
**max_length**: (`optional`) int
|
||||
it as an empty `torch.LongTensor` of shape `(1,)`.
|
||||
|
||||
max_length: (`optional`) int
|
||||
The max length of the sequence to be generated. Between 1 and infinity. Default to 20.
|
||||
**do_sample**: (`optional`) bool
|
||||
If set to `False` we use greedy decoding; otherwise sampling. Default to greedy sampling.
|
||||
**num_beams**: (`optional`) int
|
||||
Number of beams for beam search. 1 means no beam serach. Default to 1.
|
||||
**temperature**: (`optional`) float
|
||||
The value used to module the next token probabilities.
|
||||
**top_k**: (`optional`) int
|
||||
|
||||
do_sample: (`optional`) bool
|
||||
If set to `False` greedy decoding is used. Otherwise sampling is used. Default to greedy sampling.
|
||||
|
||||
num_beams: (`optional`) int
|
||||
Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
|
||||
|
||||
temperature: (`optional`) float
|
||||
The value used to module the next token probabilities. Must be strictely positive. Default to 1.0.
|
||||
|
||||
top_k: (`optional`) int
|
||||
The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
|
||||
**top_p**: (`optional`) float
|
||||
|
||||
top_p: (`optional`) float
|
||||
The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
|
||||
**repetition_penalty**: (`optional`) float
|
||||
The parameter for repetition penalty. Between 1.0 and + infinity. 1.0 means no penalty. Default to 1.
|
||||
**bos_token_id**: (`optional`) int
|
||||
|
||||
repetition_penalty: (`optional`) float
|
||||
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
|
||||
|
||||
bos_token_id: (`optional`) int
|
||||
Beginning of sentence token if no prompt is provided. Default to 0.
|
||||
**eos_token_ids**: (`optional`) int or list of int
|
||||
|
||||
eos_token_ids: (`optional`) int or list of int
|
||||
End of sequence token or list of tokens to stop the generation. Default to 0.
|
||||
**length_penalty**: (`optional`) int
|
||||
Exponential penalty to the length. Default to 0.
|
||||
**length_penalty**: (`optional`) float
|
||||
length_penalty: (`optional`) float
|
||||
Exponential penalty to the length. Default to 1.
|
||||
**num_return_sequences**: (`optional`) int
|
||||
The number of independantly computed returned sequences for each element in the batch. Default to 1.
|
||||
|
||||
num_return_sequences: (`optional`) int
|
||||
The number of independently computed returned sequences for each element in the batch. Default to 1.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
|
||||
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
|
||||
outputs = model.generate(max_length=40, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id) # do greedy decoding without beam search
|
||||
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
|
||||
model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache.
|
||||
input_context = 'The dog'
|
||||
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
|
||||
outputs = model.generate(input_ids=input_ids, do_sample=True, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
|
||||
for i in range(3): # 3 output sequences were generated
|
||||
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[0][i], skip_special_tokens=True)))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
|
||||
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
|
||||
input_context = 'The dog'
|
||||
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
|
||||
outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id, num_beams=3) # generate sequences using greedy beam search decoding (3 beams)
|
||||
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer
|
||||
model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache.
|
||||
input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl
|
||||
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
|
||||
outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences using using greedy search
|
||||
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
|
||||
|
||||
"""
|
||||
|
||||
# We cannot generate if the model does not have a LM head
|
||||
if self.get_output_embeddings() is None:
|
||||
raise AttributeError(
|
||||
"You tried to generate sequences with a model that does not have a LM Head."
|
||||
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`)"
|
||||
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`)"
|
||||
)
|
||||
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
@@ -625,7 +666,7 @@ class PreTrainedModel(nn.Module):
|
||||
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
|
||||
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
|
||||
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
|
||||
# assert temperature >= 0, "`temperature` should be positive."
|
||||
assert temperature > 0, "`temperature` should be strictely positive."
|
||||
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
|
||||
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
|
||||
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
|
||||
@@ -727,16 +768,16 @@ class PreTrainedModel(nn.Module):
|
||||
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
|
||||
if repetition_penalty != 1.0:
|
||||
for i in range(batch_size):
|
||||
for previous_tokens in set(input_ids[i].tolist()):
|
||||
for previous_token in set(input_ids[i].tolist()):
|
||||
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
||||
if next_token_logits[i, previous_tokens] < 0:
|
||||
next_token_logits[i, previous_tokens] *= repetition_penalty
|
||||
if next_token_logits[i, previous_token] < 0:
|
||||
next_token_logits[i, previous_token] *= repetition_penalty
|
||||
else:
|
||||
next_token_logits[i, previous_tokens] /= repetition_penalty
|
||||
next_token_logits[i, previous_token] /= repetition_penalty
|
||||
|
||||
if do_sample:
|
||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||
if temperature > 0 and temperature != 1.0:
|
||||
if temperature != 1.0:
|
||||
next_token_logits = next_token_logits / temperature
|
||||
# Top-p/top-k filtering
|
||||
next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
||||
@@ -810,16 +851,16 @@ class PreTrainedModel(nn.Module):
|
||||
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
||||
if repetition_penalty != 1.0:
|
||||
for i in range(batch_size * num_beams):
|
||||
for previous_tokens in set(input_ids[i].tolist()):
|
||||
for previous_token in set(input_ids[i].tolist()):
|
||||
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
||||
if scores[i, previous_tokens] < 0:
|
||||
scores[i, previous_tokens] *= repetition_penalty
|
||||
if scores[i, previous_token] < 0:
|
||||
scores[i, previous_token] *= repetition_penalty
|
||||
else:
|
||||
scores[i, previous_tokens] /= repetition_penalty
|
||||
scores[i, previous_token] /= repetition_penalty
|
||||
|
||||
if do_sample:
|
||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||
if temperature > 0 and temperature != 1.0:
|
||||
if temperature != 1.0:
|
||||
scores = scores / temperature
|
||||
# Top-p/top-k filtering
|
||||
scores = top_k_top_p_filtering(
|
||||
|
||||
Reference in New Issue
Block a user