transposing the inputs of Transformer-XL to have a unified interface
This commit is contained in:
@@ -67,12 +67,12 @@ class TransfoXLModelTest(unittest.TestCase):
|
||||
self.seed = seed
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids_1 = TransfoXLModelTest.ids_tensor([self.seq_length, self.batch_size], self.vocab_size)
|
||||
input_ids_2 = TransfoXLModelTest.ids_tensor([self.seq_length, self.batch_size], self.vocab_size)
|
||||
input_ids_1 = TransfoXLModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
input_ids_2 = TransfoXLModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
lm_labels = None
|
||||
if self.use_labels:
|
||||
lm_labels = TransfoXLModelTest.ids_tensor([self.seq_length, self.batch_size], self.vocab_size)
|
||||
lm_labels = TransfoXLModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
config = TransfoXLConfig(
|
||||
vocab_size_or_config_json_file=self.vocab_size,
|
||||
@@ -110,13 +110,13 @@ class TransfoXLModelTest(unittest.TestCase):
|
||||
def check_transfo_xl_model_output(self, result):
|
||||
self.parent.assertListEqual(
|
||||
list(result["hidden_states_1"].size()),
|
||||
[self.seq_length, self.batch_size, self.d_model])
|
||||
[self.batch_size, self.seq_length, self.d_model])
|
||||
self.parent.assertListEqual(
|
||||
list(result["hidden_states_2"].size()),
|
||||
[self.batch_size, self.seq_length, self.d_model])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
|
||||
self.parent.assertListEqual(
|
||||
list(result["hidden_states_2"].size()),
|
||||
[self.seq_length, self.batch_size, self.d_model])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_2"]),
|
||||
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
|
||||
@@ -147,13 +147,13 @@ class TransfoXLModelTest(unittest.TestCase):
|
||||
def check_transfo_xl_lm_head_output(self, result):
|
||||
self.parent.assertListEqual(
|
||||
list(result["loss_1"].size()),
|
||||
[self.seq_length, self.batch_size])
|
||||
[self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_1"].size()),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1a"]),
|
||||
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_1"].size()),
|
||||
[self.seq_length, self.batch_size, self.vocab_size])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1b"]),
|
||||
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
|
||||
@@ -163,13 +163,13 @@ class TransfoXLModelTest(unittest.TestCase):
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["loss_2"].size()),
|
||||
[self.seq_length, self.batch_size])
|
||||
[self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_2"].size()),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_2a"]),
|
||||
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_2"].size()),
|
||||
[self.seq_length, self.batch_size, self.vocab_size])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_2b"]),
|
||||
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
|
||||
|
||||
Reference in New Issue
Block a user