Storing the feature of each choice as a dict for readability.
This commit is contained in:
@@ -73,9 +73,17 @@ class InputFeatures(object):
|
|||||||
example_id,
|
example_id,
|
||||||
choices_features,
|
choices_features,
|
||||||
label
|
label
|
||||||
|
|
||||||
):
|
):
|
||||||
self.example_id = example_id
|
self.example_id = example_id
|
||||||
self.choices_features = choices_features
|
self.choices_features = [
|
||||||
|
{
|
||||||
|
'input_ids': input_ids,
|
||||||
|
'input_mask': input_mask,
|
||||||
|
'segment_ids': segment_ids
|
||||||
|
}
|
||||||
|
for _, input_ids, input_mask, segment_ids in choices_features
|
||||||
|
]
|
||||||
self.label = label
|
self.label = label
|
||||||
|
|
||||||
def read_swag_examples(input_file, is_training):
|
def read_swag_examples(input_file, is_training):
|
||||||
@@ -181,6 +189,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||||
"""Truncates a sequence pair in place to the maximum length."""
|
"""Truncates a sequence pair in place to the maximum length."""
|
||||||
|
|
||||||
@@ -207,4 +216,11 @@ if __name__ == "__main__":
|
|||||||
print("###########################")
|
print("###########################")
|
||||||
print(example)
|
print(example)
|
||||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||||
features = convert_examples_to_features(examples, tokenizer, max_seq_length, is_training)
|
features = convert_examples_to_features(examples[:500], tokenizer, max_seq_length, is_training)
|
||||||
|
for i in range(10):
|
||||||
|
choice_feature_list = features[i].choices_features
|
||||||
|
for choice_idx, choice_feature in enumerate(choice_feature_list):
|
||||||
|
print(f'choice_idx: {choice_idx}')
|
||||||
|
print(f'input_ids: {" ".join(map(str, choice_feature["input_ids"]))}')
|
||||||
|
print(f'input_mask: {" ".join(map(str, choice_feature["input_mask"]))}')
|
||||||
|
print(f'segment_ids: {" ".join(map(str, choice_feature["segment_ids"]))}')
|
||||||
|
|||||||
Reference in New Issue
Block a user