Add docstring for processor method

This commit is contained in:
Agrin Hilmkil
2019-09-27 17:32:28 +02:00
parent e31a472801
commit 795b3e76ff
2 changed files with 18 additions and 0 deletions

View File

@@ -155,6 +155,7 @@ class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['sentence1'].numpy().decode('utf-8'),
tensor_dict['sentence2'].numpy().decode('utf-8'),
@@ -194,6 +195,7 @@ class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['premise'].numpy().decode('utf-8'),
tensor_dict['hypothesis'].numpy().decode('utf-8'),
@@ -243,6 +245,7 @@ class ColaProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['sentence'].numpy().decode('utf-8'),
None,
@@ -278,6 +281,7 @@ class Sst2Processor(DataProcessor):
"""Processor for the SST-2 data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['sentence'].numpy().decode('utf-8'),
None,
@@ -315,6 +319,7 @@ class StsbProcessor(DataProcessor):
"""Processor for the STS-B data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['sentence1'].numpy().decode('utf-8'),
tensor_dict['sentence2'].numpy().decode('utf-8'),
@@ -353,6 +358,7 @@ class QqpProcessor(DataProcessor):
"""Processor for the QQP data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['question1'].numpy().decode('utf-8'),
tensor_dict['question2'].numpy().decode('utf-8'),
@@ -394,6 +400,7 @@ class QnliProcessor(DataProcessor):
"""Processor for the QNLI data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['question'].numpy().decode('utf-8'),
tensor_dict['sentence'].numpy().decode('utf-8'),
@@ -433,6 +440,7 @@ class RteProcessor(DataProcessor):
"""Processor for the RTE data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['sentence1'].numpy().decode('utf-8'),
tensor_dict['sentence2'].numpy().decode('utf-8'),
@@ -471,6 +479,7 @@ class WnliProcessor(DataProcessor):
"""Processor for the WNLI data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['sentence1'].numpy().decode('utf-8'),
tensor_dict['sentence2'].numpy().decode('utf-8'),

View File

@@ -86,6 +86,15 @@ class InputFeatures(object):
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def get_example_from_tensor_dict(self, tensor_dict):
"""Gets an example from a dict with tensorflow tensors
Args:
tensor_dict: Keys and values should match the corresponding Glue
tensorflow_dataset examples.
"""
raise NotImplementedError()
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()