Interface with TFDS
This commit is contained in:
@@ -246,16 +246,24 @@ class SquadProcessor(DataProcessor):
|
|||||||
dev_file = None
|
dev_file = None
|
||||||
|
|
||||||
def get_example_from_tensor_dict(self, tensor_dict):
|
def get_example_from_tensor_dict(self, tensor_dict):
|
||||||
"""See base class."""
|
|
||||||
return NewSquadExample(
|
return NewSquadExample(
|
||||||
tensor_dict['id'].numpy(),
|
tensor_dict['id'].numpy().decode("utf-8"),
|
||||||
tensor_dict['question'].numpy().decode('utf-8'),
|
tensor_dict['question'].numpy().decode('utf-8'),
|
||||||
tensor_dict['context'].numpy().decode('utf-8'),
|
tensor_dict['context'].numpy().decode('utf-8'),
|
||||||
tensor_dict['answers']['text'].numpy().decode('utf-8'),
|
tensor_dict['answers']['text'][0].numpy().decode('utf-8'),
|
||||||
tensor_dict['answers']['answers_start'].numpy().decode('utf-8'),
|
tensor_dict['answers']['answer_start'][0].numpy(),
|
||||||
tensor_dict['title'].numpy().decode('utf-8')
|
tensor_dict['title'].numpy().decode('utf-8')
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_examples_from_dataset(self, dataset):
|
||||||
|
"""See base class."""
|
||||||
|
|
||||||
|
examples = []
|
||||||
|
for tensor_dict in tqdm(dataset):
|
||||||
|
examples.append(self.get_example_from_tensor_dict(tensor_dict))
|
||||||
|
|
||||||
|
return examples
|
||||||
|
|
||||||
def get_train_examples(self, data_dir, only_first=None):
|
def get_train_examples(self, data_dir, only_first=None):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
if self.train_file is None:
|
if self.train_file is None:
|
||||||
|
|||||||
Reference in New Issue
Block a user