refactor variable naming and improve tf generate in line with torch generate
This commit is contained in:
@@ -722,6 +722,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
min_length = min_length if min_length is not None else self.config.min_length
|
||||
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||
temperature = temperature if temperature is not None else self.config.temperature
|
||||
@@ -852,7 +853,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
device=next(self.parameters()).device,
|
||||
)
|
||||
cur_len = 1
|
||||
self.model.decoder.generation_mode = True
|
||||
|
||||
# put model in generation mode if it has one
|
||||
if hasattr(self.model, "generation_mode"):
|
||||
self.model.decoder.generation_mode = True
|
||||
else:
|
||||
encoder_inputs = None
|
||||
cur_len = input_ids.shape[-1]
|
||||
@@ -860,44 +864,44 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
if num_beams > 1:
|
||||
output = self._generate_beam_search(
|
||||
input_ids,
|
||||
cur_len,
|
||||
max_length,
|
||||
min_length,
|
||||
do_sample,
|
||||
early_stopping,
|
||||
temperature,
|
||||
top_k,
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
no_repeat_ngram_size,
|
||||
bos_token_id,
|
||||
pad_token_id,
|
||||
eos_token_ids,
|
||||
effective_batch_size,
|
||||
num_return_sequences,
|
||||
length_penalty,
|
||||
num_beams,
|
||||
vocab_size,
|
||||
encoder_inputs,
|
||||
attention_mask,
|
||||
cur_len=cur_len,
|
||||
max_length=max_length,
|
||||
min_length=min_length,
|
||||
do_sample=do_sample,
|
||||
early_stopping=early_stopping,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
bos_token_id=bos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_ids=eos_token_ids,
|
||||
batch_size=effective_batch_size,
|
||||
num_return_sequences=num_return_sequences,
|
||||
length_penalty=length_penalty,
|
||||
num_beams=num_beams,
|
||||
vocab_size=vocab_size,
|
||||
encoder_inputs=encoder_inputs,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
else:
|
||||
output = self._generate_no_beam_search(
|
||||
input_ids,
|
||||
cur_len,
|
||||
max_length,
|
||||
min_length,
|
||||
do_sample,
|
||||
temperature,
|
||||
top_k,
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
no_repeat_ngram_size,
|
||||
pad_token_id,
|
||||
eos_token_ids,
|
||||
effective_batch_size,
|
||||
encoder_inputs,
|
||||
attention_mask,
|
||||
cur_len=cur_len,
|
||||
max_length=max_length,
|
||||
min_length=min_length,
|
||||
do_sample=do_sample,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_ids=eos_token_ids,
|
||||
batch_size=effective_batch_size,
|
||||
encoder_inputs=encoder_inputs,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
return output
|
||||
@@ -1157,24 +1161,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
next_sent_beam = []
|
||||
|
||||
# next tokens for this sentence
|
||||
for i, (idx, score) in enumerate(zip(next_tokens[batch_idx], next_scores[batch_idx])):
|
||||
for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
|
||||
zip(next_tokens[batch_idx], next_scores[batch_idx])
|
||||
):
|
||||
# get beam and word IDs
|
||||
beam_id = idx // vocab_size
|
||||
token_id = idx % vocab_size
|
||||
beam_id = beam_token_id // vocab_size
|
||||
token_id = beam_token_id % vocab_size
|
||||
|
||||
effective_beam_id = batch_idx * num_beams + beam_id
|
||||
|
||||
# add to generated hypotheses if end of sentence
|
||||
if (eos_token_ids is not None) and (token_id.item() in eos_token_ids):
|
||||
# when passed to num_beams hypotheses, continue
|
||||
if i >= num_beams:
|
||||
# if beam_token does not belong to top num_beams tokens, it should not be added
|
||||
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
|
||||
if is_beam_token_worse_than_top_num_beams:
|
||||
continue
|
||||
generated_hyps[batch_idx].add(
|
||||
input_ids[effective_beam_id].clone(), score.item(),
|
||||
input_ids[effective_beam_id].clone(), beam_token_score.item(),
|
||||
)
|
||||
else:
|
||||
# add next predicted word if it is not eos_token
|
||||
next_sent_beam.append((score, token_id, effective_beam_id))
|
||||
next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
|
||||
|
||||
# the beam for next step is full
|
||||
if len(next_sent_beam) == num_beams:
|
||||
|
||||
Reference in New Issue
Block a user