Shift labels internally within TransfoXLLMHeadModel when called with labels (#3716)

* Shifting labels inside TransfoXLLMHead

* Changed doc to reflect change

* Updated pytorch test

* removed IDE whitespace changes

* black reformat

Co-authored-by: TevenLeScao <teven.lescao@gmail.com>
This commit is contained in:
Teven
2020-04-13 18:11:23 +02:00
committed by GitHub
parent 5ebd898953
commit 352d5472b0
3 changed files with 12 additions and 6 deletions

View File

@@ -164,7 +164,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
return outputs
def check_transfo_xl_lm_head_output(self, result):
self.parent.assertListEqual(list(result["loss_1"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["loss_1"].size()), [self.batch_size, self.seq_length - 1])
self.parent.assertListEqual(
list(result["lm_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
@@ -173,7 +173,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
)
self.parent.assertListEqual(list(result["loss_2"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["loss_2"].size()), [self.batch_size, self.seq_length - 1])
self.parent.assertListEqual(
list(result["lm_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)