fixing run_generation example - using torch.no_grad
This commit is contained in:
@@ -87,11 +87,11 @@ def prepare_ctrl_input(args, _, tokenizer, prompt_text):
|
|||||||
logger.info(
|
logger.info(
|
||||||
"WARNING! You are not starting your generation from a control code so you won't get good results"
|
"WARNING! You are not starting your generation from a control code so you won't get good results"
|
||||||
)
|
)
|
||||||
return prompt_text, {}
|
return prompt_text
|
||||||
|
|
||||||
|
|
||||||
def prepare_xlm_input(args, model, tokenizer, prompt_text):
|
def prepare_xlm_input(args, model, tokenizer, prompt_text):
|
||||||
kwargs = {"language": None, "mask_token_id": None}
|
# kwargs = {"language": None, "mask_token_id": None}
|
||||||
|
|
||||||
# Set the language
|
# Set the language
|
||||||
use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
|
use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
|
||||||
@@ -107,14 +107,15 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text):
|
|||||||
+ str(list(available_languages))
|
+ str(list(available_languages))
|
||||||
+ " >>> "
|
+ " >>> "
|
||||||
)
|
)
|
||||||
kwargs["language"] = tokenizer.lang2id[language]
|
# kwargs["language"] = tokenizer.lang2id[language]
|
||||||
|
|
||||||
|
# TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
|
||||||
# XLM masked-language modeling (MLM) models need masked token
|
# XLM masked-language modeling (MLM) models need masked token
|
||||||
is_xlm_mlm = "mlm" in args.model_name_or_path
|
# is_xlm_mlm = "mlm" in args.model_name_or_path
|
||||||
if is_xlm_mlm:
|
# if is_xlm_mlm:
|
||||||
kwargs["mask_token_id"] = tokenizer.mask_token_id
|
# kwargs["mask_token_id"] = tokenizer.mask_token_id
|
||||||
|
|
||||||
return prompt_text, kwargs
|
return prompt_text
|
||||||
|
|
||||||
|
|
||||||
def prepare_xlnet_input(args, _, tokenizer, prompt_text):
|
def prepare_xlnet_input(args, _, tokenizer, prompt_text):
|
||||||
@@ -179,8 +180,8 @@ def main():
|
|||||||
try:
|
try:
|
||||||
args.model_type = args.model_type.lower()
|
args.model_type = args.model_type.lower()
|
||||||
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||||
except KeyError as ke:
|
except KeyError:
|
||||||
raise ke(
|
raise KeyError(
|
||||||
"the model {} you specified is not supported. You are welcome to add it and open a PR :)"
|
"the model {} you specified is not supported. You are welcome to add it and open a PR :)"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -197,10 +198,9 @@ def main():
|
|||||||
|
|
||||||
# Different models need different input formatting and/or extra arguments
|
# Different models need different input formatting and/or extra arguments
|
||||||
requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys()
|
requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys()
|
||||||
model_kwargs = {}
|
|
||||||
if requires_preprocessing:
|
if requires_preprocessing:
|
||||||
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
|
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
|
||||||
prompt_text, model_kwargs = prepare_input(args, model, tokenizer, prompt_text)
|
prompt_text = prepare_input(args, model, tokenizer, prompt_text)
|
||||||
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors='pt')
|
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors='pt')
|
||||||
|
|
||||||
output_sequences = model.generate(
|
output_sequences = model.generate(
|
||||||
@@ -210,14 +210,11 @@ def main():
|
|||||||
top_k=args.k,
|
top_k=args.k,
|
||||||
top_p=args.p,
|
top_p=args.p,
|
||||||
repetition_penalty=args.repetition_penalty,
|
repetition_penalty=args.repetition_penalty,
|
||||||
**model_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
generated_sequence = output_sequences.tolist()[
|
generated_sequence = output_sequences.tolist()
|
||||||
encoded_prompt.size(1) :
|
text = [tokenizer.decode(seq, clean_up_tokenization_spaces=True) for seq in generated_sequence]
|
||||||
] # adapted to case where num_samples > 1
|
# text = text[: text.find(args.stop_token) if args.stop_token else None]
|
||||||
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
|
|
||||||
text = text[: text.find(args.stop_token) if args.stop_token else None]
|
|
||||||
|
|
||||||
print(text)
|
print(text)
|
||||||
|
|
||||||
|
|||||||
@@ -113,8 +113,8 @@ class XLMConfig(PretrainedConfig):
|
|||||||
summary_first_dropout=0.1,
|
summary_first_dropout=0.1,
|
||||||
start_n_top=5,
|
start_n_top=5,
|
||||||
end_n_top=5,
|
end_n_top=5,
|
||||||
mask_token_id = 0,
|
mask_token_id=0,
|
||||||
lang_id = 0,
|
lang_id=0,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Constructs XLMConfig.
|
"""Constructs XLMConfig.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -489,7 +489,7 @@ class PreTrainedModel(nn.Module):
|
|||||||
def generate(self, input_ids=None, max_length=None, do_sample=None, num_beams=None,
|
def generate(self, input_ids=None, max_length=None, do_sample=None, num_beams=None,
|
||||||
temperature=None, top_k=None, top_p=None, repetition_penalty=None,
|
temperature=None, top_k=None, top_p=None, repetition_penalty=None,
|
||||||
bos_token_id=None, pad_token_id=None, eos_token_ids=None,
|
bos_token_id=None, pad_token_id=None, eos_token_ids=None,
|
||||||
length_penalty=None, num_return_sequences=None, **model_kwargs):
|
length_penalty=None, num_return_sequences=None):
|
||||||
""" Sequence generator for models with a LM head.
|
""" Sequence generator for models with a LM head.
|
||||||
|
|
||||||
The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
|
The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
|
||||||
@@ -519,7 +519,8 @@ class PreTrainedModel(nn.Module):
|
|||||||
|
|
||||||
# We cannot generate if the model does not have a LM head
|
# We cannot generate if the model does not have a LM head
|
||||||
if self.get_output_embeddings() is None:
|
if self.get_output_embeddings() is None:
|
||||||
raise AttributeError("You tried do generated sequences with a model that does not have a LM Head.")
|
raise AttributeError("You tried to generate sequences with a model that does not have a LM Head."
|
||||||
|
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`)")
|
||||||
|
|
||||||
max_length = max_length if max_length is not None else self.config.max_length
|
max_length = max_length if max_length is not None else self.config.max_length
|
||||||
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||||
@@ -544,7 +545,7 @@ class PreTrainedModel(nn.Module):
|
|||||||
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
|
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
|
||||||
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
|
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
|
||||||
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
|
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
|
||||||
# assert temperature > 0, "`temperature` should be strictely positive."
|
# assert temperature >= 0, "`temperature` should be positive."
|
||||||
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
|
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
|
||||||
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
|
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
|
||||||
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
|
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
|
||||||
@@ -576,13 +577,11 @@ class PreTrainedModel(nn.Module):
|
|||||||
output = self._generate_beam_search(input_ids, cur_len, max_length, do_sample,
|
output = self._generate_beam_search(input_ids, cur_len, max_length, do_sample,
|
||||||
temperature, top_k, top_p, repetition_penalty,
|
temperature, top_k, top_p, repetition_penalty,
|
||||||
pad_token_id, eos_token_ids, effective_batch_size,
|
pad_token_id, eos_token_ids, effective_batch_size,
|
||||||
length_penalty, num_beams, vocab_size,
|
length_penalty, num_beams, vocab_size)
|
||||||
**model_kwargs)
|
|
||||||
else:
|
else:
|
||||||
output = self._generate_no_beam_search(input_ids, cur_len, max_length, do_sample,
|
output = self._generate_no_beam_search(input_ids, cur_len, max_length, do_sample,
|
||||||
temperature, top_k, top_p, repetition_penalty,
|
temperature, top_k, top_p, repetition_penalty,
|
||||||
pad_token_id, eos_token_ids, effective_batch_size,
|
pad_token_id, eos_token_ids, effective_batch_size)
|
||||||
**model_kwargs)
|
|
||||||
|
|
||||||
if num_return_sequences != 1:
|
if num_return_sequences != 1:
|
||||||
output = output.view(batch_size, num_return_sequences, -1)
|
output = output.view(batch_size, num_return_sequences, -1)
|
||||||
@@ -590,19 +589,18 @@ class PreTrainedModel(nn.Module):
|
|||||||
|
|
||||||
def _generate_no_beam_search(self, input_ids, cur_len, max_length, do_sample,
|
def _generate_no_beam_search(self, input_ids, cur_len, max_length, do_sample,
|
||||||
temperature, top_k, top_p, repetition_penalty,
|
temperature, top_k, top_p, repetition_penalty,
|
||||||
pad_token_id, eos_token_ids, batch_size,
|
pad_token_id, eos_token_ids, batch_size):
|
||||||
**model_kwargs):
|
|
||||||
""" Generate sequences for each example without beam search (num_beams == 1).
|
""" Generate sequences for each example without beam search (num_beams == 1).
|
||||||
All returned sequence are generated independantly.
|
All returned sequence are generated independantly.
|
||||||
"""
|
"""
|
||||||
# current position / max lengths / length of generated sentences / unfinished sentences
|
# current position / max lengths / length of generated sentences / unfinished sentences
|
||||||
unfinished_sents = input_ids.new(batch_size).fill_(1)
|
unfinished_sents = input_ids.new(batch_size).fill_(1)
|
||||||
|
|
||||||
# cache compute states
|
# TODO: add cached compute states
|
||||||
pasts = None
|
pasts = None
|
||||||
|
|
||||||
while cur_len < max_length:
|
while cur_len < max_length:
|
||||||
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts, **model_kwargs)
|
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
|
||||||
outputs = self(**model_inputs)
|
outputs = self(**model_inputs)
|
||||||
next_token_logits = outputs[0][:, -1, :]
|
next_token_logits = outputs[0][:, -1, :]
|
||||||
|
|
||||||
@@ -614,7 +612,7 @@ class PreTrainedModel(nn.Module):
|
|||||||
|
|
||||||
if do_sample:
|
if do_sample:
|
||||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||||
if temperature != 1.0:
|
if temperature > 0 and temperature != 1.0:
|
||||||
next_token_logits = next_token_logits / temperature
|
next_token_logits = next_token_logits / temperature
|
||||||
# Top-p/top-k filtering
|
# Top-p/top-k filtering
|
||||||
next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
||||||
@@ -644,8 +642,7 @@ class PreTrainedModel(nn.Module):
|
|||||||
def _generate_beam_search(self, input_ids, cur_len, max_length, do_sample,
|
def _generate_beam_search(self, input_ids, cur_len, max_length, do_sample,
|
||||||
temperature, top_k, top_p, repetition_penalty,
|
temperature, top_k, top_p, repetition_penalty,
|
||||||
pad_token_id, eos_token_ids, batch_size,
|
pad_token_id, eos_token_ids, batch_size,
|
||||||
length_penalty, num_beams, vocab_size,
|
length_penalty, num_beams, vocab_size):
|
||||||
**model_kwargs):
|
|
||||||
""" Generate sequences for each example with beam search.
|
""" Generate sequences for each example with beam search.
|
||||||
"""
|
"""
|
||||||
# Expand input to num beams
|
# Expand input to num beams
|
||||||
@@ -667,7 +664,7 @@ class PreTrainedModel(nn.Module):
|
|||||||
done = [False for _ in range(batch_size)]
|
done = [False for _ in range(batch_size)]
|
||||||
|
|
||||||
while cur_len < max_length:
|
while cur_len < max_length:
|
||||||
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts, **model_kwargs)
|
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
|
||||||
scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size)
|
scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size)
|
||||||
scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size)
|
scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
@@ -679,7 +676,7 @@ class PreTrainedModel(nn.Module):
|
|||||||
|
|
||||||
if do_sample:
|
if do_sample:
|
||||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||||
if temperature != 1.0:
|
if temperature > 0 and temperature != 1.0:
|
||||||
scores = scores / temperature
|
scores = scores / temperature
|
||||||
# Top-p/top-k filtering
|
# Top-p/top-k filtering
|
||||||
scores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2) # (batch_size * num_beams, vocab_size)
|
scores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2) # (batch_size * num_beams, vocab_size)
|
||||||
|
|||||||
@@ -639,9 +639,9 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.pred_layer.proj
|
return self.pred_layer.proj
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, **model_kwargs):
|
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||||
mask_token_id = model_kwargs['mask_token_id'] if 'mask_token_id' in model_kwargs else self.config.mask_token_id
|
mask_token_id = self.config.mask_token_id
|
||||||
lang_id = model_kwargs['lang_id'] if 'lang_id' in model_kwargs else self.config.lang_id
|
lang_id = self.config.lang_id
|
||||||
|
|
||||||
mask_token = torch.full((1, 1), mask_token_id, dtype=torch.long, device=input_ids.device)
|
mask_token = torch.full((1, 1), mask_token_id, dtype=torch.long, device=input_ids.device)
|
||||||
input_ids = torch.cat([input_ids, mask_token], dim=1)
|
input_ids = torch.cat([input_ids, mask_token], dim=1)
|
||||||
|
|||||||
Reference in New Issue
Block a user