remove old methods
This commit is contained in:
@@ -305,37 +305,6 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
|||||||
else:
|
else:
|
||||||
tokens_b.pop()
|
tokens_b.pop()
|
||||||
|
|
||||||
|
|
||||||
def input_fn_builder(features, seq_length, train_batch_size):
|
|
||||||
# TODO: delete
|
|
||||||
"""Creates an `input_fn` closure to be passed to TPUEstimator.""" ### ATTENTION - To rewrite ###
|
|
||||||
|
|
||||||
all_input_ids = [f.input_ids for feature in features]
|
|
||||||
all_input_mask = [f.input_mask for feature in features]
|
|
||||||
all_segment_ids = [f.segment_ids for feature in features]
|
|
||||||
all_label_ids = [f.label_id for feature in features]
|
|
||||||
|
|
||||||
# for feature in features:
|
|
||||||
# all_input_ids.append(feature.input_ids)
|
|
||||||
# all_input_mask.append(feature.input_mask)
|
|
||||||
# all_segment_ids.append(feature.segment_ids)
|
|
||||||
# all_label_ids.append(feature.label_id)
|
|
||||||
|
|
||||||
input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.Long)
|
|
||||||
input_mask_tensor = torch.tensor(all_input_mask, dtype=torch.Long)
|
|
||||||
segment_tensor = torch.tensor(all_segment_ids, dtype=torch.Long)
|
|
||||||
label_tensor = torch.tensor(all_label_ids, dtype=torch.Long)
|
|
||||||
|
|
||||||
train_data = TensorDataset(input_ids_tensor, input_mask_tensor,
|
|
||||||
segment_tensor, label_tensor)
|
|
||||||
if args.local_rank == -1:
|
|
||||||
train_sampler = RandomSampler(train_data)
|
|
||||||
else:
|
|
||||||
train_sampler = DistributedSampler(train_data)
|
|
||||||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=train_batch_size)
|
|
||||||
|
|
||||||
return train_dataloader
|
|
||||||
|
|
||||||
def accuracy(out, labels):
|
def accuracy(out, labels):
|
||||||
outputs = np.argmax(out, axis=1)
|
outputs = np.argmax(out, axis=1)
|
||||||
return np.sum(outputs==labels)
|
return np.sum(outputs==labels)
|
||||||
|
|||||||
Reference in New Issue
Block a user