Fixing problems in convert_examples_to_features.
This commit is contained in:
@@ -70,20 +70,13 @@ class SwagExample(object):
|
|||||||
|
|
||||||
class InputFeatures(object):
|
class InputFeatures(object):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
unique_id,
|
|
||||||
example_id,
|
example_id,
|
||||||
input_ids,
|
choices_features,
|
||||||
input_mask,
|
label
|
||||||
segment_ids,
|
|
||||||
label_id
|
|
||||||
):
|
):
|
||||||
self.unique_id = unique_id
|
|
||||||
self.example_id = example_id
|
self.example_id = example_id
|
||||||
self.input_ids = input_ids
|
self.choices_features = choices_features
|
||||||
self.input_mask = input_mask
|
self.label = label
|
||||||
self.segment_ids = segment_ids
|
|
||||||
self.label_id = label_id
|
|
||||||
|
|
||||||
|
|
||||||
def read_swag_examples(input_file, is_training):
|
def read_swag_examples(input_file, is_training):
|
||||||
input_df = pd.read_csv(input_file)
|
input_df = pd.read_csv(input_file)
|
||||||
@@ -145,7 +138,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
# place so that the total length is less than the
|
# place so that the total length is less than the
|
||||||
# specified length. Account for [CLS], [SEP], [SEP] with
|
# specified length. Account for [CLS], [SEP], [SEP] with
|
||||||
# "- 3"
|
# "- 3"
|
||||||
_truncate_seq_pair(context_tokens, ending_tokens, max_seq_length - 3)
|
_truncate_seq_pair(context_tokens_choice, ending_tokens, max_seq_length - 3)
|
||||||
|
|
||||||
tokens = ["[CLS]"] + context_tokens_choice + ["[SEP]"] + ending_tokens + ["[SEP]"]
|
tokens = ["[CLS]"] + context_tokens_choice + ["[SEP]"] + ending_tokens + ["[SEP]"]
|
||||||
segment_ids = [0] * (len(context_tokens_choice) + 2) + [1] * (len(ending_tokens) + 1)
|
segment_ids = [0] * (len(context_tokens_choice) + 2) + [1] * (len(ending_tokens) + 1)
|
||||||
@@ -178,7 +171,15 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
if is_training:
|
if is_training:
|
||||||
logger.info(f"label: {label}")
|
logger.info(f"label: {label}")
|
||||||
|
|
||||||
|
features.append(
|
||||||
|
InputFeatures(
|
||||||
|
example_id = example.swag_id,
|
||||||
|
choices_features = choices_features,
|
||||||
|
label = label
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||||
"""Truncates a sequence pair in place to the maximum length."""
|
"""Truncates a sequence pair in place to the maximum length."""
|
||||||
@@ -206,4 +207,4 @@ if __name__ == "__main__":
|
|||||||
print("###########################")
|
print("###########################")
|
||||||
print(example)
|
print(example)
|
||||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||||
convert_examples_to_features(examples, tokenizer, max_seq_length, is_training)
|
features = convert_examples_to_features(examples, tokenizer, max_seq_length, is_training)
|
||||||
|
|||||||
Reference in New Issue
Block a user