From 54e9ce785d1b2d3c7c6d4c1d4f5c5850f451ac52 Mon Sep 17 00:00:00 2001 From: Kevin Canwen Xu Date: Wed, 24 Jun 2020 16:10:36 +0800 Subject: [PATCH] Fix PABEE division by zero error (#5233) * Fix PABEE division by zero error * patience=0 by default --- .../run_glue_with_pabee.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/examples/bert-loses-patience/run_glue_with_pabee.py b/examples/bert-loses-patience/run_glue_with_pabee.py index 575d2aabae..1ef2b8488e 100755 --- a/examples/bert-loses-patience/run_glue_with_pabee.py +++ b/examples/bert-loses-patience/run_glue_with_pabee.py @@ -254,12 +254,15 @@ def train(args, train_dataset, model, tokenizer): return global_step, tr_loss / global_step -def evaluate(args, model, tokenizer, prefix=""): +def evaluate(args, model, tokenizer, prefix="", patience=0): - # PABEE STATS if args.model_type == "albert": + model.albert.set_regression_threshold(args.regression_threshold) + model.albert.set_patience(patience) model.albert.reset_stats() elif args.model_type == "bert": + model.bert.set_regression_threshold(args.regression_threshold) + model.bert.set_patience(patience) model.bert.reset_stats() else: raise NotImplementedError() @@ -331,7 +334,7 @@ def evaluate(args, model, tokenizer, prefix=""): print(" %s = %s" % (key, str(result[key]))) writer.write("%s = %s\n" % (key, str(result[key]))) - if args.eval_all_checkpoints: + if args.eval_all_checkpoints and patience != 0: if args.model_type == "albert": model.albert.log_stats() elif args.model_type == "bert": @@ -690,15 +693,7 @@ def main(): print(f"Evaluation for checkpoint {prefix}") for patience in patience_list: - if args.model_type == "albert": - model.albert.set_regression_threshold(args.regression_threshold) - model.albert.set_patience(patience) - elif args.model_type == "bert": - model.bert.set_regression_threshold(args.regression_threshold) - model.bert.set_patience(patience) - else: - raise NotImplementedError() - result = evaluate(args, model, tokenizer, prefix=prefix) + result = evaluate(args, model, tokenizer, prefix=prefix, patience=patience) result = dict((k + "_{}".format(global_step), v) for k, v in result.items()) results.update(result) return results