Fix TF Causal LM models' returned logits (#15256)
* Fix TF Causal LM models' returned logits * Fix expected shape in the tests Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -240,7 +240,7 @@ class TFEncoderDecoderMixin:
|
||||
assert "loss" in outputs_encoder_decoder
|
||||
|
||||
batch_size, seq_len = decoder_input_ids.shape
|
||||
expected_shape = (batch_size, seq_len - 1, decoder_config.vocab_size)
|
||||
expected_shape = (batch_size, seq_len, decoder_config.vocab_size)
|
||||
self.assertEqual(outputs_encoder_decoder["logits"].shape, expected_shape)
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
||||
|
||||
Reference in New Issue
Block a user