cleanup tf unittests: part 2 (#6260)
* cleanup torch unittests: part 2 * remove trailing comma added by isort, and which breaks flake * one more comma * revert odd balls * part 3: odd cases * more ["key"] -> .key refactoring * .numpy() is not needed * more unncessary .numpy() removed * more simplification
This commit is contained in:
@@ -100,19 +100,15 @@ class TransfoXLModelTester:
|
||||
return outputs
|
||||
|
||||
def check_transfo_xl_model_output(self, result):
|
||||
self.parent.assertEqual(result["hidden_states_1"].shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(result["hidden_states_2"].shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertListEqual(
|
||||
list(result["hidden_states_1"].size()), [self.batch_size, self.seq_length, self.hidden_size],
|
||||
[mem.shape for mem in result["mems_1"]],
|
||||
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(result["hidden_states_2"].size()), [self.batch_size, self.seq_length, self.hidden_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_2"]),
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
[mem.shape for mem in result["mems_2"]],
|
||||
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels):
|
||||
@@ -136,22 +132,18 @@ class TransfoXLModelTester:
|
||||
return outputs
|
||||
|
||||
def check_transfo_xl_lm_head_output(self, result):
|
||||
self.parent.assertListEqual(list(result["loss_1"].size()), [self.batch_size, self.seq_length - 1])
|
||||
self.parent.assertEqual(result["loss_1"].shape, (self.batch_size, self.seq_length - 1))
|
||||
self.parent.assertEqual(result["lm_logits_1"].shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
[mem.shape for mem in result["mems_1"]],
|
||||
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
self.parent.assertListEqual(list(result["loss_2"].size()), [self.batch_size, self.seq_length - 1])
|
||||
self.parent.assertEqual(result["loss_2"].shape, (self.batch_size, self.seq_length - 1))
|
||||
self.parent.assertEqual(result["lm_logits_2"].shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_2"]),
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
[mem.shape for mem in result["mems_2"]],
|
||||
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
|
||||
Reference in New Issue
Block a user