Reformat source code with black.
This is the result of:
$ black --line-length 119 examples templates transformers utils hubconf.py setup.py
There's a lot of fairly long lines in the project. As a consequence, I'm
picking the longest widely accepted line length, 119 characters.
This is also Thomas' preference, because it allows for explicit variable
names, to make the code easier to understand.
This commit is contained in:
@@ -61,9 +61,7 @@ def read_examples_from_file(data_dir, mode):
|
||||
for line in f:
|
||||
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
|
||||
if words:
|
||||
examples.append(InputExample(guid="{}-{}".format(mode, guid_index),
|
||||
words=words,
|
||||
labels=labels))
|
||||
examples.append(InputExample(guid="{}-{}".format(mode, guid_index), words=words, labels=labels))
|
||||
guid_index += 1
|
||||
words = []
|
||||
labels = []
|
||||
@@ -76,27 +74,27 @@ def read_examples_from_file(data_dir, mode):
|
||||
# Examples could have no label for mode = "test"
|
||||
labels.append("O")
|
||||
if words:
|
||||
examples.append(InputExample(guid="%s-%d".format(mode, guid_index),
|
||||
words=words,
|
||||
labels=labels))
|
||||
examples.append(InputExample(guid="%s-%d".format(mode, guid_index), words=words, labels=labels))
|
||||
return examples
|
||||
|
||||
|
||||
def convert_examples_to_features(examples,
|
||||
label_list,
|
||||
max_seq_length,
|
||||
tokenizer,
|
||||
cls_token_at_end=False,
|
||||
cls_token="[CLS]",
|
||||
cls_token_segment_id=1,
|
||||
sep_token="[SEP]",
|
||||
sep_token_extra=False,
|
||||
pad_on_left=False,
|
||||
pad_token=0,
|
||||
pad_token_segment_id=0,
|
||||
pad_token_label_id=-100,
|
||||
sequence_a_segment_id=0,
|
||||
mask_padding_with_zero=True):
|
||||
def convert_examples_to_features(
|
||||
examples,
|
||||
label_list,
|
||||
max_seq_length,
|
||||
tokenizer,
|
||||
cls_token_at_end=False,
|
||||
cls_token="[CLS]",
|
||||
cls_token_segment_id=1,
|
||||
sep_token="[SEP]",
|
||||
sep_token_extra=False,
|
||||
pad_on_left=False,
|
||||
pad_token=0,
|
||||
pad_token_segment_id=0,
|
||||
pad_token_label_id=-100,
|
||||
sequence_a_segment_id=0,
|
||||
mask_padding_with_zero=True,
|
||||
):
|
||||
""" Loads a data file into a list of `InputBatch`s
|
||||
`cls_token_at_end` define the location of the CLS token:
|
||||
- False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
|
||||
@@ -122,8 +120,8 @@ def convert_examples_to_features(examples,
|
||||
# Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
|
||||
special_tokens_count = 3 if sep_token_extra else 2
|
||||
if len(tokens) > max_seq_length - special_tokens_count:
|
||||
tokens = tokens[:(max_seq_length - special_tokens_count)]
|
||||
label_ids = label_ids[:(max_seq_length - special_tokens_count)]
|
||||
tokens = tokens[: (max_seq_length - special_tokens_count)]
|
||||
label_ids = label_ids[: (max_seq_length - special_tokens_count)]
|
||||
|
||||
# The convention in BERT is:
|
||||
# (a) For sequence pairs:
|
||||
@@ -174,10 +172,10 @@ def convert_examples_to_features(examples,
|
||||
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
|
||||
label_ids = ([pad_token_label_id] * padding_length) + label_ids
|
||||
else:
|
||||
input_ids += ([pad_token] * padding_length)
|
||||
input_mask += ([0 if mask_padding_with_zero else 1] * padding_length)
|
||||
segment_ids += ([pad_token_segment_id] * padding_length)
|
||||
label_ids += ([pad_token_label_id] * padding_length)
|
||||
input_ids += [pad_token] * padding_length
|
||||
input_mask += [0 if mask_padding_with_zero else 1] * padding_length
|
||||
segment_ids += [pad_token_segment_id] * padding_length
|
||||
label_ids += [pad_token_label_id] * padding_length
|
||||
|
||||
assert len(input_ids) == max_seq_length
|
||||
assert len(input_mask) == max_seq_length
|
||||
@@ -194,10 +192,8 @@ def convert_examples_to_features(examples,
|
||||
logger.info("label_ids: %s", " ".join([str(x) for x in label_ids]))
|
||||
|
||||
features.append(
|
||||
InputFeatures(input_ids=input_ids,
|
||||
input_mask=input_mask,
|
||||
segment_ids=segment_ids,
|
||||
label_ids=label_ids))
|
||||
InputFeatures(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_ids=label_ids)
|
||||
)
|
||||
return features
|
||||
|
||||
|
||||
@@ -209,4 +205,4 @@ def get_labels(path):
|
||||
labels = ["O"] + labels
|
||||
return labels
|
||||
else:
|
||||
return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"]
|
||||
return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"]
|
||||
|
||||
Reference in New Issue
Block a user