clean up glue example

This commit is contained in:
thomwolf
2019-07-05 16:31:13 +02:00
parent 162ba383b0
commit 1113f97f33
4 changed files with 423 additions and 17 deletions

View File

@@ -309,14 +309,7 @@ def main():
# define a new function to compute loss values for both output_modes
ouputs = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids)
loss =
if output_mode == "classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
elif output_mode == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), label_ids.view(-1))
loss = ouputs[0]
if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
@@ -423,15 +416,8 @@ def main():
label_ids = label_ids.to(device)
with torch.no_grad():
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
# create eval loss and other metric required by the task
if output_mode == "classification":
loss_fct = CrossEntropyLoss()
tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
elif output_mode == "regression":
loss_fct = MSELoss()
tmp_eval_loss = loss_fct(logits.view(-1), label_ids.view(-1))
outputs = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids)
tmp_eval_loss, logits = outputs[:2]
eval_loss += tmp_eval_loss.mean().item()
nb_eval_steps += 1