[WIP] Lightning glue example (#3290)
* ✨ Alter base pl transformer to use automodels * 🐛 Add batch size env variable to function call * 💄 Apply black code style from Makefile * 🚚 Move lightning base out of ner directory * ✨ Add lightning glue example * 💄 self * move _feature_file to base class * ✨ Move eval logging to custom callback * 💄 Apply black code style * 🐛 Add parent to pythonpath, remove copy command * 🐛 Add missing max_length kwarg
This commit is contained in:
@@ -21,11 +21,13 @@ class NERTransformer(BaseTransformer):
|
||||
A training module for NER. See BaseTransformer for the core options.
|
||||
"""
|
||||
|
||||
mode = "token-classification"
|
||||
|
||||
def __init__(self, hparams):
|
||||
self.labels = get_labels(hparams.labels)
|
||||
num_labels = len(self.labels)
|
||||
self.pad_token_label_id = CrossEntropyLoss().ignore_index
|
||||
super(NERTransformer, self).__init__(hparams, num_labels)
|
||||
super(NERTransformer, self).__init__(hparams, num_labels, self.mode)
|
||||
|
||||
def forward(self, **inputs):
|
||||
return self.model(**inputs)
|
||||
@@ -38,21 +40,11 @@ class NERTransformer(BaseTransformer):
|
||||
batch[2] if self.hparams.model_type in ["bert", "xlnet"] else None
|
||||
) # XLM and RoBERTa don"t use segment_ids
|
||||
|
||||
outputs = self.forward(**inputs)
|
||||
outputs = self(**inputs)
|
||||
loss = outputs[0]
|
||||
tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]}
|
||||
return {"loss": loss, "log": tensorboard_logs}
|
||||
|
||||
def _feature_file(self, mode):
|
||||
return os.path.join(
|
||||
self.hparams.data_dir,
|
||||
"cached_{}_{}_{}".format(
|
||||
mode,
|
||||
list(filter(None, self.hparams.model_name_or_path.split("/"))).pop(),
|
||||
str(self.hparams.max_seq_length),
|
||||
),
|
||||
)
|
||||
|
||||
def prepare_data(self):
|
||||
"Called to initialize data. Use the call to construct features"
|
||||
args = self.hparams
|
||||
@@ -100,7 +92,7 @@ class NERTransformer(BaseTransformer):
|
||||
inputs["token_type_ids"] = (
|
||||
batch[2] if self.hparams.model_type in ["bert", "xlnet"] else None
|
||||
) # XLM and RoBERTa don"t use segment_ids
|
||||
outputs = self.forward(**inputs)
|
||||
outputs = self(**inputs)
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
preds = logits.detach().cpu().numpy()
|
||||
out_label_ids = inputs["labels"].detach().cpu().numpy()
|
||||
@@ -130,14 +122,8 @@ class NERTransformer(BaseTransformer):
|
||||
"f1": f1_score(out_label_list, preds_list),
|
||||
}
|
||||
|
||||
if self.is_logger():
|
||||
logger.info("***** Eval results *****")
|
||||
for key in sorted(results.keys()):
|
||||
logger.info(" %s = %s", key, str(results[key]))
|
||||
|
||||
tensorboard_logs = results
|
||||
ret = {k: v for k, v in results.items()}
|
||||
ret["log"] = tensorboard_logs
|
||||
ret["log"] = results
|
||||
return ret, preds_list, out_label_list
|
||||
|
||||
def validation_end(self, outputs):
|
||||
@@ -151,32 +137,7 @@ class NERTransformer(BaseTransformer):
|
||||
# updating to test_epoch_end instead of deprecated test_end
|
||||
ret, predictions, targets = self._eval_end(outputs)
|
||||
|
||||
if self.is_logger():
|
||||
# Write output to a file:
|
||||
# Save results
|
||||
output_test_results_file = os.path.join(self.hparams.output_dir, "test_results.txt")
|
||||
with open(output_test_results_file, "w") as writer:
|
||||
for key in sorted(ret.keys()):
|
||||
if key != "log":
|
||||
writer.write("{} = {}\n".format(key, str(ret[key])))
|
||||
# Save predictions
|
||||
output_test_predictions_file = os.path.join(self.hparams.output_dir, "test_predictions.txt")
|
||||
with open(output_test_predictions_file, "w") as writer:
|
||||
with open(os.path.join(self.hparams.data_dir, "test.txt"), "r") as f:
|
||||
example_id = 0
|
||||
for line in f:
|
||||
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
|
||||
writer.write(line)
|
||||
if not predictions[example_id]:
|
||||
example_id += 1
|
||||
elif predictions[example_id]:
|
||||
output_line = line.split()[0] + " " + predictions[example_id].pop(0) + "\n"
|
||||
writer.write(output_line)
|
||||
else:
|
||||
logger.warning(
|
||||
"Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0]
|
||||
)
|
||||
# Converting to the dic required by pl
|
||||
# Converting to the dict required by pl
|
||||
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/\
|
||||
# pytorch_lightning/trainer/logging.py#L139
|
||||
logs = ret["log"]
|
||||
@@ -230,6 +191,6 @@ if __name__ == "__main__":
|
||||
# pl use this format to create a checkpoint:
|
||||
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
|
||||
# /pytorch_lightning/callbacks/model_checkpoint.py#L169
|
||||
checkpoints = list(sorted(glob.glob(args.output_dir + "/checkpointepoch=*.ckpt", recursive=True)))
|
||||
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
|
||||
NERTransformer.load_from_checkpoint(checkpoints[-1])
|
||||
trainer.test(model)
|
||||
|
||||
Reference in New Issue
Block a user