Storing the feature of each choice as a dict for readability.
This commit is contained in:
@@ -73,9 +73,17 @@ class InputFeatures(object):
|
||||
example_id,
|
||||
choices_features,
|
||||
label
|
||||
|
||||
):
|
||||
self.example_id = example_id
|
||||
self.choices_features = choices_features
|
||||
self.choices_features = [
|
||||
{
|
||||
'input_ids': input_ids,
|
||||
'input_mask': input_mask,
|
||||
'segment_ids': segment_ids
|
||||
}
|
||||
for _, input_ids, input_mask, segment_ids in choices_features
|
||||
]
|
||||
self.label = label
|
||||
|
||||
def read_swag_examples(input_file, is_training):
|
||||
@@ -181,6 +189,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||
"""Truncates a sequence pair in place to the maximum length."""
|
||||
|
||||
@@ -207,4 +216,11 @@ if __name__ == "__main__":
|
||||
print("###########################")
|
||||
print(example)
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
features = convert_examples_to_features(examples, tokenizer, max_seq_length, is_training)
|
||||
features = convert_examples_to_features(examples[:500], tokenizer, max_seq_length, is_training)
|
||||
for i in range(10):
|
||||
choice_feature_list = features[i].choices_features
|
||||
for choice_idx, choice_feature in enumerate(choice_feature_list):
|
||||
print(f'choice_idx: {choice_idx}')
|
||||
print(f'input_ids: {" ".join(map(str, choice_feature["input_ids"]))}')
|
||||
print(f'input_mask: {" ".join(map(str, choice_feature["input_mask"]))}')
|
||||
print(f'segment_ids: {" ".join(map(str, choice_feature["segment_ids"]))}')
|
||||
|
||||
Reference in New Issue
Block a user