transposing the inputs of Transformer-XL to have a unified interface

This commit is contained in:
thomwolf
2019-02-11 13:19:59 +01:00
parent 32fea876bb
commit 884ca81d87
4 changed files with 57 additions and 44 deletions

View File

@@ -67,12 +67,12 @@ class TransfoXLModelTest(unittest.TestCase):
self.seed = seed
def prepare_config_and_inputs(self):
input_ids_1 = TransfoXLModelTest.ids_tensor([self.seq_length, self.batch_size], self.vocab_size)
input_ids_2 = TransfoXLModelTest.ids_tensor([self.seq_length, self.batch_size], self.vocab_size)
input_ids_1 = TransfoXLModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids_2 = TransfoXLModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
lm_labels = None
if self.use_labels:
lm_labels = TransfoXLModelTest.ids_tensor([self.seq_length, self.batch_size], self.vocab_size)
lm_labels = TransfoXLModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
config = TransfoXLConfig(
vocab_size_or_config_json_file=self.vocab_size,
@@ -110,13 +110,13 @@ class TransfoXLModelTest(unittest.TestCase):
def check_transfo_xl_model_output(self, result):
self.parent.assertListEqual(
list(result["hidden_states_1"].size()),
[self.seq_length, self.batch_size, self.d_model])
[self.batch_size, self.seq_length, self.d_model])
self.parent.assertListEqual(
list(result["hidden_states_2"].size()),
[self.batch_size, self.seq_length, self.d_model])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(result["hidden_states_2"].size()),
[self.seq_length, self.batch_size, self.d_model])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
@@ -147,13 +147,13 @@ class TransfoXLModelTest(unittest.TestCase):
def check_transfo_xl_lm_head_output(self, result):
self.parent.assertListEqual(
list(result["loss_1"].size()),
[self.seq_length, self.batch_size])
[self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(result["lm_logits_1"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1a"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(result["lm_logits_1"].size()),
[self.seq_length, self.batch_size, self.vocab_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1b"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
@@ -163,13 +163,13 @@ class TransfoXLModelTest(unittest.TestCase):
self.parent.assertListEqual(
list(result["loss_2"].size()),
[self.seq_length, self.batch_size])
[self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(result["lm_logits_2"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2a"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(result["lm_logits_2"].size()),
[self.seq_length, self.batch_size, self.vocab_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2b"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)