Flos fix (#7384)
This commit is contained in:
@@ -695,7 +695,7 @@ class Trainer:
|
||||
# set global_step to global_step of last saved checkpoint from model path
|
||||
try:
|
||||
self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
|
||||
self.total_flos = getattr(model.config, "total_flos", 0)
|
||||
self.total_flos = getattr(self._actual_model(model).config, "total_flos", 0)
|
||||
|
||||
epochs_trained = self.global_step // num_update_steps_per_epoch
|
||||
steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch)
|
||||
@@ -1448,15 +1448,29 @@ class Trainer:
|
||||
:obj:`int`: The number of floating-point operations.
|
||||
"""
|
||||
|
||||
if isinstance(self.model, torch.nn.DataParallel) or isinstance(
|
||||
self.model, torch.nn.parallel.DistributedDataParallel
|
||||
):
|
||||
model = self.model.module
|
||||
else:
|
||||
model = self.model
|
||||
model = self._actual_model(self.model)
|
||||
|
||||
if hasattr(model, "floating_point_ops"):
|
||||
return model.floating_point_ops(inputs)
|
||||
|
||||
else:
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def _actual_model(
|
||||
model: Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]
|
||||
) -> torch.nn.modules.Module:
|
||||
"""
|
||||
|
||||
Args:
|
||||
model: (:obj:`Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]`):
|
||||
Model object used during training
|
||||
|
||||
Returns:
|
||||
:obj:`torch.nn.modules.Module`: unwrapped module
|
||||
"""
|
||||
if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
model = model.module
|
||||
else:
|
||||
model = model
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user