Bart-CNN (#3059)
`generate` code that produces 99% identical summarizations to fairseq on CNN test data, with caching.
This commit is contained in:
@@ -171,7 +171,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
else:
|
||||
output_embeddings.weight = input_embeddings.weight
|
||||
|
||||
if hasattr(output_embeddings, "bias") and output_embeddings.bias is not None:
|
||||
if getattr(output_embeddings, "bias", None) is not None:
|
||||
output_embeddings.bias.data = torch.nn.functional.pad(
|
||||
output_embeddings.bias.data,
|
||||
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
|
||||
@@ -558,7 +558,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
model.__class__.__name__, "\n\t".join(error_msgs)
|
||||
)
|
||||
)
|
||||
|
||||
model.tie_weights() # make sure word embedding weights are still tied if needed
|
||||
|
||||
# Set model in evaluation mode to desactivate DropOut modules by default
|
||||
@@ -574,16 +573,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
return {"input_ids": input_ids}
|
||||
|
||||
def _do_output_past(self, outputs):
|
||||
has_output_past = hasattr(self.config, "output_past") and self.config.output_past
|
||||
has_mem_len = hasattr(self.config, "mem_len") and self.config.mem_len
|
||||
|
||||
if has_output_past and not has_mem_len and len(outputs) > 1:
|
||||
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
|
||||
has_output_past = getattr(self.config, "output_past", False)
|
||||
mem_len = getattr(self.config, "mem_len", 0)
|
||||
if len(outputs) <= 1:
|
||||
return False
|
||||
if mem_len > 0 or has_output_past:
|
||||
return True
|
||||
elif has_mem_len and self.config.mem_len > 0 and len(outputs) > 1:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
|
||||
"""repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
|
||||
for i in range(batch_size * num_beams):
|
||||
for previous_token in set(prev_output_tokens[i].tolist()):
|
||||
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
||||
if lprobs[i, previous_token] < 0:
|
||||
lprobs[i, previous_token] *= repetition_penalty
|
||||
else:
|
||||
lprobs[i, previous_token] /= repetition_penalty
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
@@ -761,7 +769,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len)
|
||||
input_ids = input_ids.contiguous().view(
|
||||
batch_size * num_return_sequences, cur_len
|
||||
) # (batch_size * num_return_sequences, cur_len)
|
||||
) # shape: (batch_size * num_return_sequences, cur_len)
|
||||
effective_batch_size = batch_size * num_return_sequences
|
||||
else:
|
||||
effective_batch_size = batch_size
|
||||
@@ -822,9 +830,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
sent_lengths = input_ids.new(batch_size).fill_(max_length)
|
||||
|
||||
past = None
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
|
||||
|
||||
outputs = self(**model_inputs)
|
||||
next_token_logits = outputs[0][:, -1, :]
|
||||
|
||||
@@ -834,13 +842,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
|
||||
if repetition_penalty != 1.0:
|
||||
for i in range(batch_size):
|
||||
for previous_token in set(input_ids[i].tolist()):
|
||||
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
||||
if next_token_logits[i, previous_token] < 0:
|
||||
next_token_logits[i, previous_token] *= repetition_penalty
|
||||
else:
|
||||
next_token_logits[i, previous_token] /= repetition_penalty
|
||||
self.enforce_repetition_penalty_(next_token_logits, batch_size, 1, input_ids, repetition_penalty)
|
||||
|
||||
if do_sample:
|
||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||
@@ -911,6 +913,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
""" Generate sequences for each example with beam search.
|
||||
"""
|
||||
# Expand input to num beams
|
||||
# assert input_ids.shape == (batch_size * num_beams, cur_len)
|
||||
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
|
||||
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
|
||||
|
||||
@@ -941,13 +944,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
||||
if repetition_penalty != 1.0:
|
||||
for i in range(batch_size * num_beams):
|
||||
for previous_token in set(input_ids[i].tolist()):
|
||||
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
||||
if scores[i, previous_token] < 0:
|
||||
scores[i, previous_token] *= repetition_penalty
|
||||
else:
|
||||
scores[i, previous_token] /= repetition_penalty
|
||||
self.enforce_repetition_penalty_(scores, batch_size, num_beams, input_ids, repetition_penalty)
|
||||
|
||||
if do_sample:
|
||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||
@@ -1039,16 +1036,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
# re-order internal states
|
||||
if past:
|
||||
reordered_past = []
|
||||
for layer_past in past:
|
||||
# get the correct batch idx from layer past batch dim
|
||||
# batch dim of `past` and `mems` is at 2nd position
|
||||
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
|
||||
reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
|
||||
# check that shape matches
|
||||
assert reordered_layer_past.shape == layer_past.shape
|
||||
reordered_past.append(reordered_layer_past)
|
||||
past = tuple(reordered_past)
|
||||
past = self._reorder_cache(past, beam_idx)
|
||||
|
||||
# update current length
|
||||
cur_len = cur_len + 1
|
||||
@@ -1096,6 +1084,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
return decoded
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
reordered_past = []
|
||||
for layer_past in past:
|
||||
# get the correct batch idx from layer past batch dim
|
||||
# batch dim of `past` and `mems` is at 2nd position
|
||||
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
|
||||
reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
|
||||
# check that shape matches
|
||||
assert reordered_layer_past.shape == layer_past.shape
|
||||
reordered_past.append(reordered_layer_past)
|
||||
past = tuple(reordered_past)
|
||||
return past
|
||||
|
||||
|
||||
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
|
||||
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||
@@ -1164,17 +1166,22 @@ class BeamHypotheses(object):
|
||||
else:
|
||||
self.worst_score = min(score, self.worst_score)
|
||||
|
||||
def is_done(self, best_sum_logprobs):
|
||||
def is_done(self, best_sum_logprobs, cur_len=None):
|
||||
"""
|
||||
If there are enough hypotheses and that none of the hypotheses being generated
|
||||
can become better than the worst one in the heap, then we are done with this sentence.
|
||||
"""
|
||||
|
||||
if len(self) < self.num_beams:
|
||||
return False
|
||||
elif self.early_stopping:
|
||||
return True
|
||||
else:
|
||||
return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty
|
||||
if cur_len is None:
|
||||
cur_len = self.max_length
|
||||
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
|
||||
ret = self.worst_score >= cur_score
|
||||
return ret
|
||||
|
||||
|
||||
class Conv1D(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user