Model output test (#6155)
* Use return_dict=True in all tests * Formatting
This commit is contained in:
@@ -75,6 +75,7 @@ class TransfoXLModelTester:
|
||||
div_val=self.div_val,
|
||||
n_layer=self.num_hidden_layers,
|
||||
eos_token_id=self.eos_token_id,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return (config, input_ids_1, input_ids_2, lm_labels)
|
||||
@@ -88,13 +89,13 @@ class TransfoXLModelTester:
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
hidden_states_1, mems_1 = model(input_ids_1)
|
||||
hidden_states_2, mems_2 = model(input_ids_2, mems_1)
|
||||
outputs1 = model(input_ids_1)
|
||||
outputs2 = model(input_ids_2, outputs1["mems"])
|
||||
outputs = {
|
||||
"hidden_states_1": hidden_states_1,
|
||||
"mems_1": mems_1,
|
||||
"hidden_states_2": hidden_states_2,
|
||||
"mems_2": mems_2,
|
||||
"hidden_states_1": outputs1["last_hidden_state"],
|
||||
"mems_1": outputs1["mems"],
|
||||
"hidden_states_2": outputs2["last_hidden_state"],
|
||||
"mems_2": outputs2["mems"],
|
||||
}
|
||||
return outputs
|
||||
|
||||
@@ -119,17 +120,17 @@ class TransfoXLModelTester:
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
lm_logits_1, mems_1 = model(input_ids_1)
|
||||
loss_1, _, mems_1 = model(input_ids_1, labels=lm_labels)
|
||||
lm_logits_2, mems_2 = model(input_ids_2, mems=mems_1)
|
||||
loss_2, _, mems_2 = model(input_ids_2, labels=lm_labels, mems=mems_1)
|
||||
lm_logits_1 = model(input_ids_1)["prediction_scores"]
|
||||
outputs1 = model(input_ids_1, labels=lm_labels)
|
||||
lm_logits_2 = model(input_ids_2, mems=outputs1["mems"])["prediction_scores"]
|
||||
outputs2 = model(input_ids_2, labels=lm_labels, mems=outputs1["mems"])
|
||||
|
||||
outputs = {
|
||||
"loss_1": loss_1,
|
||||
"mems_1": mems_1,
|
||||
"loss_1": outputs1["losses"],
|
||||
"mems_1": outputs1["mems"],
|
||||
"lm_logits_1": lm_logits_1,
|
||||
"loss_2": loss_2,
|
||||
"mems_2": mems_2,
|
||||
"loss_2": outputs2["losses"],
|
||||
"mems_2": outputs2["mems"],
|
||||
"lm_logits_2": lm_logits_2,
|
||||
}
|
||||
return outputs
|
||||
|
||||
Reference in New Issue
Block a user