Model output test (#6155)
* Use return_dict=True in all tests * Formatting
This commit is contained in:
@@ -165,6 +165,7 @@ class ReformerModelTester:
|
||||
attn_layers=self.attn_layers,
|
||||
pad_token_id=self.pad_token_id,
|
||||
hash_seed=self.hash_seed,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return (
|
||||
@@ -181,15 +182,12 @@ class ReformerModelTester:
|
||||
model = ReformerModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
sequence_output, _ = model(input_ids, attention_mask=input_mask)
|
||||
sequence_output, _ = model(input_ids)
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
result = model(input_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
}
|
||||
# 2 * hidden_size because we use reversible resnet layers
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size],
|
||||
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size],
|
||||
)
|
||||
|
||||
def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels):
|
||||
@@ -198,7 +196,7 @@ class ReformerModelTester:
|
||||
model = ReformerForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0]
|
||||
loss = model(input_ids, attention_mask=input_mask, labels=input_ids)["loss"]
|
||||
loss.backward()
|
||||
|
||||
def create_and_check_reformer_with_lm(self, config, input_ids, input_mask, choice_labels):
|
||||
@@ -207,13 +205,9 @@ class ReformerModelTester:
|
||||
model = ReformerModelWithLMHead(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores, _ = model(input_ids, attention_mask=input_mask, labels=input_ids)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, labels=input_ids)
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
self.check_loss_output(result)
|
||||
|
||||
@@ -222,13 +216,9 @@ class ReformerModelTester:
|
||||
model = ReformerForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, labels=input_ids)
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
self.check_loss_output(result)
|
||||
|
||||
@@ -325,7 +315,7 @@ class ReformerModelTester:
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)[0]
|
||||
hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)["last_hidden_state"]
|
||||
self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
|
||||
|
||||
def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask, choice_labels):
|
||||
@@ -408,7 +398,7 @@ class ReformerModelTester:
|
||||
model.to(torch_device)
|
||||
model.half()
|
||||
model.eval()
|
||||
output = model(input_ids, attention_mask=input_mask)[0]
|
||||
output = model(input_ids, attention_mask=input_mask)["last_input_state"]
|
||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||
|
||||
def create_and_check_reformer_model_generate(self, config, input_ids, input_mask, choice_labels):
|
||||
@@ -444,21 +434,16 @@ class ReformerModelTester:
|
||||
model = ReformerForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
output_logits = model(input_ids, attention_mask=input_mask)[0]
|
||||
output_logits = model(input_ids, attention_mask=input_mask)["logits"]
|
||||
self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1])
|
||||
|
||||
def create_and_check_reformer_for_question_answering(self, config, input_ids, input_mask, choice_labels):
|
||||
model = ReformerForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, start_logits, end_logits = model(
|
||||
result = model(
|
||||
input_ids, attention_mask=input_mask, start_positions=choice_labels, end_positions=choice_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
@@ -474,11 +459,11 @@ class ReformerModelTester:
|
||||
input_ids_second = input_ids[:, -1:]
|
||||
|
||||
# return saved cache
|
||||
_, past_buckets_states = model(input_ids_first, use_cache=True)
|
||||
past_buckets_states = model(input_ids_first, use_cache=True)["past_buckets_states"]
|
||||
|
||||
# calculate last output with and without cache
|
||||
outputs_with_cache, _ = model(input_ids_second, past_buckets_states=past_buckets_states, use_cache=True)
|
||||
outputs_without_cache = model(input_ids)[0][:, -1]
|
||||
outputs_with_cache = model(input_ids_second, past_buckets_states=past_buckets_states, use_cache=True)["logits"]
|
||||
outputs_without_cache = model(input_ids)["logits"][:, -1]
|
||||
|
||||
# select random slice idx
|
||||
random_slice_idx = torch.randint(outputs_without_cache.shape[-1], (1, 1), device=torch_device).item()
|
||||
@@ -504,11 +489,7 @@ class ReformerModelTester:
|
||||
model = ReformerForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, logits = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user