Fixing problems in convert_examples_to_features.
This commit is contained in:
@@ -70,20 +70,13 @@ class SwagExample(object):
|
||||
|
||||
class InputFeatures(object):
|
||||
def __init__(self,
|
||||
unique_id,
|
||||
example_id,
|
||||
input_ids,
|
||||
input_mask,
|
||||
segment_ids,
|
||||
label_id
|
||||
choices_features,
|
||||
label
|
||||
):
|
||||
self.unique_id = unique_id
|
||||
self.example_id = example_id
|
||||
self.input_ids = input_ids
|
||||
self.input_mask = input_mask
|
||||
self.segment_ids = segment_ids
|
||||
self.label_id = label_id
|
||||
|
||||
self.choices_features = choices_features
|
||||
self.label = label
|
||||
|
||||
def read_swag_examples(input_file, is_training):
|
||||
input_df = pd.read_csv(input_file)
|
||||
@@ -145,7 +138,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
# place so that the total length is less than the
|
||||
# specified length. Account for [CLS], [SEP], [SEP] with
|
||||
# "- 3"
|
||||
_truncate_seq_pair(context_tokens, ending_tokens, max_seq_length - 3)
|
||||
_truncate_seq_pair(context_tokens_choice, ending_tokens, max_seq_length - 3)
|
||||
|
||||
tokens = ["[CLS]"] + context_tokens_choice + ["[SEP]"] + ending_tokens + ["[SEP]"]
|
||||
segment_ids = [0] * (len(context_tokens_choice) + 2) + [1] * (len(ending_tokens) + 1)
|
||||
@@ -178,7 +171,15 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
if is_training:
|
||||
logger.info(f"label: {label}")
|
||||
|
||||
features.append(
|
||||
InputFeatures(
|
||||
example_id = example.swag_id,
|
||||
choices_features = choices_features,
|
||||
label = label
|
||||
)
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||
"""Truncates a sequence pair in place to the maximum length."""
|
||||
@@ -206,4 +207,4 @@ if __name__ == "__main__":
|
||||
print("###########################")
|
||||
print(example)
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
convert_examples_to_features(examples, tokenizer, max_seq_length, is_training)
|
||||
features = convert_examples_to_features(examples, tokenizer, max_seq_length, is_training)
|
||||
|
||||
Reference in New Issue
Block a user