From bf119c0568dfc1ea5ce0a34359e33ca002266e96 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Wed, 4 Dec 2019 11:34:59 -0500 Subject: [PATCH] TFDS dataset can now be evaluated --- transformers/data/processors/squad.py | 34 ++++++++++++++++++++------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/transformers/data/processors/squad.py b/transformers/data/processors/squad.py index 70dc9faf54..2e50ac8a8c 100644 --- a/transformers/data/processors/squad.py +++ b/transformers/data/processors/squad.py @@ -245,22 +245,37 @@ class SquadProcessor(DataProcessor): train_file = None dev_file = None - def get_example_from_tensor_dict(self, tensor_dict): + def get_example_from_tensor_dict(self, tensor_dict, evaluate=False): + + if not evaluate: + answer = tensor_dict['answers']['text'][0].numpy().decode('utf-8') + answer_start = tensor_dict['answers']['answer_start'][0].numpy() + answers = None + else: + answers = [{ + "answer_start": start.numpy(), + "text": text.numpy().decode('utf-8') + } for start, text in zip(tensor_dict['answers']["answer_start"], tensor_dict['answers']["text"])] + + answer = None + answer_start = None + return SquadExample( - tensor_dict['id'].numpy().decode("utf-8"), - tensor_dict['question'].numpy().decode('utf-8'), - tensor_dict['context'].numpy().decode('utf-8'), - tensor_dict['answers']['text'][0].numpy().decode('utf-8'), - tensor_dict['answers']['answer_start'][0].numpy(), - tensor_dict['title'].numpy().decode('utf-8') + qas_id=tensor_dict['id'].numpy().decode("utf-8"), + question_text=tensor_dict['question'].numpy().decode('utf-8'), + context_text=tensor_dict['context'].numpy().decode('utf-8'), + answer_text=answer, + start_position_character=answer_start, + title=tensor_dict['title'].numpy().decode('utf-8'), + answers=answers ) - def get_examples_from_dataset(self, dataset): + def get_examples_from_dataset(self, dataset, evaluate=False): """See base class.""" examples = [] for tensor_dict in tqdm(dataset): - examples.append(self.get_example_from_tensor_dict(tensor_dict)) + examples.append(self.get_example_from_tensor_dict(tensor_dict, evaluate=evaluate)) return examples @@ -300,6 +315,7 @@ class SquadProcessor(DataProcessor): question_text = qa["question"] start_position_character = None answer_text = None + answers = None if "is_impossible" in qa: is_impossible = qa["is_impossible"]