pruning in bertology
This commit is contained in:
@@ -308,7 +308,7 @@ def main():
|
||||
input_ids, input_mask, segment_ids, label_ids = batch
|
||||
|
||||
# define a new function to compute loss values for both output_modes
|
||||
logits = model(input_ids, segment_ids, input_mask)
|
||||
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
|
||||
|
||||
if output_mode == "classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
@@ -422,7 +422,7 @@ def main():
|
||||
label_ids = label_ids.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_ids, segment_ids, input_mask)
|
||||
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
|
||||
|
||||
# create eval loss and other metric required by the task
|
||||
if output_mode == "classification":
|
||||
@@ -503,7 +503,7 @@ def main():
|
||||
label_ids = label_ids.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_ids, segment_ids, input_mask, labels=None)
|
||||
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=None)
|
||||
|
||||
loss_fct = CrossEntropyLoss()
|
||||
tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
|
||||
|
||||
Reference in New Issue
Block a user