Fix PABEE division by zero error (#5233)
* Fix PABEE division by zero error * patience=0 by default
This commit is contained in:
@@ -254,12 +254,15 @@ def train(args, train_dataset, model, tokenizer):
|
||||
return global_step, tr_loss / global_step
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, prefix=""):
|
||||
def evaluate(args, model, tokenizer, prefix="", patience=0):
|
||||
|
||||
# PABEE STATS
|
||||
if args.model_type == "albert":
|
||||
model.albert.set_regression_threshold(args.regression_threshold)
|
||||
model.albert.set_patience(patience)
|
||||
model.albert.reset_stats()
|
||||
elif args.model_type == "bert":
|
||||
model.bert.set_regression_threshold(args.regression_threshold)
|
||||
model.bert.set_patience(patience)
|
||||
model.bert.reset_stats()
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
@@ -331,7 +334,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
print(" %s = %s" % (key, str(result[key])))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
if args.eval_all_checkpoints:
|
||||
if args.eval_all_checkpoints and patience != 0:
|
||||
if args.model_type == "albert":
|
||||
model.albert.log_stats()
|
||||
elif args.model_type == "bert":
|
||||
@@ -690,15 +693,7 @@ def main():
|
||||
|
||||
print(f"Evaluation for checkpoint {prefix}")
|
||||
for patience in patience_list:
|
||||
if args.model_type == "albert":
|
||||
model.albert.set_regression_threshold(args.regression_threshold)
|
||||
model.albert.set_patience(patience)
|
||||
elif args.model_type == "bert":
|
||||
model.bert.set_regression_threshold(args.regression_threshold)
|
||||
model.bert.set_patience(patience)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
result = evaluate(args, model, tokenizer, prefix=prefix)
|
||||
result = evaluate(args, model, tokenizer, prefix=prefix, patience=patience)
|
||||
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
||||
results.update(result)
|
||||
return results
|
||||
|
||||
Reference in New Issue
Block a user