Adding LM Head to Transfo-XL and first step to fixing problem with Adaptive Embeddings in TransfoXL (#3286)

* first commit

* work in progress

* make language generation task pass

* update to working version for LM

* delete print

* remove dead code

* make style
This commit is contained in:
Patrick von Platen
2020-03-18 14:24:27 +01:00
committed by GitHub
parent efdb46b6e2
commit 292186a3e7
7 changed files with 79 additions and 147 deletions

View File

@@ -129,10 +129,10 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
def check_transfo_xl_model_output(self, result):
self.parent.assertListEqual(
list(result["hidden_states_1"].size()), [self.batch_size, self.seq_length, self.hidden_size]
list(result["hidden_states_1"].size()), [self.batch_size, self.seq_length, self.hidden_size],
)
self.parent.assertListEqual(
list(result["hidden_states_2"].size()), [self.batch_size, self.seq_length, self.hidden_size]
list(result["hidden_states_2"].size()), [self.batch_size, self.seq_length, self.hidden_size],
)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
@@ -166,7 +166,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
def check_transfo_xl_lm_head_output(self, result):
self.parent.assertListEqual(list(result["loss_1"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(result["lm_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size]
list(result["lm_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
@@ -175,7 +175,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual(list(result["loss_2"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(result["lm_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size]
list(result["lm_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2"]),