From 300ec3003c282c5e3f06b33509af10dd0336d0ba Mon Sep 17 00:00:00 2001 From: thomwolf Date: Sat, 21 Dec 2019 14:02:19 +0100 Subject: [PATCH] fixing run_generation example - using torch.no_grad --- examples/run_generation.py | 31 ++++++++++++++----------------- transformers/configuration_xlm.py | 4 ++-- transformers/modeling_utils.py | 29 +++++++++++++---------------- transformers/modeling_xlm.py | 6 +++--- 4 files changed, 32 insertions(+), 38 deletions(-) diff --git a/examples/run_generation.py b/examples/run_generation.py index 67e1da7413..ade85f0269 100644 --- a/examples/run_generation.py +++ b/examples/run_generation.py @@ -87,11 +87,11 @@ def prepare_ctrl_input(args, _, tokenizer, prompt_text): logger.info( "WARNING! You are not starting your generation from a control code so you won't get good results" ) - return prompt_text, {} + return prompt_text def prepare_xlm_input(args, model, tokenizer, prompt_text): - kwargs = {"language": None, "mask_token_id": None} + # kwargs = {"language": None, "mask_token_id": None} # Set the language use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb @@ -107,14 +107,15 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text): + str(list(available_languages)) + " >>> " ) - kwargs["language"] = tokenizer.lang2id[language] + # kwargs["language"] = tokenizer.lang2id[language] + # TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers # XLM masked-language modeling (MLM) models need masked token - is_xlm_mlm = "mlm" in args.model_name_or_path - if is_xlm_mlm: - kwargs["mask_token_id"] = tokenizer.mask_token_id + # is_xlm_mlm = "mlm" in args.model_name_or_path + # if is_xlm_mlm: + # kwargs["mask_token_id"] = tokenizer.mask_token_id - return prompt_text, kwargs + return prompt_text def prepare_xlnet_input(args, _, tokenizer, prompt_text): @@ -179,8 +180,8 @@ def main(): try: args.model_type = args.model_type.lower() model_class, tokenizer_class = MODEL_CLASSES[args.model_type] - except KeyError as ke: - raise ke( + except KeyError: + raise KeyError( "the model {} you specified is not supported. You are welcome to add it and open a PR :)" ) @@ -197,10 +198,9 @@ def main(): # Different models need different input formatting and/or extra arguments requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys() - model_kwargs = {} if requires_preprocessing: prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type) - prompt_text, model_kwargs = prepare_input(args, model, tokenizer, prompt_text) + prompt_text = prepare_input(args, model, tokenizer, prompt_text) encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors='pt') output_sequences = model.generate( @@ -210,14 +210,11 @@ def main(): top_k=args.k, top_p=args.p, repetition_penalty=args.repetition_penalty, - **model_kwargs, ) - generated_sequence = output_sequences.tolist()[ - encoded_prompt.size(1) : - ] # adapted to case where num_samples > 1 - text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) - text = text[: text.find(args.stop_token) if args.stop_token else None] + generated_sequence = output_sequences.tolist() + text = [tokenizer.decode(seq, clean_up_tokenization_spaces=True) for seq in generated_sequence] + # text = text[: text.find(args.stop_token) if args.stop_token else None] print(text) diff --git a/transformers/configuration_xlm.py b/transformers/configuration_xlm.py index 1938b85741..1134c7ab61 100644 --- a/transformers/configuration_xlm.py +++ b/transformers/configuration_xlm.py @@ -113,8 +113,8 @@ class XLMConfig(PretrainedConfig): summary_first_dropout=0.1, start_n_top=5, end_n_top=5, - mask_token_id = 0, - lang_id = 0, + mask_token_id=0, + lang_id=0, **kwargs): """Constructs XLMConfig. """ diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index 5b28d5b755..005252c141 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -489,7 +489,7 @@ class PreTrainedModel(nn.Module): def generate(self, input_ids=None, max_length=None, do_sample=None, num_beams=None, temperature=None, top_k=None, top_p=None, repetition_penalty=None, bos_token_id=None, pad_token_id=None, eos_token_ids=None, - length_penalty=None, num_return_sequences=None, **model_kwargs): + length_penalty=None, num_return_sequences=None): """ Sequence generator for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling @@ -519,7 +519,8 @@ class PreTrainedModel(nn.Module): # We cannot generate if the model does not have a LM head if self.get_output_embeddings() is None: - raise AttributeError("You tried do generated sequences with a model that does not have a LM Head.") + raise AttributeError("You tried to generate sequences with a model that does not have a LM Head." + "Please use another model class (e.g. `OpenAIGPTLMHeadModel`)") max_length = max_length if max_length is not None else self.config.max_length do_sample = do_sample if do_sample is not None else self.config.do_sample @@ -544,7 +545,7 @@ class PreTrainedModel(nn.Module): assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer." assert isinstance(do_sample, bool), "`do_sample` should be a boolean." assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer." - # assert temperature > 0, "`temperature` should be strictely positive." + # assert temperature >= 0, "`temperature` should be positive." assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer." assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." @@ -576,13 +577,11 @@ class PreTrainedModel(nn.Module): output = self._generate_beam_search(input_ids, cur_len, max_length, do_sample, temperature, top_k, top_p, repetition_penalty, pad_token_id, eos_token_ids, effective_batch_size, - length_penalty, num_beams, vocab_size, - **model_kwargs) + length_penalty, num_beams, vocab_size) else: output = self._generate_no_beam_search(input_ids, cur_len, max_length, do_sample, temperature, top_k, top_p, repetition_penalty, - pad_token_id, eos_token_ids, effective_batch_size, - **model_kwargs) + pad_token_id, eos_token_ids, effective_batch_size) if num_return_sequences != 1: output = output.view(batch_size, num_return_sequences, -1) @@ -590,19 +589,18 @@ class PreTrainedModel(nn.Module): def _generate_no_beam_search(self, input_ids, cur_len, max_length, do_sample, temperature, top_k, top_p, repetition_penalty, - pad_token_id, eos_token_ids, batch_size, - **model_kwargs): + pad_token_id, eos_token_ids, batch_size): """ Generate sequences for each example without beam search (num_beams == 1). All returned sequence are generated independantly. """ # current position / max lengths / length of generated sentences / unfinished sentences unfinished_sents = input_ids.new(batch_size).fill_(1) - # cache compute states + # TODO: add cached compute states pasts = None while cur_len < max_length: - model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts, **model_kwargs) + model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts) outputs = self(**model_inputs) next_token_logits = outputs[0][:, -1, :] @@ -614,7 +612,7 @@ class PreTrainedModel(nn.Module): if do_sample: # Temperature (higher temperature => more likely to sample low probability tokens) - if temperature != 1.0: + if temperature > 0 and temperature != 1.0: next_token_logits = next_token_logits / temperature # Top-p/top-k filtering next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) @@ -644,8 +642,7 @@ class PreTrainedModel(nn.Module): def _generate_beam_search(self, input_ids, cur_len, max_length, do_sample, temperature, top_k, top_p, repetition_penalty, pad_token_id, eos_token_ids, batch_size, - length_penalty, num_beams, vocab_size, - **model_kwargs): + length_penalty, num_beams, vocab_size): """ Generate sequences for each example with beam search. """ # Expand input to num beams @@ -667,7 +664,7 @@ class PreTrainedModel(nn.Module): done = [False for _ in range(batch_size)] while cur_len < max_length: - model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts, **model_kwargs) + model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts) scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size) scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size) @@ -679,7 +676,7 @@ class PreTrainedModel(nn.Module): if do_sample: # Temperature (higher temperature => more likely to sample low probability tokens) - if temperature != 1.0: + if temperature > 0 and temperature != 1.0: scores = scores / temperature # Top-p/top-k filtering scores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2) # (batch_size * num_beams, vocab_size) diff --git a/transformers/modeling_xlm.py b/transformers/modeling_xlm.py index 6691b0f60b..35bada92af 100644 --- a/transformers/modeling_xlm.py +++ b/transformers/modeling_xlm.py @@ -639,9 +639,9 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): def get_output_embeddings(self): return self.pred_layer.proj - def prepare_inputs_for_generation(self, input_ids, **model_kwargs): - mask_token_id = model_kwargs['mask_token_id'] if 'mask_token_id' in model_kwargs else self.config.mask_token_id - lang_id = model_kwargs['lang_id'] if 'lang_id' in model_kwargs else self.config.lang_id + def prepare_inputs_for_generation(self, input_ids, **kwargs): + mask_token_id = self.config.mask_token_id + lang_id = self.config.lang_id mask_token = torch.full((1, 1), mask_token_id, dtype=torch.long, device=input_ids.device) input_ids = torch.cat([input_ids, mask_token], dim=1)