pruning in bertology

This commit is contained in:
thomwolf
2019-06-19 15:25:49 +02:00
parent dc8e0019b7
commit 34d706a0e1
9 changed files with 137 additions and 66 deletions

View File

@@ -308,7 +308,7 @@ def main():
input_ids, input_mask, segment_ids, label_ids = batch
# define a new function to compute loss values for both output_modes
logits = model(input_ids, segment_ids, input_mask)
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
if output_mode == "classification":
loss_fct = CrossEntropyLoss()
@@ -422,7 +422,7 @@ def main():
label_ids = label_ids.to(device)
with torch.no_grad():
logits = model(input_ids, segment_ids, input_mask)
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
# create eval loss and other metric required by the task
if output_mode == "classification":
@@ -503,7 +503,7 @@ def main():
label_ids = label_ids.to(device)
with torch.no_grad():
logits = model(input_ids, segment_ids, input_mask, labels=None)
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=None)
loss_fct = CrossEntropyLoss()
tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))