From d891fd0ae0d43ae71723c1df9622ffef33ce84aa Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Mon, 23 Dec 2019 21:19:27 +0100 Subject: [PATCH] add past hidden key states for more efficient language generation & add prepare_inputs for gpt2 and ctrl model --- src/transformers/modeling_ctrl.py | 8 ++++++++ src/transformers/modeling_gpt2.py | 8 ++++++++ src/transformers/modeling_utils.py | 29 ++++++++++++++++++++++------- 3 files changed, 38 insertions(+), 7 deletions(-) diff --git a/src/transformers/modeling_ctrl.py b/src/transformers/modeling_ctrl.py index b4906b7aa6..69e3a1a8da 100644 --- a/src/transformers/modeling_ctrl.py +++ b/src/transformers/modeling_ctrl.py @@ -490,6 +490,14 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): def get_output_embeddings(self): return self.lm_head + def prepare_inputs_for_generation(self, input_ids, **kwargs): + # inputs_ids contain only last token if past is in kwargs and defined + input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids + + inputs = {"input_ids": input_ids} + inputs.update(kwargs) + return inputs + def forward( self, input_ids=None, diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index c8d5040f2f..d962259d28 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -559,6 +559,14 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): def get_output_embeddings(self): return self.lm_head + def prepare_inputs_for_generation(self, input_ids, **kwargs): + # inputs_ids contain only last token if past is in kwargs and defined + input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids + + inputs = {"input_ids": input_ids} + inputs.update(kwargs) + return inputs + def forward( self, input_ids=None, diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 8722d578fd..2c8d30e85f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -18,6 +18,7 @@ import logging import os +import ipdb import torch from torch import nn @@ -539,6 +540,14 @@ class PreTrainedModel(nn.Module): def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids} + def _has_past(self, outputs): + # TODO: might be better to write a self.has_past method for each individual class as is done for + # prepare_inputs_for_generation + if hasattr(self, 'output_past') and self.output_past and len(outputs) > 1: + return True + # TODO: Add cases for (xlnet, transfo_xl) using mem_len + return False + @torch.no_grad() def generate( self, @@ -716,14 +725,16 @@ class PreTrainedModel(nn.Module): # current position / max lengths / length of generated sentences / unfinished sentences unfinished_sents = input_ids.new(batch_size).fill_(1) - # TODO: add cached compute states - pasts = None + past = None while cur_len < max_length: - model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts) + model_inputs = self.prepare_inputs_for_generation(input_ids, past=past) outputs = self(**model_inputs) next_token_logits = outputs[0][:, -1, :] + if self._has_past(outputs): + past = outputs[1] + # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) if repetition_penalty != 1.0: for i in range(batch_size): @@ -782,6 +793,7 @@ class PreTrainedModel(nn.Module): ): """ Generate sequences for each example with beam search. """ + ipdb.set_trace() # Expand input to num beams input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len) input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len) @@ -797,15 +809,18 @@ class PreTrainedModel(nn.Module): beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,) # cache compute states - pasts = None # self.prepare_pasts() + past = None # done sentences done = [False for _ in range(batch_size)] while cur_len < max_length: - model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts) - scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size) - scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size) + model_inputs = self.prepare_inputs_for_generation(input_ids, past=past) + outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size) + scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size) + + if self._has_past(outputs): + past = outputs[1] # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) if repetition_penalty != 1.0: