Fix TF Causal LM models' returned logits (#15256)

* Fix TF Causal LM models' returned logits

* Fix expected shape in the tests

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2022-02-01 12:04:07 +01:00
committed by GitHub
parent af5c3329d7
commit dc05dd539f
10 changed files with 18 additions and 18 deletions

View File

@@ -1262,9 +1262,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
if inputs["labels"] is not None:
# shift labels to the left and cut last logit token
logits = logits[:, :-1]
shifted_logits = logits[:, :-1]
labels = inputs["labels"][:, 1:]
loss = self.hf_compute_loss(labels=labels, logits=logits)
loss = self.hf_compute_loss(labels=labels, logits=shifted_logits)
if not inputs["return_dict"]:
output = (logits,) + outputs[2:]