Model output test (#6155)

* Use return_dict=True in all tests

* Formatting
This commit is contained in:
Sylvain Gugger
2020-07-31 09:44:37 -04:00
committed by GitHub
parent 86caab1e0b
commit d951c14ae4
26 changed files with 320 additions and 765 deletions

View File

@@ -83,6 +83,7 @@ class T5ModelTester:
bos_token_id=self.pad_token_id,
pad_token_id=self.pad_token_id,
decoder_start_token_id=self.decoder_start_token_id,
return_dict=True,
)
return (
@@ -136,13 +137,17 @@ class T5ModelTester:
model = T5Model(config=config)
model.to(torch_device)
model.eval()
decoder_output, decoder_past, encoder_output = model(
result = model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
decoder_output, decoder_past, encoder_output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
decoder_output = result["last_hidden_state"]
decoder_past = result["decoder_past_key_values"]
encoder_output = result["encoder_last_hidden_state"]
self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size))
self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.decoder_seq_length, self.hidden_size))
self.parent.assertEqual(len(decoder_past), 2)
@@ -162,10 +167,9 @@ class T5ModelTester:
decoder_attention_mask=decoder_attention_mask,
labels=lm_labels,
)
loss, prediction_scores, _, _ = outputs
self.parent.assertEqual(len(outputs), 4)
self.parent.assertEqual(prediction_scores.size(), (self.batch_size, self.decoder_seq_length, self.vocab_size))
self.parent.assertEqual(loss.size(), ())
self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size))
self.parent.assertEqual(outputs["loss"].size(), ())
def create_and_check_t5_decoder_model_past(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
@@ -179,7 +183,7 @@ class T5ModelTester:
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
output, past_key_value_states = outputs
output, past_key_value_states = outputs.to_tuple()
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
@@ -187,8 +191,8 @@ class T5ModelTester:
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
output_from_no_past = model(next_input_ids)[0]
output_from_past = model(next_tokens, past_key_value_states=past_key_value_states)[0]
output_from_no_past = model(next_input_ids)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_value_states=past_key_value_states)["last_hidden_state"]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
@@ -212,7 +216,7 @@ class T5ModelTester:
attn_mask[:, half_seq_length:] = 0
# first forward pass
output, past_key_value_states = model(input_ids, attention_mask=attn_mask, use_cache=True)
output, past_key_value_states = model(input_ids, attention_mask=attn_mask, use_cache=True).to_tuple()
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
@@ -229,8 +233,10 @@ class T5ModelTester:
)
# get two different outputs
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)[0]
output_from_past = model(next_tokens, past_key_value_states=past_key_value_states, attention_mask=attn_mask)[0]
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_value_states=past_key_value_states, attention_mask=attn_mask)[
"last_hidden_state"
]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
@@ -256,7 +262,7 @@ class T5ModelTester:
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
model = T5Model(config=config).to(torch_device).half().eval()
output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)[0]
output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"]
self.parent.assertFalse(torch.isnan(output).any().item())
def prepare_config_and_inputs_for_common(self):