This commit is contained in:
LysandreJik
2019-12-04 16:24:15 -05:00
parent a7ca6d738b
commit f7e4a7cdfa
6 changed files with 191 additions and 24 deletions

View File

@@ -630,12 +630,12 @@ def compute_predictions_log_probs(
for i in range(start_n_top):
for j in range(end_n_top):
start_log_prob = result.start_top_log_probs[i]
start_log_prob = result.start_logits[i]
start_index = result.start_top_index[i]
j_index = i * end_n_top + j
end_log_prob = result.end_top_log_probs[j_index]
end_log_prob = result.end_logits[j_index]
end_index = result.end_top_index[j_index]
# We could hypothetically create invalid predictions, e.g., predict

View File

@@ -146,7 +146,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
token_to_orig_map = {}
for i in range(paragraph_len):
index = len(truncated_query) + sequence_added_tokens + i if not sequence_a_is_doc else i
index = len(truncated_query) + sequence_added_tokens + i if tokenizer.padding_side == "right" else i
token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i]
encoded_dict["paragraph_len"] = paragraph_len
@@ -166,7 +166,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
for doc_span_index in range(len(spans)):
for j in range(spans[doc_span_index]["paragraph_len"]):
is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
index = j if sequence_a_is_doc else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
index = j if tokenizer.padding_side == "left" else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
spans[doc_span_index]["token_is_max_context"][index] = is_max_context
for span in spans:
@@ -179,7 +179,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
p_mask = np.minimum(p_mask, 1)
if not sequence_a_is_doc:
if tokenizer.padding_side == "right":
# Limit positive values to one
p_mask = 1 - p_mask
@@ -207,7 +207,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
end_position = cls_index
span_is_impossible = True
else:
if sequence_a_is_doc:
if tokenizer.padding_side == "left":
doc_offset = 0
else:
doc_offset = len(truncated_query) + sequence_added_tokens
@@ -270,7 +270,29 @@ class SquadProcessor(DataProcessor):
)
def get_examples_from_dataset(self, dataset, evaluate=False):
"""See base class."""
"""
Creates a list of :class:`~transformers.data.processors.squad.SquadExample` using a TFDS dataset.
Args:
dataset: The tfds dataset loaded from `tensorflow_datasets.load("squad")`
evaluate: boolean specifying if in evaluation mode or in training mode
Returns:
List of SquadExample
Examples::
import tensorflow_datasets as tfds
dataset = tfds.load("squad")
training_examples = get_examples_from_dataset(dataset, evaluate=False)
evaluation_examples = get_examples_from_dataset(dataset, evaluate=True)
"""
if evaluate:
dataset = dataset["validation"]
else:
dataset = dataset["train"]
examples = []
for tensor_dict in tqdm(dataset):
@@ -455,8 +477,8 @@ class SquadResult(object):
end_logits: The logits corresponding to the end of the answer
"""
def __init__(self, unique_id, start_logits, end_logits, start_top_index=None, end_top_index=None, cls_logits=None):
self.start_top_log_probs = start_logits
self.end_top_log_probs = end_logits
self.start_logits = start_logits
self.end_logits = end_logits
self.unique_id = unique_id
if start_top_index: