Fix PABEE division by zero error (#5233)
* Fix PABEE division by zero error * patience=0 by default
This commit is contained in:
@@ -254,12 +254,15 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
return global_step, tr_loss / global_step
|
return global_step, tr_loss / global_step
|
||||||
|
|
||||||
|
|
||||||
def evaluate(args, model, tokenizer, prefix=""):
|
def evaluate(args, model, tokenizer, prefix="", patience=0):
|
||||||
|
|
||||||
# PABEE STATS
|
|
||||||
if args.model_type == "albert":
|
if args.model_type == "albert":
|
||||||
|
model.albert.set_regression_threshold(args.regression_threshold)
|
||||||
|
model.albert.set_patience(patience)
|
||||||
model.albert.reset_stats()
|
model.albert.reset_stats()
|
||||||
elif args.model_type == "bert":
|
elif args.model_type == "bert":
|
||||||
|
model.bert.set_regression_threshold(args.regression_threshold)
|
||||||
|
model.bert.set_patience(patience)
|
||||||
model.bert.reset_stats()
|
model.bert.reset_stats()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@@ -331,7 +334,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
print(" %s = %s" % (key, str(result[key])))
|
print(" %s = %s" % (key, str(result[key])))
|
||||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||||
|
|
||||||
if args.eval_all_checkpoints:
|
if args.eval_all_checkpoints and patience != 0:
|
||||||
if args.model_type == "albert":
|
if args.model_type == "albert":
|
||||||
model.albert.log_stats()
|
model.albert.log_stats()
|
||||||
elif args.model_type == "bert":
|
elif args.model_type == "bert":
|
||||||
@@ -690,15 +693,7 @@ def main():
|
|||||||
|
|
||||||
print(f"Evaluation for checkpoint {prefix}")
|
print(f"Evaluation for checkpoint {prefix}")
|
||||||
for patience in patience_list:
|
for patience in patience_list:
|
||||||
if args.model_type == "albert":
|
result = evaluate(args, model, tokenizer, prefix=prefix, patience=patience)
|
||||||
model.albert.set_regression_threshold(args.regression_threshold)
|
|
||||||
model.albert.set_patience(patience)
|
|
||||||
elif args.model_type == "bert":
|
|
||||||
model.bert.set_regression_threshold(args.regression_threshold)
|
|
||||||
model.bert.set_patience(patience)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
result = evaluate(args, model, tokenizer, prefix=prefix)
|
|
||||||
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
||||||
results.update(result)
|
results.update(result)
|
||||||
return results
|
return results
|
||||||
|
|||||||
Reference in New Issue
Block a user