[s2s] only save metrics.json from rank zero (#7331)
This commit is contained in:
@@ -30,7 +30,6 @@ from utils import (
|
||||
lmap,
|
||||
pickle_save,
|
||||
save_git_info,
|
||||
save_json,
|
||||
use_task_specific_params,
|
||||
)
|
||||
|
||||
@@ -189,7 +188,7 @@ class SummarizationModule(BaseTransformer):
|
||||
losses.update(generative_metrics)
|
||||
all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
|
||||
all_metrics["step_count"] = self.step_count
|
||||
self.save_metrics(all_metrics, prefix) # writes to self.metrics_save_path
|
||||
self.metrics[prefix].append(all_metrics) # callback writes this to self.metrics_save_path
|
||||
preds = flatten_list([x["preds"] for x in outputs])
|
||||
return {
|
||||
"log": all_metrics,
|
||||
@@ -198,10 +197,6 @@ class SummarizationModule(BaseTransformer):
|
||||
f"{prefix}_{self.val_metric}": metric_tensor,
|
||||
}
|
||||
|
||||
def save_metrics(self, latest_metrics, type_path) -> None:
|
||||
self.metrics[type_path].append(latest_metrics)
|
||||
save_json(self.metrics, self.metrics_save_path)
|
||||
|
||||
def calc_generative_metrics(self, preds, target) -> Dict:
|
||||
return calculate_rouge(preds, target)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user