Fix quality due to ruff release

This commit is contained in:
Sylvain
2023-03-22 20:45:08 -04:00
parent 73fdc8c5b4
commit ef28df0572
28 changed files with 40 additions and 58 deletions

View File

@@ -170,7 +170,7 @@ class SummarizationModule(BaseTransformer):
def training_step(self, batch, batch_idx) -> Dict:
loss_tensors = self._step(batch)
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
logs = dict(zip(self.loss_names, loss_tensors))
# tokens per batch
logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum()
logs["bs"] = batch["input_ids"].shape[0]
@@ -225,7 +225,7 @@ class SummarizationModule(BaseTransformer):
preds: List[str] = self.ids_to_clean_text(generated_ids)
target: List[str] = self.ids_to_clean_text(batch["labels"])
loss_tensors = self._step(batch)
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
base_metrics = dict(zip(self.loss_names, loss_tensors))
rouge: Dict = self.calc_generative_metrics(preds, target)
summ_len = np.mean(lmap(len, generated_ids))
base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge)