Upgrade examples to pl=0.8.1(#5146)

This commit is contained in:
Sam Shleifer
2020-06-22 20:40:10 -04:00
committed by GitHub
parent 06b60c8b05
commit f5c2a122e3
11 changed files with 53 additions and 150 deletions

View File

@@ -108,7 +108,7 @@ class GLUETransformer(BaseTransformer):
return {"val_loss": tmp_eval_loss.detach().cpu(), "pred": preds, "target": out_label_ids}
def _eval_end(self, outputs):
def _eval_end(self, outputs) -> tuple:
val_loss_mean = torch.stack([x["val_loss"] for x in outputs]).mean().detach().cpu().item()
preds = np.concatenate([x["pred"] for x in outputs], axis=0)
@@ -132,20 +132,14 @@ class GLUETransformer(BaseTransformer):
logs = ret["log"]
return {"val_loss": logs["val_loss"], "log": logs, "progress_bar": logs}
def test_epoch_end(self, outputs):
# updating to test_epoch_end instead of deprecated test_end
def test_epoch_end(self, outputs) -> dict:
ret, predictions, targets = self._eval_end(outputs)
# Converting to the dic required by pl
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/\
# pytorch_lightning/trainer/logging.py#L139
logs = ret["log"]
# `val_loss` is the key returned by `self._eval_end()` but actually refers to `test_loss`
return {"avg_test_loss": logs["val_loss"], "log": logs, "progress_bar": logs}
@staticmethod
def add_model_specific_args(parser, root_dir):
# Add NER specific options
BaseTransformer.add_model_specific_args(parser, root_dir)
parser.add_argument(
"--max_seq_length",