add past hidden key states for more efficient language generation & add prepare_inputs for gpt2 and ctrl model
This commit is contained in:
@@ -490,6 +490,14 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||||
|
# inputs_ids contain only last token if past is in kwargs and defined
|
||||||
|
input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids
|
||||||
|
|
||||||
|
inputs = {"input_ids": input_ids}
|
||||||
|
inputs.update(kwargs)
|
||||||
|
return inputs
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
|
|||||||
@@ -559,6 +559,14 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||||
|
# inputs_ids contain only last token if past is in kwargs and defined
|
||||||
|
input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids
|
||||||
|
|
||||||
|
inputs = {"input_ids": input_ids}
|
||||||
|
inputs.update(kwargs)
|
||||||
|
return inputs
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
|
|||||||
@@ -18,6 +18,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import ipdb
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -539,6 +540,14 @@ class PreTrainedModel(nn.Module):
|
|||||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||||
return {"input_ids": input_ids}
|
return {"input_ids": input_ids}
|
||||||
|
|
||||||
|
def _has_past(self, outputs):
|
||||||
|
# TODO: might be better to write a self.has_past method for each individual class as is done for
|
||||||
|
# prepare_inputs_for_generation
|
||||||
|
if hasattr(self, 'output_past') and self.output_past and len(outputs) > 1:
|
||||||
|
return True
|
||||||
|
# TODO: Add cases for (xlnet, transfo_xl) using mem_len
|
||||||
|
return False
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
@@ -716,14 +725,16 @@ class PreTrainedModel(nn.Module):
|
|||||||
# current position / max lengths / length of generated sentences / unfinished sentences
|
# current position / max lengths / length of generated sentences / unfinished sentences
|
||||||
unfinished_sents = input_ids.new(batch_size).fill_(1)
|
unfinished_sents = input_ids.new(batch_size).fill_(1)
|
||||||
|
|
||||||
# TODO: add cached compute states
|
past = None
|
||||||
pasts = None
|
|
||||||
|
|
||||||
while cur_len < max_length:
|
while cur_len < max_length:
|
||||||
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
|
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
|
||||||
outputs = self(**model_inputs)
|
outputs = self(**model_inputs)
|
||||||
next_token_logits = outputs[0][:, -1, :]
|
next_token_logits = outputs[0][:, -1, :]
|
||||||
|
|
||||||
|
if self._has_past(outputs):
|
||||||
|
past = outputs[1]
|
||||||
|
|
||||||
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
|
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
|
||||||
if repetition_penalty != 1.0:
|
if repetition_penalty != 1.0:
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
@@ -782,6 +793,7 @@ class PreTrainedModel(nn.Module):
|
|||||||
):
|
):
|
||||||
""" Generate sequences for each example with beam search.
|
""" Generate sequences for each example with beam search.
|
||||||
"""
|
"""
|
||||||
|
ipdb.set_trace()
|
||||||
# Expand input to num beams
|
# Expand input to num beams
|
||||||
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
|
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
|
||||||
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
|
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
|
||||||
@@ -797,15 +809,18 @@ class PreTrainedModel(nn.Module):
|
|||||||
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
|
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
|
||||||
|
|
||||||
# cache compute states
|
# cache compute states
|
||||||
pasts = None # self.prepare_pasts()
|
past = None
|
||||||
|
|
||||||
# done sentences
|
# done sentences
|
||||||
done = [False for _ in range(batch_size)]
|
done = [False for _ in range(batch_size)]
|
||||||
|
|
||||||
while cur_len < max_length:
|
while cur_len < max_length:
|
||||||
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
|
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
|
||||||
scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size)
|
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
||||||
scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size)
|
scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
|
if self._has_past(outputs):
|
||||||
|
past = outputs[1]
|
||||||
|
|
||||||
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
||||||
if repetition_penalty != 1.0:
|
if repetition_penalty != 1.0:
|
||||||
|
|||||||
Reference in New Issue
Block a user