Fix TF Causal LM models' returned logits (#15256)
* Fix TF Causal LM models' returned logits * Fix expected shape in the tests Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -1542,9 +1542,9 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
|
|
||||||
if inputs["labels"] is not None:
|
if inputs["labels"] is not None:
|
||||||
# shift labels to the left and cut last logit token
|
# shift labels to the left and cut last logit token
|
||||||
logits = logits[:, :-1]
|
shifted_logits = logits[:, :-1]
|
||||||
labels = inputs["labels"][:, 1:]
|
labels = inputs["labels"][:, 1:]
|
||||||
loss = self.hf_compute_loss(labels=labels, logits=logits)
|
loss = self.hf_compute_loss(labels=labels, logits=shifted_logits)
|
||||||
|
|
||||||
if not inputs["return_dict"]:
|
if not inputs["return_dict"]:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -735,9 +735,9 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
loss = None
|
loss = None
|
||||||
if inputs["labels"] is not None:
|
if inputs["labels"] is not None:
|
||||||
# shift labels to the left and cut last logit token
|
# shift labels to the left and cut last logit token
|
||||||
logits = logits[:, :-1]
|
shifted_logits = logits[:, :-1]
|
||||||
labels = inputs["labels"][:, 1:]
|
labels = inputs["labels"][:, 1:]
|
||||||
loss = self.hf_compute_loss(labels, logits)
|
loss = self.hf_compute_loss(labels, shifted_logits)
|
||||||
|
|
||||||
if not inputs["return_dict"]:
|
if not inputs["return_dict"]:
|
||||||
output = (logits,) + transformer_outputs[1:]
|
output = (logits,) + transformer_outputs[1:]
|
||||||
|
|||||||
@@ -949,9 +949,9 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
loss = None
|
loss = None
|
||||||
if inputs["labels"] is not None:
|
if inputs["labels"] is not None:
|
||||||
# shift labels to the left and cut last logit token
|
# shift labels to the left and cut last logit token
|
||||||
logits = logits[:, :-1]
|
shifted_logits = logits[:, :-1]
|
||||||
labels = inputs["labels"][:, 1:]
|
labels = inputs["labels"][:, 1:]
|
||||||
loss = self.hf_compute_loss(labels, logits)
|
loss = self.hf_compute_loss(labels, shifted_logits)
|
||||||
|
|
||||||
if not inputs["return_dict"]:
|
if not inputs["return_dict"]:
|
||||||
output = (logits,) + transformer_outputs[1:]
|
output = (logits,) + transformer_outputs[1:]
|
||||||
|
|||||||
@@ -656,9 +656,9 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
|
|||||||
loss = None
|
loss = None
|
||||||
if inputs["labels"] is not None:
|
if inputs["labels"] is not None:
|
||||||
# shift labels to the left and cut last logit token
|
# shift labels to the left and cut last logit token
|
||||||
logits = logits[:, :-1]
|
shifted_logits = logits[:, :-1]
|
||||||
labels = inputs["labels"][:, 1:]
|
labels = inputs["labels"][:, 1:]
|
||||||
loss = self.hf_compute_loss(labels, logits)
|
loss = self.hf_compute_loss(labels, shifted_logits)
|
||||||
|
|
||||||
if not inputs["return_dict"]:
|
if not inputs["return_dict"]:
|
||||||
output = (logits,) + transformer_outputs[1:]
|
output = (logits,) + transformer_outputs[1:]
|
||||||
|
|||||||
@@ -1275,9 +1275,9 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
|
|||||||
|
|
||||||
if inputs["labels"] is not None:
|
if inputs["labels"] is not None:
|
||||||
# shift labels to the left and cut last logit token
|
# shift labels to the left and cut last logit token
|
||||||
logits = logits[:, :-1]
|
shifted_logits = logits[:, :-1]
|
||||||
labels = inputs["labels"][:, 1:]
|
labels = inputs["labels"][:, 1:]
|
||||||
loss = self.hf_compute_loss(labels=labels, logits=logits)
|
loss = self.hf_compute_loss(labels=labels, logits=shifted_logits)
|
||||||
|
|
||||||
if not inputs["return_dict"]:
|
if not inputs["return_dict"]:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -1310,9 +1310,9 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos
|
|||||||
|
|
||||||
if inputs["labels"] is not None:
|
if inputs["labels"] is not None:
|
||||||
# shift labels to the left and cut last logit token
|
# shift labels to the left and cut last logit token
|
||||||
logits = logits[:, :-1]
|
shifted_logits = logits[:, :-1]
|
||||||
labels = inputs["labels"][:, 1:]
|
labels = inputs["labels"][:, 1:]
|
||||||
loss = self.hf_compute_loss(labels=labels, logits=logits)
|
loss = self.hf_compute_loss(labels=labels, logits=shifted_logits)
|
||||||
|
|
||||||
if not inputs["return_dict"]:
|
if not inputs["return_dict"]:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -1035,9 +1035,9 @@ class TFRoFormerForCausalLM(TFRoFormerPreTrainedModel, TFCausalLanguageModelingL
|
|||||||
|
|
||||||
if inputs["labels"] is not None:
|
if inputs["labels"] is not None:
|
||||||
# shift labels to the left and cut last logit token
|
# shift labels to the left and cut last logit token
|
||||||
logits = logits[:, :-1]
|
shifted_logits = logits[:, :-1]
|
||||||
labels = inputs["labels"][:, 1:]
|
labels = inputs["labels"][:, 1:]
|
||||||
loss = self.hf_compute_loss(labels=labels, logits=logits)
|
loss = self.hf_compute_loss(labels=labels, logits=shifted_logits)
|
||||||
|
|
||||||
if not inputs["return_dict"]:
|
if not inputs["return_dict"]:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -1262,9 +1262,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
|
|||||||
|
|
||||||
if inputs["labels"] is not None:
|
if inputs["labels"] is not None:
|
||||||
# shift labels to the left and cut last logit token
|
# shift labels to the left and cut last logit token
|
||||||
logits = logits[:, :-1]
|
shifted_logits = logits[:, :-1]
|
||||||
labels = inputs["labels"][:, 1:]
|
labels = inputs["labels"][:, 1:]
|
||||||
loss = self.hf_compute_loss(labels=labels, logits=logits)
|
loss = self.hf_compute_loss(labels=labels, logits=shifted_logits)
|
||||||
|
|
||||||
if not inputs["return_dict"]:
|
if not inputs["return_dict"]:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -240,7 +240,7 @@ class TFEncoderDecoderMixin:
|
|||||||
assert "loss" in outputs_encoder_decoder
|
assert "loss" in outputs_encoder_decoder
|
||||||
|
|
||||||
batch_size, seq_len = decoder_input_ids.shape
|
batch_size, seq_len = decoder_input_ids.shape
|
||||||
expected_shape = (batch_size, seq_len - 1, decoder_config.vocab_size)
|
expected_shape = (batch_size, seq_len, decoder_config.vocab_size)
|
||||||
self.assertEqual(outputs_encoder_decoder["logits"].shape, expected_shape)
|
self.assertEqual(outputs_encoder_decoder["logits"].shape, expected_shape)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
||||||
|
|||||||
@@ -231,7 +231,7 @@ class TFVisionEncoderDecoderMixin:
|
|||||||
self.assertIn("loss", outputs_encoder_decoder)
|
self.assertIn("loss", outputs_encoder_decoder)
|
||||||
|
|
||||||
batch_size, seq_len = decoder_input_ids.shape
|
batch_size, seq_len = decoder_input_ids.shape
|
||||||
expected_shape = (batch_size, seq_len - 1, decoder_config.vocab_size)
|
expected_shape = (batch_size, seq_len, decoder_config.vocab_size)
|
||||||
self.assertEqual(outputs_encoder_decoder["logits"].shape, expected_shape)
|
self.assertEqual(outputs_encoder_decoder["logits"].shape, expected_shape)
|
||||||
self.assertEqual(outputs_encoder_decoder["encoder_last_hidden_state"].shape[0], pixel_values.shape[0])
|
self.assertEqual(outputs_encoder_decoder["encoder_last_hidden_state"].shape[0], pixel_values.shape[0])
|
||||||
self.assertEqual(outputs_encoder_decoder["encoder_last_hidden_state"].shape[-1], config.hidden_size)
|
self.assertEqual(outputs_encoder_decoder["encoder_last_hidden_state"].shape[-1], config.hidden_size)
|
||||||
|
|||||||
Reference in New Issue
Block a user