rename variables named 'word' to 'token' in generate fn (#3119)
* fix conflits * fixed naming bug * make style
This commit is contained in:
committed by
GitHub
parent
71c8711970
commit
006097f8ad
@@ -242,7 +242,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
# initialize all new embeddings (in particular added tokens)
|
# initialize all new embeddings (in particular added tokens)
|
||||||
self._init_weights(new_embeddings)
|
self._init_weights(new_embeddings)
|
||||||
|
|
||||||
# Copy word embeddings from the previous weights
|
# Copy token embeddings from the previous weights
|
||||||
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
||||||
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
|
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
|
||||||
|
|
||||||
@@ -558,7 +558,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
model.__class__.__name__, "\n\t".join(error_msgs)
|
model.__class__.__name__, "\n\t".join(error_msgs)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
model.tie_weights() # make sure word embedding weights are still tied if needed
|
model.tie_weights() # make sure token embedding weights are still tied if needed
|
||||||
|
|
||||||
# Set model in evaluation mode to desactivate DropOut modules by default
|
# Set model in evaluation mode to desactivate DropOut modules by default
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -843,8 +843,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
""" Generate sequences for each example without beam search (num_beams == 1).
|
""" Generate sequences for each example without beam search (num_beams == 1).
|
||||||
All returned sequence are generated independantly.
|
All returned sequence are generated independantly.
|
||||||
"""
|
"""
|
||||||
# current position / max lengths / length of generated sentences / unfinished sentences
|
# length of generated sentences / unfinished sentences
|
||||||
|
|
||||||
unfinished_sents = input_ids.new(batch_size).fill_(1)
|
unfinished_sents = input_ids.new(batch_size).fill_(1)
|
||||||
sent_lengths = input_ids.new(batch_size).fill_(max_length)
|
sent_lengths = input_ids.new(batch_size).fill_(max_length)
|
||||||
|
|
||||||
@@ -934,7 +933,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Expand input to num beams
|
# Expand input to num beams
|
||||||
# assert input_ids.shape == (batch_size * num_beams, cur_len)
|
|
||||||
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
|
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
|
||||||
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
|
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
|
||||||
|
|
||||||
@@ -946,7 +944,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
# scores for each sentence in the beam
|
# scores for each sentence in the beam
|
||||||
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
||||||
|
|
||||||
# Greedy decoding it is made sure that only words of the first beam are considered to avoid sampling the exact same words three times
|
# Greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
|
||||||
if do_sample is False:
|
if do_sample is False:
|
||||||
beam_scores[:, 1:] = -1e9
|
beam_scores[:, 1:] = -1e9
|
||||||
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
|
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
|
||||||
@@ -960,7 +958,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
while cur_len < max_length:
|
while cur_len < max_length:
|
||||||
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
|
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
|
||||||
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
||||||
scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
# if model has past, then set the past variable to speed up decoding
|
# if model has past, then set the past variable to speed up decoding
|
||||||
if self._do_output_past(outputs):
|
if self._do_output_past(outputs):
|
||||||
@@ -968,14 +966,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
|
|
||||||
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
||||||
if repetition_penalty != 1.0:
|
if repetition_penalty != 1.0:
|
||||||
self.enforce_repetition_penalty_(scores, batch_size, num_beams, input_ids, repetition_penalty)
|
self.enforce_repetition_penalty_(
|
||||||
|
next_token_logits, batch_size, num_beams, input_ids, repetition_penalty
|
||||||
|
)
|
||||||
|
|
||||||
if do_sample:
|
if do_sample:
|
||||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||||
if temperature != 1.0:
|
if temperature != 1.0:
|
||||||
scores = scores / temperature
|
next_token_logits = next_token_logits / temperature
|
||||||
|
|
||||||
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
|
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||||
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
|
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
# Top-p/top-k filtering
|
# Top-p/top-k filtering
|
||||||
@@ -988,25 +988,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
batch_size, num_beams * vocab_size
|
batch_size, num_beams * vocab_size
|
||||||
) # (batch_size, num_beams * vocab_size)
|
) # (batch_size, num_beams * vocab_size)
|
||||||
|
|
||||||
# Sample 2 next words for each beam (so we have some spare tokens and match output of greedy beam search)
|
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
|
||||||
next_words = torch.multinomial(
|
next_tokens = torch.multinomial(
|
||||||
F.softmax(_scores, dim=-1), num_samples=2 * num_beams
|
F.softmax(_scores, dim=-1), num_samples=2 * num_beams
|
||||||
) # (batch_size, num_beams * 2)
|
) # (batch_size, num_beams * 2)
|
||||||
|
|
||||||
# Compute next scores
|
# Compute next scores
|
||||||
next_scores = torch.gather(_scores, -1, next_words) # (batch_size, num_beams * 2)
|
next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# do greedy beam search
|
# do greedy beam search
|
||||||
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
|
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||||
assert scores.size() == (batch_size * num_beams, vocab_size)
|
assert scores.size() == (batch_size * num_beams, vocab_size)
|
||||||
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
|
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
|
||||||
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
|
next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
|
||||||
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
|
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
|
||||||
_scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size)
|
next_scores = next_scores.view(
|
||||||
next_scores, next_words = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
|
batch_size, num_beams * vocab_size
|
||||||
|
) # (batch_size, num_beams * vocab_size)
|
||||||
|
next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
|
||||||
|
|
||||||
assert next_scores.size() == next_words.size() == (batch_size, 2 * num_beams)
|
assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
|
||||||
|
|
||||||
# next batch beam content
|
# next batch beam content
|
||||||
# list of (batch_size * num_beams) tuple(next hypothesis score, next word, current position in the batch)
|
# list of (batch_size * num_beams) tuple(next hypothesis score, next word, current position in the batch)
|
||||||
@@ -1032,21 +1034,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
# next sentence beam content
|
# next sentence beam content
|
||||||
next_sent_beam = []
|
next_sent_beam = []
|
||||||
|
|
||||||
# next words for this sentence
|
# next tokens for this sentence
|
||||||
for idx, score in zip(next_words[batch_idx], next_scores[batch_idx]):
|
for idx, score in zip(next_tokens[batch_idx], next_scores[batch_idx]):
|
||||||
|
|
||||||
# get beam and word IDs
|
# get beam and word IDs
|
||||||
beam_id = idx // vocab_size
|
beam_id = idx // vocab_size
|
||||||
word_id = idx % vocab_size
|
token_id = idx % vocab_size
|
||||||
|
|
||||||
# add to generated hypotheses if end of sentence or last iteration
|
# add to generated hypotheses if end of sentence or last iteration
|
||||||
if eos_token_ids is not None and word_id.item() in eos_token_ids:
|
if eos_token_ids is not None and token_id.item() in eos_token_ids:
|
||||||
generated_hyps[batch_idx].add(
|
generated_hyps[batch_idx].add(
|
||||||
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item(),
|
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item(),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# add next predicted word if it is not eos_token
|
# add next predicted word if it is not eos_token
|
||||||
next_sent_beam.append((score, word_id, batch_idx * num_beams + beam_id))
|
next_sent_beam.append((score, token_id, batch_idx * num_beams + beam_id))
|
||||||
|
|
||||||
# the beam for next step is full
|
# the beam for next step is full
|
||||||
if len(next_sent_beam) == num_beams:
|
if len(next_sent_beam) == num_beams:
|
||||||
@@ -1060,12 +1062,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
# sanity check / prepare next batch
|
# sanity check / prepare next batch
|
||||||
assert len(next_batch_beam) == batch_size * num_beams
|
assert len(next_batch_beam) == batch_size * num_beams
|
||||||
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
||||||
beam_words = input_ids.new([x[1] for x in next_batch_beam])
|
beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
|
||||||
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
|
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
|
||||||
|
|
||||||
# re-order batch
|
# re-order batch
|
||||||
input_ids = input_ids[beam_idx, :]
|
input_ids = input_ids[beam_idx, :]
|
||||||
input_ids = torch.cat([input_ids, beam_words.unsqueeze(1)], dim=-1)
|
input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
|
||||||
|
|
||||||
# re-order internal states
|
# re-order internal states
|
||||||
if past:
|
if past:
|
||||||
@@ -1081,11 +1083,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(batch_size):
|
||||||
# Add all open beam hypothesis to generated_hyps
|
# Add all open beam hypothesis to generated_hyps
|
||||||
if not done[batch_idx]:
|
if not done[batch_idx]:
|
||||||
for idx, score in zip(next_words[batch_idx], next_scores[batch_idx]):
|
for idx, score in zip(next_tokens[batch_idx], next_scores[batch_idx]):
|
||||||
|
|
||||||
# get beam and word IDs
|
# get beam and word IDs
|
||||||
beam_id = idx // vocab_size
|
beam_id = idx // vocab_size
|
||||||
word_id = idx % vocab_size
|
token_id = idx % vocab_size
|
||||||
generated_hyps[batch_idx].add(
|
generated_hyps[batch_idx].add(
|
||||||
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item()
|
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item()
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user