NER - pl example (#3180)

* 1. seqeval required by ner pl example. install from examples/requirements. 2. unrecognized arguments: save_steps

* pl checkpoint callback filenotfound error: make directory and pass

* #3159 pl checkpoint path difference

* 1. Updated Readme for pl 2. pl script now also correct displays logs 3. pass gpu ids compared to number of gpus

* Updated results in readme

* 1. updated readme 2. removing deprecated pl methods 3. finalizing scripts

* comment length check

* using deprecated validation_end for stable results

* style related changes
This commit is contained in:
Shubham Agarwal
2020-03-10 00:43:38 +00:00
committed by GitHub
parent f51ba059b9
commit 5ca356a464
3 changed files with 34 additions and 9 deletions

View File

@@ -141,10 +141,14 @@ class NERTransformer(BaseTransformer):
return ret, preds_list, out_label_list
def validation_end(self, outputs):
# todo: update to validation_epoch_end instead of deprecated validation_end
# when stable
ret, preds, targets = self._eval_end(outputs)
return ret
logs = ret["log"]
return {"val_loss": logs["val_loss"], "log": logs, "progress_bar": logs}
def test_end(self, outputs):
def test_epoch_end(self, outputs):
# updating to test_epoch_end instead of deprecated test_end
ret, predictions, targets = self._eval_end(outputs)
if self.is_logger():
@@ -172,7 +176,12 @@ class NERTransformer(BaseTransformer):
logger.warning(
"Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0]
)
return ret
# Converting to the dic required by pl
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/\
# pytorch_lightning/trainer/logging.py#L139
logs = ret["log"]
# `val_loss` is the key returned by `self._eval_end()` but actually refers to `test_loss`
return {"avg_test_loss": logs["val_loss"], "log": logs, "progress_bar": logs}
@staticmethod
def add_model_specific_args(parser, root_dir):
@@ -217,6 +226,10 @@ if __name__ == "__main__":
trainer = generic_train(model, args)
if args.do_predict:
checkpoints = list(sorted(glob.glob(args.output_dir + "/checkpoint_*.ckpt", recursive=True)))
# See https://github.com/huggingface/transformers/issues/3159
# pl use this format to create a checkpoint:
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
# /pytorch_lightning/callbacks/model_checkpoint.py#L169
checkpoints = list(sorted(glob.glob(args.output_dir + "/checkpointepoch=*.ckpt", recursive=True)))
NERTransformer.load_from_checkpoint(checkpoints[-1])
trainer.test(model)