Quick fix metrics evaluation on run_classif_pytorch
This commit is contained in:
@@ -425,7 +425,7 @@ def input_fn_builder(features, seq_length, train_batch_size):
|
|||||||
|
|
||||||
def accuracy(out, labels):
|
def accuracy(out, labels):
|
||||||
outputs = np.argmax(out, axis=1)
|
outputs = np.argmax(out, axis=1)
|
||||||
return np.sum(outputs==labels)/float(labels.size)
|
return np.sum(outputs==labels)
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
processors = {
|
processors = {
|
||||||
@@ -491,6 +491,7 @@ def main():
|
|||||||
t_total=num_train_steps)
|
t_total=num_train_steps)
|
||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
|
total_tr_loss = 0
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
train_features = convert_examples_to_features(
|
train_features = convert_examples_to_features(
|
||||||
train_examples, label_list, args.max_seq_length, tokenizer)
|
train_examples, label_list, args.max_seq_length, tokenizer)
|
||||||
@@ -512,6 +513,7 @@ def main():
|
|||||||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
|
nb_tr_examples = 0
|
||||||
for epoch in range(int(args.num_train_epochs)):
|
for epoch in range(int(args.num_train_epochs)):
|
||||||
for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
|
for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
|
||||||
input_ids = input_ids.to(device)
|
input_ids = input_ids.to(device)
|
||||||
@@ -520,6 +522,8 @@ def main():
|
|||||||
label_ids = label_ids.to(device)
|
label_ids = label_ids.to(device)
|
||||||
|
|
||||||
loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
|
loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
|
||||||
|
total_tr_loss += loss.item()
|
||||||
|
nb_tr_examples += input_ids.size(0)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
global_step += 1
|
global_step += 1
|
||||||
@@ -572,7 +576,7 @@ def main():
|
|||||||
result = {'eval_loss': eval_loss,
|
result = {'eval_loss': eval_loss,
|
||||||
'eval_accuracy': eval_accuracy,
|
'eval_accuracy': eval_accuracy,
|
||||||
'global_step': global_step,
|
'global_step': global_step,
|
||||||
'loss': loss.item()}
|
'loss': total_tr_loss/nb_tr_examples}#'loss': loss.item()}
|
||||||
|
|
||||||
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
||||||
with open(output_eval_file, "w") as writer:
|
with open(output_eval_file, "w") as writer:
|
||||||
|
|||||||
Reference in New Issue
Block a user