`generate` code that produces 99% identical summarizations to fairseq on CNN test data, with caching.
This commit is contained in:
Sam Shleifer
2020-03-02 10:35:53 -05:00
committed by GitHub
parent 6b1ff25084
commit b54ef78d0c
8 changed files with 544 additions and 152 deletions

View File

@@ -171,7 +171,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
else:
output_embeddings.weight = input_embeddings.weight
if hasattr(output_embeddings, "bias") and output_embeddings.bias is not None:
if getattr(output_embeddings, "bias", None) is not None:
output_embeddings.bias.data = torch.nn.functional.pad(
output_embeddings.bias.data,
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
@@ -558,7 +558,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
model.__class__.__name__, "\n\t".join(error_msgs)
)
)
model.tie_weights() # make sure word embedding weights are still tied if needed
# Set model in evaluation mode to desactivate DropOut modules by default
@@ -574,16 +573,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
return {"input_ids": input_ids}
def _do_output_past(self, outputs):
has_output_past = hasattr(self.config, "output_past") and self.config.output_past
has_mem_len = hasattr(self.config, "mem_len") and self.config.mem_len
if has_output_past and not has_mem_len and len(outputs) > 1:
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
has_output_past = getattr(self.config, "output_past", False)
mem_len = getattr(self.config, "mem_len", 0)
if len(outputs) <= 1:
return False
if mem_len > 0 or has_output_past:
return True
elif has_mem_len and self.config.mem_len > 0 and len(outputs) > 1:
return True
return False
def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
"""repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
for i in range(batch_size * num_beams):
for previous_token in set(prev_output_tokens[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if lprobs[i, previous_token] < 0:
lprobs[i, previous_token] *= repetition_penalty
else:
lprobs[i, previous_token] /= repetition_penalty
@torch.no_grad()
def generate(
self,
@@ -761,7 +769,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len)
input_ids = input_ids.contiguous().view(
batch_size * num_return_sequences, cur_len
) # (batch_size * num_return_sequences, cur_len)
) # shape: (batch_size * num_return_sequences, cur_len)
effective_batch_size = batch_size * num_return_sequences
else:
effective_batch_size = batch_size
@@ -822,9 +830,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
sent_lengths = input_ids.new(batch_size).fill_(max_length)
past = None
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
@@ -834,13 +842,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
for i in range(batch_size):
for previous_token in set(input_ids[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if next_token_logits[i, previous_token] < 0:
next_token_logits[i, previous_token] *= repetition_penalty
else:
next_token_logits[i, previous_token] /= repetition_penalty
self.enforce_repetition_penalty_(next_token_logits, batch_size, 1, input_ids, repetition_penalty)
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
@@ -911,6 +913,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
""" Generate sequences for each example with beam search.
"""
# Expand input to num beams
# assert input_ids.shape == (batch_size * num_beams, cur_len)
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
@@ -941,13 +944,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
for i in range(batch_size * num_beams):
for previous_token in set(input_ids[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if scores[i, previous_token] < 0:
scores[i, previous_token] *= repetition_penalty
else:
scores[i, previous_token] /= repetition_penalty
self.enforce_repetition_penalty_(scores, batch_size, num_beams, input_ids, repetition_penalty)
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
@@ -1039,16 +1036,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# re-order internal states
if past:
reordered_past = []
for layer_past in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
# check that shape matches
assert reordered_layer_past.shape == layer_past.shape
reordered_past.append(reordered_layer_past)
past = tuple(reordered_past)
past = self._reorder_cache(past, beam_idx)
# update current length
cur_len = cur_len + 1
@@ -1096,6 +1084,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
return decoded
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = []
for layer_past in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
# check that shape matches
assert reordered_layer_past.shape == layer_past.shape
reordered_past.append(reordered_layer_past)
past = tuple(reordered_past)
return past
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
@@ -1164,17 +1166,22 @@ class BeamHypotheses(object):
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs):
def is_done(self, best_sum_logprobs, cur_len=None):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams:
return False
elif self.early_stopping:
return True
else:
return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty
if cur_len is None:
cur_len = self.max_length
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score
return ret
class Conv1D(nn.Module):