From f671997ef74199823db83ed7b43340764888e129 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Thu, 28 Nov 2019 17:17:20 -0500 Subject: [PATCH] Interface with TFDS --- transformers/data/processors/squad.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/transformers/data/processors/squad.py b/transformers/data/processors/squad.py index 3d5a3eca80..52c2c28add 100644 --- a/transformers/data/processors/squad.py +++ b/transformers/data/processors/squad.py @@ -246,16 +246,24 @@ class SquadProcessor(DataProcessor): dev_file = None def get_example_from_tensor_dict(self, tensor_dict): - """See base class.""" return NewSquadExample( - tensor_dict['id'].numpy(), + tensor_dict['id'].numpy().decode("utf-8"), tensor_dict['question'].numpy().decode('utf-8'), tensor_dict['context'].numpy().decode('utf-8'), - tensor_dict['answers']['text'].numpy().decode('utf-8'), - tensor_dict['answers']['answers_start'].numpy().decode('utf-8'), + tensor_dict['answers']['text'][0].numpy().decode('utf-8'), + tensor_dict['answers']['answer_start'][0].numpy(), tensor_dict['title'].numpy().decode('utf-8') ) + def get_examples_from_dataset(self, dataset): + """See base class.""" + + examples = [] + for tensor_dict in tqdm(dataset): + examples.append(self.get_example_from_tensor_dict(tensor_dict)) + + return examples + def get_train_examples(self, data_dir, only_first=None): """See base class.""" if self.train_file is None: