rename variables named 'word' to 'token' in generate fn (#3119)
* fix conflits * fixed naming bug * make style
This commit is contained in:
committed by
GitHub
parent
71c8711970
commit
006097f8ad
@@ -242,7 +242,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
# initialize all new embeddings (in particular added tokens)
|
||||
self._init_weights(new_embeddings)
|
||||
|
||||
# Copy word embeddings from the previous weights
|
||||
# Copy token embeddings from the previous weights
|
||||
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
||||
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
|
||||
|
||||
@@ -558,7 +558,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
model.__class__.__name__, "\n\t".join(error_msgs)
|
||||
)
|
||||
)
|
||||
model.tie_weights() # make sure word embedding weights are still tied if needed
|
||||
model.tie_weights() # make sure token embedding weights are still tied if needed
|
||||
|
||||
# Set model in evaluation mode to desactivate DropOut modules by default
|
||||
model.eval()
|
||||
@@ -843,8 +843,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
""" Generate sequences for each example without beam search (num_beams == 1).
|
||||
All returned sequence are generated independantly.
|
||||
"""
|
||||
# current position / max lengths / length of generated sentences / unfinished sentences
|
||||
|
||||
# length of generated sentences / unfinished sentences
|
||||
unfinished_sents = input_ids.new(batch_size).fill_(1)
|
||||
sent_lengths = input_ids.new(batch_size).fill_(max_length)
|
||||
|
||||
@@ -934,7 +933,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
"""
|
||||
|
||||
# Expand input to num beams
|
||||
# assert input_ids.shape == (batch_size * num_beams, cur_len)
|
||||
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
|
||||
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
|
||||
|
||||
@@ -946,7 +944,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
# scores for each sentence in the beam
|
||||
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
||||
|
||||
# Greedy decoding it is made sure that only words of the first beam are considered to avoid sampling the exact same words three times
|
||||
# Greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
|
||||
if do_sample is False:
|
||||
beam_scores[:, 1:] = -1e9
|
||||
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
|
||||
@@ -960,7 +958,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
|
||||
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
||||
scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# if model has past, then set the past variable to speed up decoding
|
||||
if self._do_output_past(outputs):
|
||||
@@ -968,14 +966,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
||||
if repetition_penalty != 1.0:
|
||||
self.enforce_repetition_penalty_(scores, batch_size, num_beams, input_ids, repetition_penalty)
|
||||
self.enforce_repetition_penalty_(
|
||||
next_token_logits, batch_size, num_beams, input_ids, repetition_penalty
|
||||
)
|
||||
|
||||
if do_sample:
|
||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||
if temperature != 1.0:
|
||||
scores = scores / temperature
|
||||
next_token_logits = next_token_logits / temperature
|
||||
|
||||
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# Top-p/top-k filtering
|
||||
@@ -988,25 +988,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
batch_size, num_beams * vocab_size
|
||||
) # (batch_size, num_beams * vocab_size)
|
||||
|
||||
# Sample 2 next words for each beam (so we have some spare tokens and match output of greedy beam search)
|
||||
next_words = torch.multinomial(
|
||||
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
|
||||
next_tokens = torch.multinomial(
|
||||
F.softmax(_scores, dim=-1), num_samples=2 * num_beams
|
||||
) # (batch_size, num_beams * 2)
|
||||
|
||||
# Compute next scores
|
||||
next_scores = torch.gather(_scores, -1, next_words) # (batch_size, num_beams * 2)
|
||||
next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)
|
||||
|
||||
else:
|
||||
# do greedy beam search
|
||||
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
assert scores.size() == (batch_size * num_beams, vocab_size)
|
||||
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
|
||||
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
|
||||
next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
|
||||
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
|
||||
_scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size)
|
||||
next_scores, next_words = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
|
||||
next_scores = next_scores.view(
|
||||
batch_size, num_beams * vocab_size
|
||||
) # (batch_size, num_beams * vocab_size)
|
||||
next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
|
||||
|
||||
assert next_scores.size() == next_words.size() == (batch_size, 2 * num_beams)
|
||||
assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
|
||||
|
||||
# next batch beam content
|
||||
# list of (batch_size * num_beams) tuple(next hypothesis score, next word, current position in the batch)
|
||||
@@ -1032,21 +1034,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
# next sentence beam content
|
||||
next_sent_beam = []
|
||||
|
||||
# next words for this sentence
|
||||
for idx, score in zip(next_words[batch_idx], next_scores[batch_idx]):
|
||||
# next tokens for this sentence
|
||||
for idx, score in zip(next_tokens[batch_idx], next_scores[batch_idx]):
|
||||
|
||||
# get beam and word IDs
|
||||
beam_id = idx // vocab_size
|
||||
word_id = idx % vocab_size
|
||||
token_id = idx % vocab_size
|
||||
|
||||
# add to generated hypotheses if end of sentence or last iteration
|
||||
if eos_token_ids is not None and word_id.item() in eos_token_ids:
|
||||
if eos_token_ids is not None and token_id.item() in eos_token_ids:
|
||||
generated_hyps[batch_idx].add(
|
||||
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item(),
|
||||
)
|
||||
else:
|
||||
# add next predicted word if it is not eos_token
|
||||
next_sent_beam.append((score, word_id, batch_idx * num_beams + beam_id))
|
||||
next_sent_beam.append((score, token_id, batch_idx * num_beams + beam_id))
|
||||
|
||||
# the beam for next step is full
|
||||
if len(next_sent_beam) == num_beams:
|
||||
@@ -1060,12 +1062,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
# sanity check / prepare next batch
|
||||
assert len(next_batch_beam) == batch_size * num_beams
|
||||
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
||||
beam_words = input_ids.new([x[1] for x in next_batch_beam])
|
||||
beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
|
||||
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
|
||||
|
||||
# re-order batch
|
||||
input_ids = input_ids[beam_idx, :]
|
||||
input_ids = torch.cat([input_ids, beam_words.unsqueeze(1)], dim=-1)
|
||||
input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
|
||||
|
||||
# re-order internal states
|
||||
if past:
|
||||
@@ -1081,11 +1083,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
for batch_idx in range(batch_size):
|
||||
# Add all open beam hypothesis to generated_hyps
|
||||
if not done[batch_idx]:
|
||||
for idx, score in zip(next_words[batch_idx], next_scores[batch_idx]):
|
||||
for idx, score in zip(next_tokens[batch_idx], next_scores[batch_idx]):
|
||||
|
||||
# get beam and word IDs
|
||||
beam_id = idx // vocab_size
|
||||
word_id = idx % vocab_size
|
||||
token_id = idx % vocab_size
|
||||
generated_hyps[batch_idx].add(
|
||||
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item()
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user