fix model loading
This commit is contained in:
@@ -308,7 +308,8 @@ def main():
|
||||
input_ids, input_mask, segment_ids, label_ids = batch
|
||||
|
||||
# define a new function to compute loss values for both output_modes
|
||||
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
|
||||
ouputs = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids)
|
||||
loss =
|
||||
|
||||
if output_mode == "classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
|
||||
Reference in New Issue
Block a user