NER - pl example (#3180)
* 1. seqeval required by ner pl example. install from examples/requirements. 2. unrecognized arguments: save_steps * pl checkpoint callback filenotfound error: make directory and pass * #3159 pl checkpoint path difference * 1. Updated Readme for pl 2. pl script now also correct displays logs 3. pass gpu ids compared to number of gpus * Updated results in readme * 1. updated readme 2. removing deprecated pl methods 3. finalizing scripts * comment length check * using deprecated validation_end for stable results * style related changes
This commit is contained in:
@@ -141,10 +141,14 @@ class NERTransformer(BaseTransformer):
|
||||
return ret, preds_list, out_label_list
|
||||
|
||||
def validation_end(self, outputs):
|
||||
# todo: update to validation_epoch_end instead of deprecated validation_end
|
||||
# when stable
|
||||
ret, preds, targets = self._eval_end(outputs)
|
||||
return ret
|
||||
logs = ret["log"]
|
||||
return {"val_loss": logs["val_loss"], "log": logs, "progress_bar": logs}
|
||||
|
||||
def test_end(self, outputs):
|
||||
def test_epoch_end(self, outputs):
|
||||
# updating to test_epoch_end instead of deprecated test_end
|
||||
ret, predictions, targets = self._eval_end(outputs)
|
||||
|
||||
if self.is_logger():
|
||||
@@ -172,7 +176,12 @@ class NERTransformer(BaseTransformer):
|
||||
logger.warning(
|
||||
"Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0]
|
||||
)
|
||||
return ret
|
||||
# Converting to the dic required by pl
|
||||
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/\
|
||||
# pytorch_lightning/trainer/logging.py#L139
|
||||
logs = ret["log"]
|
||||
# `val_loss` is the key returned by `self._eval_end()` but actually refers to `test_loss`
|
||||
return {"avg_test_loss": logs["val_loss"], "log": logs, "progress_bar": logs}
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
@@ -217,6 +226,10 @@ if __name__ == "__main__":
|
||||
trainer = generic_train(model, args)
|
||||
|
||||
if args.do_predict:
|
||||
checkpoints = list(sorted(glob.glob(args.output_dir + "/checkpoint_*.ckpt", recursive=True)))
|
||||
# See https://github.com/huggingface/transformers/issues/3159
|
||||
# pl use this format to create a checkpoint:
|
||||
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
|
||||
# /pytorch_lightning/callbacks/model_checkpoint.py#L169
|
||||
checkpoints = list(sorted(glob.glob(args.output_dir + "/checkpointepoch=*.ckpt", recursive=True)))
|
||||
NERTransformer.load_from_checkpoint(checkpoints[-1])
|
||||
trainer.test(model)
|
||||
|
||||
Reference in New Issue
Block a user