[T5, Testst] Add extensive hard-coded integration tests and make sure PT and TF give equal results (#3550)

* add some t5 integration tests

* finish summarization and translation integration tests for T5 - results loook good

* add tf test

* fix == vs is bug

* fix tf beam search error and make tf t5 tests pass
This commit is contained in:
Patrick von Platen
2020-04-01 18:01:33 +02:00
committed by GitHub
parent 8538ce9044
commit b815edf69f
5 changed files with 326 additions and 36 deletions

View File

@@ -960,6 +960,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
.to(input_ids.device)
)
encoder_outputs = (encoder_outputs[0].index_select(batch_idx, expanded_idx), *encoder_outputs[1:])
else:
encoder_outputs = None
cur_len = input_ids.shape[-1]
@@ -1284,14 +1285,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
zip(next_tokens[batch_idx], next_scores[batch_idx])
):
# get beam and word IDs
# get beam and token IDs
beam_id = beam_token_id // vocab_size
token_id = beam_token_id % vocab_size
effective_beam_id = batch_idx * num_beams + beam_id
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (token_id.item() is eos_token_id):
# add to generated hypotheses if end of sentence or last iteration
if (eos_token_id is not None) and (token_id.item() == eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
if is_beam_token_worse_than_top_num_beams:
@@ -1300,7 +1300,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
input_ids[effective_beam_id].clone(), beam_token_score.item(),
)
else:
# add next predicted word if it is not eos_token
# add next predicted token if it is not eos_token
next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
# the beam for next step is full
@@ -1330,7 +1330,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# re-order batch
input_ids = input_ids[beam_idx, :]
input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
# re-order internal states
if past is not None:
past = self._reorder_cache(past, beam_idx)