Cleanup
This commit is contained in:
@@ -146,7 +146,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
|
||||
token_to_orig_map = {}
|
||||
for i in range(paragraph_len):
|
||||
index = len(truncated_query) + sequence_added_tokens + i if not sequence_a_is_doc else i
|
||||
index = len(truncated_query) + sequence_added_tokens + i if tokenizer.padding_side == "right" else i
|
||||
token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i]
|
||||
|
||||
encoded_dict["paragraph_len"] = paragraph_len
|
||||
@@ -166,7 +166,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
for doc_span_index in range(len(spans)):
|
||||
for j in range(spans[doc_span_index]["paragraph_len"]):
|
||||
is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
|
||||
index = j if sequence_a_is_doc else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
|
||||
index = j if tokenizer.padding_side == "left" else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
|
||||
spans[doc_span_index]["token_is_max_context"][index] = is_max_context
|
||||
|
||||
for span in spans:
|
||||
@@ -179,7 +179,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
|
||||
p_mask = np.minimum(p_mask, 1)
|
||||
|
||||
if not sequence_a_is_doc:
|
||||
if tokenizer.padding_side == "right":
|
||||
# Limit positive values to one
|
||||
p_mask = 1 - p_mask
|
||||
|
||||
@@ -207,7 +207,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
end_position = cls_index
|
||||
span_is_impossible = True
|
||||
else:
|
||||
if sequence_a_is_doc:
|
||||
if tokenizer.padding_side == "left":
|
||||
doc_offset = 0
|
||||
else:
|
||||
doc_offset = len(truncated_query) + sequence_added_tokens
|
||||
@@ -270,7 +270,29 @@ class SquadProcessor(DataProcessor):
|
||||
)
|
||||
|
||||
def get_examples_from_dataset(self, dataset, evaluate=False):
|
||||
"""See base class."""
|
||||
"""
|
||||
Creates a list of :class:`~transformers.data.processors.squad.SquadExample` using a TFDS dataset.
|
||||
|
||||
Args:
|
||||
dataset: The tfds dataset loaded from `tensorflow_datasets.load("squad")`
|
||||
evaluate: boolean specifying if in evaluation mode or in training mode
|
||||
|
||||
Returns:
|
||||
List of SquadExample
|
||||
|
||||
Examples::
|
||||
|
||||
import tensorflow_datasets as tfds
|
||||
dataset = tfds.load("squad")
|
||||
|
||||
training_examples = get_examples_from_dataset(dataset, evaluate=False)
|
||||
evaluation_examples = get_examples_from_dataset(dataset, evaluate=True)
|
||||
"""
|
||||
|
||||
if evaluate:
|
||||
dataset = dataset["validation"]
|
||||
else:
|
||||
dataset = dataset["train"]
|
||||
|
||||
examples = []
|
||||
for tensor_dict in tqdm(dataset):
|
||||
@@ -455,8 +477,8 @@ class SquadResult(object):
|
||||
end_logits: The logits corresponding to the end of the answer
|
||||
"""
|
||||
def __init__(self, unique_id, start_logits, end_logits, start_top_index=None, end_top_index=None, cls_logits=None):
|
||||
self.start_top_log_probs = start_logits
|
||||
self.end_top_log_probs = end_logits
|
||||
self.start_logits = start_logits
|
||||
self.end_logits = end_logits
|
||||
self.unique_id = unique_id
|
||||
|
||||
if start_top_index:
|
||||
|
||||
Reference in New Issue
Block a user