Merge pull request #2290 from patrickvonplaten/fix_typo_in_doc_for_language_generation
duplicated line for repeating_words_penalty_for_language_generation
This commit is contained in:
@@ -556,48 +556,89 @@ class PreTrainedModel(nn.Module):
|
|||||||
length_penalty=None,
|
length_penalty=None,
|
||||||
num_return_sequences=None,
|
num_return_sequences=None,
|
||||||
):
|
):
|
||||||
""" Sequence generator for models with a LM head.
|
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
|
||||||
|
|
||||||
The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
|
|
||||||
and beam-search.
|
and beam-search.
|
||||||
|
|
||||||
Adapted in part from Facebook's XLM beam search code: https://github.com/facebookresearch/XLM
|
Adapted in part from `Facebook's XLM beam search code`_.
|
||||||
|
|
||||||
Params:
|
.. _`Facebook's XLM beam search code`:
|
||||||
**input_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length)
|
https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529
|
||||||
|
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
|
||||||
|
input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
|
||||||
The sequence used as a prompt for the generation. If `None` the method initializes
|
The sequence used as a prompt for the generation. If `None` the method initializes
|
||||||
it as an empty `torch.LongTensor` of shape (1,)
|
it as an empty `torch.LongTensor` of shape `(1,)`.
|
||||||
**max_length**: (`optional`) int
|
|
||||||
|
max_length: (`optional`) int
|
||||||
The max length of the sequence to be generated. Between 1 and infinity. Default to 20.
|
The max length of the sequence to be generated. Between 1 and infinity. Default to 20.
|
||||||
**do_sample**: (`optional`) bool
|
|
||||||
If set to `False` we use greedy decoding; otherwise sampling. Default to greedy sampling.
|
do_sample: (`optional`) bool
|
||||||
**num_beams**: (`optional`) int
|
If set to `False` greedy decoding is used. Otherwise sampling is used. Default to greedy sampling.
|
||||||
Number of beams for beam search. 1 means no beam serach. Default to 1.
|
|
||||||
**temperature**: (`optional`) float
|
num_beams: (`optional`) int
|
||||||
The value used to module the next token probabilities.
|
Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
|
||||||
**top_k**: (`optional`) int
|
|
||||||
|
temperature: (`optional`) float
|
||||||
|
The value used to module the next token probabilities. Must be strictely positive. Default to 1.0.
|
||||||
|
|
||||||
|
top_k: (`optional`) int
|
||||||
The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
|
The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
|
||||||
**top_p**: (`optional`) float
|
|
||||||
|
top_p: (`optional`) float
|
||||||
The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
|
The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
|
||||||
**repetition_penalty**: (`optional`) float
|
|
||||||
The parameter for repetition penalty. Between 1.0 and + infinity. 1.0 means no penalty. Default to 1.
|
repetition_penalty: (`optional`) float
|
||||||
**bos_token_id**: (`optional`) int
|
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
|
||||||
|
|
||||||
|
bos_token_id: (`optional`) int
|
||||||
Beginning of sentence token if no prompt is provided. Default to 0.
|
Beginning of sentence token if no prompt is provided. Default to 0.
|
||||||
**eos_token_ids**: (`optional`) int or list of int
|
|
||||||
|
eos_token_ids: (`optional`) int or list of int
|
||||||
End of sequence token or list of tokens to stop the generation. Default to 0.
|
End of sequence token or list of tokens to stop the generation. Default to 0.
|
||||||
**length_penalty**: (`optional`) int
|
length_penalty: (`optional`) float
|
||||||
Exponential penalty to the length. Default to 0.
|
|
||||||
**length_penalty**: (`optional`) float
|
|
||||||
Exponential penalty to the length. Default to 1.
|
Exponential penalty to the length. Default to 1.
|
||||||
**num_return_sequences**: (`optional`) int
|
|
||||||
The number of independantly computed returned sequences for each element in the batch. Default to 1.
|
num_return_sequences: (`optional`) int
|
||||||
|
The number of independently computed returned sequences for each element in the batch. Default to 1.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
|
||||||
|
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
|
||||||
|
outputs = model.generate(max_length=40, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id) # do greedy decoding without beam search
|
||||||
|
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
|
||||||
|
model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache.
|
||||||
|
input_context = 'The dog'
|
||||||
|
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
|
||||||
|
outputs = model.generate(input_ids=input_ids, do_sample=True, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
|
||||||
|
for i in range(3): # 3 output sequences were generated
|
||||||
|
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[0][i], skip_special_tokens=True)))
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
|
||||||
|
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
|
||||||
|
input_context = 'The dog'
|
||||||
|
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
|
||||||
|
outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id, num_beams=3) # generate sequences using greedy beam search decoding (3 beams)
|
||||||
|
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer
|
||||||
|
model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache.
|
||||||
|
input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl
|
||||||
|
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
|
||||||
|
outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences using using greedy search
|
||||||
|
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# We cannot generate if the model does not have a LM head
|
# We cannot generate if the model does not have a LM head
|
||||||
if self.get_output_embeddings() is None:
|
if self.get_output_embeddings() is None:
|
||||||
raise AttributeError(
|
raise AttributeError(
|
||||||
"You tried to generate sequences with a model that does not have a LM Head."
|
"You tried to generate sequences with a model that does not have a LM Head."
|
||||||
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`)"
|
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`)"
|
||||||
)
|
)
|
||||||
|
|
||||||
max_length = max_length if max_length is not None else self.config.max_length
|
max_length = max_length if max_length is not None else self.config.max_length
|
||||||
@@ -625,7 +666,7 @@ class PreTrainedModel(nn.Module):
|
|||||||
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
|
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
|
||||||
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
|
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
|
||||||
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
|
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
|
||||||
# assert temperature >= 0, "`temperature` should be positive."
|
assert temperature > 0, "`temperature` should be strictely positive."
|
||||||
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
|
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
|
||||||
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
|
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
|
||||||
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
|
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
|
||||||
@@ -727,16 +768,16 @@ class PreTrainedModel(nn.Module):
|
|||||||
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
|
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
|
||||||
if repetition_penalty != 1.0:
|
if repetition_penalty != 1.0:
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
for previous_tokens in set(input_ids[i].tolist()):
|
for previous_token in set(input_ids[i].tolist()):
|
||||||
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
||||||
if next_token_logits[i, previous_tokens] < 0:
|
if next_token_logits[i, previous_token] < 0:
|
||||||
next_token_logits[i, previous_tokens] *= repetition_penalty
|
next_token_logits[i, previous_token] *= repetition_penalty
|
||||||
else:
|
else:
|
||||||
next_token_logits[i, previous_tokens] /= repetition_penalty
|
next_token_logits[i, previous_token] /= repetition_penalty
|
||||||
|
|
||||||
if do_sample:
|
if do_sample:
|
||||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||||
if temperature > 0 and temperature != 1.0:
|
if temperature != 1.0:
|
||||||
next_token_logits = next_token_logits / temperature
|
next_token_logits = next_token_logits / temperature
|
||||||
# Top-p/top-k filtering
|
# Top-p/top-k filtering
|
||||||
next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
||||||
@@ -810,16 +851,16 @@ class PreTrainedModel(nn.Module):
|
|||||||
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
||||||
if repetition_penalty != 1.0:
|
if repetition_penalty != 1.0:
|
||||||
for i in range(batch_size * num_beams):
|
for i in range(batch_size * num_beams):
|
||||||
for previous_tokens in set(input_ids[i].tolist()):
|
for previous_token in set(input_ids[i].tolist()):
|
||||||
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
||||||
if scores[i, previous_tokens] < 0:
|
if scores[i, previous_token] < 0:
|
||||||
scores[i, previous_tokens] *= repetition_penalty
|
scores[i, previous_token] *= repetition_penalty
|
||||||
else:
|
else:
|
||||||
scores[i, previous_tokens] /= repetition_penalty
|
scores[i, previous_token] /= repetition_penalty
|
||||||
|
|
||||||
if do_sample:
|
if do_sample:
|
||||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||||
if temperature > 0 and temperature != 1.0:
|
if temperature != 1.0:
|
||||||
scores = scores / temperature
|
scores = scores / temperature
|
||||||
# Top-p/top-k filtering
|
# Top-p/top-k filtering
|
||||||
scores = top_k_top_p_filtering(
|
scores = top_k_top_p_filtering(
|
||||||
|
|||||||
Reference in New Issue
Block a user