[T5, Testst] Add extensive hard-coded integration tests and make sure PT and TF give equal results (#3550)
* add some t5 integration tests * finish summarization and translation integration tests for T5 - results loook good * add tf test * fix == vs is bug * fix tf beam search error and make tf t5 tests pass
This commit is contained in:
committed by
GitHub
parent
8538ce9044
commit
b815edf69f
@@ -960,6 +960,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
.to(input_ids.device)
|
||||
)
|
||||
encoder_outputs = (encoder_outputs[0].index_select(batch_idx, expanded_idx), *encoder_outputs[1:])
|
||||
|
||||
else:
|
||||
encoder_outputs = None
|
||||
cur_len = input_ids.shape[-1]
|
||||
@@ -1284,14 +1285,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
|
||||
zip(next_tokens[batch_idx], next_scores[batch_idx])
|
||||
):
|
||||
# get beam and word IDs
|
||||
# get beam and token IDs
|
||||
beam_id = beam_token_id // vocab_size
|
||||
token_id = beam_token_id % vocab_size
|
||||
|
||||
effective_beam_id = batch_idx * num_beams + beam_id
|
||||
|
||||
# add to generated hypotheses if end of sentence
|
||||
if (eos_token_id is not None) and (token_id.item() is eos_token_id):
|
||||
# add to generated hypotheses if end of sentence or last iteration
|
||||
if (eos_token_id is not None) and (token_id.item() == eos_token_id):
|
||||
# if beam_token does not belong to top num_beams tokens, it should not be added
|
||||
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
|
||||
if is_beam_token_worse_than_top_num_beams:
|
||||
@@ -1300,7 +1300,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
input_ids[effective_beam_id].clone(), beam_token_score.item(),
|
||||
)
|
||||
else:
|
||||
# add next predicted word if it is not eos_token
|
||||
# add next predicted token if it is not eos_token
|
||||
next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
|
||||
|
||||
# the beam for next step is full
|
||||
@@ -1330,7 +1330,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
# re-order batch
|
||||
input_ids = input_ids[beam_idx, :]
|
||||
input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
|
||||
|
||||
# re-order internal states
|
||||
if past is not None:
|
||||
past = self._reorder_cache(past, beam_idx)
|
||||
|
||||
Reference in New Issue
Block a user