Updated the custom_models.md changed cross_entropy code (#33118)

This commit is contained in:
S M Jishanul Islam
2024-08-26 17:15:43 +06:00
committed by GitHub
parent 0a7af19f4d
commit 8defc95df3
7 changed files with 7 additions and 7 deletions

View File

@@ -185,7 +185,7 @@ class ResnetModelForImageClassification(PreTrainedModel):
def forward(self, tensor, labels=None):
logits = self.model(tensor)
if labels is not None:
loss = torch.nn.cross_entropy(logits, labels)
loss = torch.nn.functional.cross_entropy(logits, labels)
return {"loss": loss, "logits": logits}
return {"logits": logits}
```