Model output test (#6155)
* Use return_dict=True in all tests * Formatting
This commit is contained in:
@@ -83,6 +83,7 @@ class T5ModelTester:
|
||||
bos_token_id=self.pad_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
decoder_start_token_id=self.decoder_start_token_id,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return (
|
||||
@@ -136,13 +137,17 @@ class T5ModelTester:
|
||||
model = T5Model(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
decoder_output, decoder_past, encoder_output = model(
|
||||
result = model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
decoder_output, decoder_past, encoder_output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
decoder_output = result["last_hidden_state"]
|
||||
decoder_past = result["decoder_past_key_values"]
|
||||
encoder_output = result["encoder_last_hidden_state"]
|
||||
|
||||
self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.decoder_seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(len(decoder_past), 2)
|
||||
@@ -162,10 +167,9 @@ class T5ModelTester:
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
labels=lm_labels,
|
||||
)
|
||||
loss, prediction_scores, _, _ = outputs
|
||||
self.parent.assertEqual(len(outputs), 4)
|
||||
self.parent.assertEqual(prediction_scores.size(), (self.batch_size, self.decoder_seq_length, self.vocab_size))
|
||||
self.parent.assertEqual(loss.size(), ())
|
||||
self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size))
|
||||
self.parent.assertEqual(outputs["loss"].size(), ())
|
||||
|
||||
def create_and_check_t5_decoder_model_past(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
@@ -179,7 +183,7 @@ class T5ModelTester:
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||
|
||||
output, past_key_value_states = outputs
|
||||
output, past_key_value_states = outputs.to_tuple()
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
@@ -187,8 +191,8 @@ class T5ModelTester:
|
||||
# append to next input_ids and
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
|
||||
output_from_no_past = model(next_input_ids)[0]
|
||||
output_from_past = model(next_tokens, past_key_value_states=past_key_value_states)[0]
|
||||
output_from_no_past = model(next_input_ids)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, past_key_value_states=past_key_value_states)["last_hidden_state"]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
@@ -212,7 +216,7 @@ class T5ModelTester:
|
||||
attn_mask[:, half_seq_length:] = 0
|
||||
|
||||
# first forward pass
|
||||
output, past_key_value_states = model(input_ids, attention_mask=attn_mask, use_cache=True)
|
||||
output, past_key_value_states = model(input_ids, attention_mask=attn_mask, use_cache=True).to_tuple()
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
@@ -229,8 +233,10 @@ class T5ModelTester:
|
||||
)
|
||||
|
||||
# get two different outputs
|
||||
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)[0]
|
||||
output_from_past = model(next_tokens, past_key_value_states=past_key_value_states, attention_mask=attn_mask)[0]
|
||||
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, past_key_value_states=past_key_value_states, attention_mask=attn_mask)[
|
||||
"last_hidden_state"
|
||||
]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
@@ -256,7 +262,7 @@ class T5ModelTester:
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
):
|
||||
model = T5Model(config=config).to(torch_device).half().eval()
|
||||
output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)[0]
|
||||
output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"]
|
||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
|
||||
Reference in New Issue
Block a user