@@ -60,7 +60,7 @@ from transformers import Trainer
|
|||||||
|
|
||||||
class CustomTrainer(Trainer):
|
class CustomTrainer(Trainer):
|
||||||
def compute_loss(self, model, inputs, return_outputs=False):
|
def compute_loss(self, model, inputs, return_outputs=False):
|
||||||
labels = inputs.get("labels")
|
labels = inputs.pop("labels")
|
||||||
# forward pass
|
# forward pass
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
logits = outputs.get("logits")
|
logits = outputs.get("logits")
|
||||||
|
|||||||
Reference in New Issue
Block a user