TFDS dataset can now be evaluated
This commit is contained in:
@@ -245,22 +245,37 @@ class SquadProcessor(DataProcessor):
|
||||
train_file = None
|
||||
dev_file = None
|
||||
|
||||
def get_example_from_tensor_dict(self, tensor_dict):
|
||||
def get_example_from_tensor_dict(self, tensor_dict, evaluate=False):
|
||||
|
||||
if not evaluate:
|
||||
answer = tensor_dict['answers']['text'][0].numpy().decode('utf-8')
|
||||
answer_start = tensor_dict['answers']['answer_start'][0].numpy()
|
||||
answers = None
|
||||
else:
|
||||
answers = [{
|
||||
"answer_start": start.numpy(),
|
||||
"text": text.numpy().decode('utf-8')
|
||||
} for start, text in zip(tensor_dict['answers']["answer_start"], tensor_dict['answers']["text"])]
|
||||
|
||||
answer = None
|
||||
answer_start = None
|
||||
|
||||
return SquadExample(
|
||||
tensor_dict['id'].numpy().decode("utf-8"),
|
||||
tensor_dict['question'].numpy().decode('utf-8'),
|
||||
tensor_dict['context'].numpy().decode('utf-8'),
|
||||
tensor_dict['answers']['text'][0].numpy().decode('utf-8'),
|
||||
tensor_dict['answers']['answer_start'][0].numpy(),
|
||||
tensor_dict['title'].numpy().decode('utf-8')
|
||||
qas_id=tensor_dict['id'].numpy().decode("utf-8"),
|
||||
question_text=tensor_dict['question'].numpy().decode('utf-8'),
|
||||
context_text=tensor_dict['context'].numpy().decode('utf-8'),
|
||||
answer_text=answer,
|
||||
start_position_character=answer_start,
|
||||
title=tensor_dict['title'].numpy().decode('utf-8'),
|
||||
answers=answers
|
||||
)
|
||||
|
||||
def get_examples_from_dataset(self, dataset):
|
||||
def get_examples_from_dataset(self, dataset, evaluate=False):
|
||||
"""See base class."""
|
||||
|
||||
examples = []
|
||||
for tensor_dict in tqdm(dataset):
|
||||
examples.append(self.get_example_from_tensor_dict(tensor_dict))
|
||||
examples.append(self.get_example_from_tensor_dict(tensor_dict, evaluate=evaluate))
|
||||
|
||||
return examples
|
||||
|
||||
@@ -300,6 +315,7 @@ class SquadProcessor(DataProcessor):
|
||||
question_text = qa["question"]
|
||||
start_position_character = None
|
||||
answer_text = None
|
||||
answers = None
|
||||
|
||||
if "is_impossible" in qa:
|
||||
is_impossible = qa["is_impossible"]
|
||||
|
||||
Reference in New Issue
Block a user