Shift labels internally within TransfoXLLMHeadModel when called with labels (#3716)
* Shifting labels inside TransfoXLLMHead * Changed doc to reflect change * Updated pytorch test * removed IDE whitespace changes * black reformat Co-authored-by: TevenLeScao <teven.lescao@gmail.com>
This commit is contained in:
@@ -164,7 +164,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
return outputs
|
||||
|
||||
def check_transfo_xl_lm_head_output(self, result):
|
||||
self.parent.assertListEqual(list(result["loss_1"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["loss_1"].size()), [self.batch_size, self.seq_length - 1])
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
@@ -173,7 +173,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
self.parent.assertListEqual(list(result["loss_2"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["loss_2"].size()), [self.batch_size, self.seq_length - 1])
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user