Compute loss method (#7074)

This commit is contained in:
Sylvain Gugger
2020-09-11 12:06:31 -04:00
committed by GitHub
parent ae736163d0
commit 4cbd50e611
2 changed files with 28 additions and 8 deletions

View File

@@ -1024,15 +1024,9 @@ class Trainer:
if self.args.fp16 and _use_native_amp:
with autocast():
outputs = model(**inputs)
loss = outputs[0]
loss = self.compute_loss(model, inputs)
else:
outputs = model(**inputs)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs[0]
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
loss = self.compute_loss(model, inputs)
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
@@ -1050,6 +1044,19 @@ class Trainer:
return loss.detach()
def compute_loss(self, model, inputs):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
outputs = model(**inputs)
# Save past state if it exists
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
# We don't use .loss here since the model may return tuples instead of ModelOutput.
return outputs[0]
def is_local_master(self) -> bool:
"""
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on