TFDS dataset can now be evaluated
This commit is contained in:
@@ -245,22 +245,37 @@ class SquadProcessor(DataProcessor):
|
|||||||
train_file = None
|
train_file = None
|
||||||
dev_file = None
|
dev_file = None
|
||||||
|
|
||||||
def get_example_from_tensor_dict(self, tensor_dict):
|
def get_example_from_tensor_dict(self, tensor_dict, evaluate=False):
|
||||||
|
|
||||||
|
if not evaluate:
|
||||||
|
answer = tensor_dict['answers']['text'][0].numpy().decode('utf-8')
|
||||||
|
answer_start = tensor_dict['answers']['answer_start'][0].numpy()
|
||||||
|
answers = None
|
||||||
|
else:
|
||||||
|
answers = [{
|
||||||
|
"answer_start": start.numpy(),
|
||||||
|
"text": text.numpy().decode('utf-8')
|
||||||
|
} for start, text in zip(tensor_dict['answers']["answer_start"], tensor_dict['answers']["text"])]
|
||||||
|
|
||||||
|
answer = None
|
||||||
|
answer_start = None
|
||||||
|
|
||||||
return SquadExample(
|
return SquadExample(
|
||||||
tensor_dict['id'].numpy().decode("utf-8"),
|
qas_id=tensor_dict['id'].numpy().decode("utf-8"),
|
||||||
tensor_dict['question'].numpy().decode('utf-8'),
|
question_text=tensor_dict['question'].numpy().decode('utf-8'),
|
||||||
tensor_dict['context'].numpy().decode('utf-8'),
|
context_text=tensor_dict['context'].numpy().decode('utf-8'),
|
||||||
tensor_dict['answers']['text'][0].numpy().decode('utf-8'),
|
answer_text=answer,
|
||||||
tensor_dict['answers']['answer_start'][0].numpy(),
|
start_position_character=answer_start,
|
||||||
tensor_dict['title'].numpy().decode('utf-8')
|
title=tensor_dict['title'].numpy().decode('utf-8'),
|
||||||
|
answers=answers
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_examples_from_dataset(self, dataset):
|
def get_examples_from_dataset(self, dataset, evaluate=False):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
|
|
||||||
examples = []
|
examples = []
|
||||||
for tensor_dict in tqdm(dataset):
|
for tensor_dict in tqdm(dataset):
|
||||||
examples.append(self.get_example_from_tensor_dict(tensor_dict))
|
examples.append(self.get_example_from_tensor_dict(tensor_dict, evaluate=evaluate))
|
||||||
|
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
@@ -300,6 +315,7 @@ class SquadProcessor(DataProcessor):
|
|||||||
question_text = qa["question"]
|
question_text = qa["question"]
|
||||||
start_position_character = None
|
start_position_character = None
|
||||||
answer_text = None
|
answer_text = None
|
||||||
|
answers = None
|
||||||
|
|
||||||
if "is_impossible" in qa:
|
if "is_impossible" in qa:
|
||||||
is_impossible = qa["is_impossible"]
|
is_impossible = qa["is_impossible"]
|
||||||
|
|||||||
Reference in New Issue
Block a user