model fixes + ipnb fixes
This commit is contained in:
@@ -26,6 +26,7 @@ import json
|
||||
import re
|
||||
|
||||
import tokenization
|
||||
import torch
|
||||
|
||||
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
@@ -251,10 +252,9 @@ def main():
|
||||
|
||||
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
|
||||
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
||||
|
||||
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
|
||||
eval_data = TensorDataset(all_input_ids, all_input_mask, all_example_index)
|
||||
if args.local_rank == -1:
|
||||
eval_sampler = SequentialSampler(eval_data)
|
||||
else:
|
||||
@@ -263,12 +263,11 @@ def main():
|
||||
|
||||
model.eval()
|
||||
with open(args.output_file, "w", encoding='utf-8') as writer:
|
||||
for input_ids, input_mask, segment_ids, example_indices in eval_dataloader:
|
||||
for input_ids, input_mask, example_indices in eval_dataloader:
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.float().to(device)
|
||||
segment_ids = segment_ids.to(device)
|
||||
|
||||
all_encoder_layers, _ = model(input_ids, segment_ids, input_mask)
|
||||
all_encoder_layers, _ = model(input_ids, token_type_ids=None, attention_mask=input_mask)
|
||||
|
||||
for enc_layers, example_index in zip(all_encoder_layers, example_indices):
|
||||
feature = features[example_index.item()]
|
||||
|
||||
Reference in New Issue
Block a user