cleanup tf unittests: part 2 (#6260)
* cleanup torch unittests: part 2 * remove trailing comma added by isort, and which breaks flake * one more comma * revert odd balls * part 3: odd cases * more ["key"] -> .key refactoring * .numpy() is not needed * more unncessary .numpy() removed * more simplification
This commit is contained in:
@@ -192,8 +192,8 @@ class XLNetModelTester:
|
||||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
[mem.shape for mem in result.mems],
|
||||
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def create_and_check_xlnet_model_use_cache(
|
||||
@@ -305,22 +305,22 @@ class XLNetModelTester:
|
||||
|
||||
result1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels)
|
||||
|
||||
result2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=result1["mems"])
|
||||
result2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=result1.mems)
|
||||
|
||||
_ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping)
|
||||
|
||||
self.parent.assertEqual(result1.loss.shape, ())
|
||||
self.parent.assertEqual(result1.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result1["mems"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
[mem.shape for mem in result1.mems],
|
||||
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
self.parent.assertEqual(result2.loss.shape, ())
|
||||
self.parent.assertEqual(result2.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result2["mems"]),
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
[mem.shape for mem in result2.mems],
|
||||
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def create_and_check_xlnet_qa(
|
||||
@@ -378,8 +378,8 @@ class XLNetModelTester:
|
||||
)
|
||||
self.parent.assertEqual(result.cls_logits.shape, (self.batch_size,))
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
[mem.shape for mem in result.mems],
|
||||
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def create_and_check_xlnet_token_classif(
|
||||
@@ -407,8 +407,8 @@ class XLNetModelTester:
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.type_sequence_label_size))
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
[mem.shape for mem in result.mems],
|
||||
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def create_and_check_xlnet_sequence_classif(
|
||||
@@ -436,8 +436,8 @@ class XLNetModelTester:
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
[mem.shape for mem in result.mems],
|
||||
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
|
||||
Reference in New Issue
Block a user