Adding LM Head to Transfo-XL and first step to fixing problem with Adaptive Embeddings in TransfoXL (#3286)
* first commit * work in progress * make language generation task pass * update to working version for LM * delete print * remove dead code * make style
This commit is contained in:
committed by
GitHub
parent
efdb46b6e2
commit
292186a3e7
@@ -129,10 +129,10 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
def check_transfo_xl_model_output(self, result):
|
||||
self.parent.assertListEqual(
|
||||
list(result["hidden_states_1"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
list(result["hidden_states_1"].size()), [self.batch_size, self.seq_length, self.hidden_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(result["hidden_states_2"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
list(result["hidden_states_2"].size()), [self.batch_size, self.seq_length, self.hidden_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
@@ -166,7 +166,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def check_transfo_xl_lm_head_output(self, result):
|
||||
self.parent.assertListEqual(list(result["loss_1"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
list(result["lm_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
@@ -175,7 +175,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
self.parent.assertListEqual(list(result["loss_2"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
list(result["lm_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_2"]),
|
||||
|
||||
Reference in New Issue
Block a user