refactor variable naming and improve tf generate in line with torch generate

This commit is contained in:
Patrick von Platen
2020-03-09 19:18:11 +01:00
parent 41b437ea3a
commit ca2047bc35
2 changed files with 219 additions and 91 deletions

View File

@@ -722,6 +722,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
do_sample = do_sample if do_sample is not None else self.config.do_sample
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
num_beams = num_beams if num_beams is not None else self.config.num_beams
temperature = temperature if temperature is not None else self.config.temperature
@@ -852,7 +853,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
device=next(self.parameters()).device,
)
cur_len = 1
self.model.decoder.generation_mode = True
# put model in generation mode if it has one
if hasattr(self.model, "generation_mode"):
self.model.decoder.generation_mode = True
else:
encoder_inputs = None
cur_len = input_ids.shape[-1]
@@ -860,44 +864,44 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if num_beams > 1:
output = self._generate_beam_search(
input_ids,
cur_len,
max_length,
min_length,
do_sample,
early_stopping,
temperature,
top_k,
top_p,
repetition_penalty,
no_repeat_ngram_size,
bos_token_id,
pad_token_id,
eos_token_ids,
effective_batch_size,
num_return_sequences,
length_penalty,
num_beams,
vocab_size,
encoder_inputs,
attention_mask,
cur_len=cur_len,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
early_stopping=early_stopping,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids,
batch_size=effective_batch_size,
num_return_sequences=num_return_sequences,
length_penalty=length_penalty,
num_beams=num_beams,
vocab_size=vocab_size,
encoder_inputs=encoder_inputs,
attention_mask=attention_mask,
)
else:
output = self._generate_no_beam_search(
input_ids,
cur_len,
max_length,
min_length,
do_sample,
temperature,
top_k,
top_p,
repetition_penalty,
no_repeat_ngram_size,
pad_token_id,
eos_token_ids,
effective_batch_size,
encoder_inputs,
attention_mask,
cur_len=cur_len,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids,
batch_size=effective_batch_size,
encoder_inputs=encoder_inputs,
attention_mask=attention_mask,
)
return output
@@ -1157,24 +1161,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_sent_beam = []
# next tokens for this sentence
for i, (idx, score) in enumerate(zip(next_tokens[batch_idx], next_scores[batch_idx])):
for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
zip(next_tokens[batch_idx], next_scores[batch_idx])
):
# get beam and word IDs
beam_id = idx // vocab_size
token_id = idx % vocab_size
beam_id = beam_token_id // vocab_size
token_id = beam_token_id % vocab_size
effective_beam_id = batch_idx * num_beams + beam_id
# add to generated hypotheses if end of sentence
if (eos_token_ids is not None) and (token_id.item() in eos_token_ids):
# when passed to num_beams hypotheses, continue
if i >= num_beams:
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
if is_beam_token_worse_than_top_num_beams:
continue
generated_hyps[batch_idx].add(
input_ids[effective_beam_id].clone(), score.item(),
input_ids[effective_beam_id].clone(), beam_token_score.item(),
)
else:
# add next predicted word if it is not eos_token
next_sent_beam.append((score, token_id, effective_beam_id))
next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
# the beam for next step is full
if len(next_sent_beam) == num_beams: