fix model loading

This commit is contained in:
thomwolf
2019-07-05 15:57:14 +02:00
parent 6dacc79d39
commit 162ba383b0
3 changed files with 28 additions and 4 deletions

View File

@@ -308,7 +308,8 @@ def main():
input_ids, input_mask, segment_ids, label_ids = batch
# define a new function to compute loss values for both output_modes
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
ouputs = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids)
loss =
if output_mode == "classification":
loss_fct = CrossEntropyLoss()