From 795b3e76ffbeb224ee334252abc1b3c359b26067 Mon Sep 17 00:00:00 2001 From: Agrin Hilmkil Date: Fri, 27 Sep 2019 17:32:28 +0200 Subject: [PATCH] Add docstring for processor method --- transformers/data/processors/glue.py | 9 +++++++++ transformers/data/processors/utils.py | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/transformers/data/processors/glue.py b/transformers/data/processors/glue.py index 8bf209a08b..61bca8c11b 100644 --- a/transformers/data/processors/glue.py +++ b/transformers/data/processors/glue.py @@ -155,6 +155,7 @@ class MrpcProcessor(DataProcessor): """Processor for the MRPC data set (GLUE version).""" def get_example_from_tensor_dict(self, tensor_dict): + """See base class.""" return InputExample(tensor_dict['idx'].numpy(), tensor_dict['sentence1'].numpy().decode('utf-8'), tensor_dict['sentence2'].numpy().decode('utf-8'), @@ -194,6 +195,7 @@ class MnliProcessor(DataProcessor): """Processor for the MultiNLI data set (GLUE version).""" def get_example_from_tensor_dict(self, tensor_dict): + """See base class.""" return InputExample(tensor_dict['idx'].numpy(), tensor_dict['premise'].numpy().decode('utf-8'), tensor_dict['hypothesis'].numpy().decode('utf-8'), @@ -243,6 +245,7 @@ class ColaProcessor(DataProcessor): """Processor for the CoLA data set (GLUE version).""" def get_example_from_tensor_dict(self, tensor_dict): + """See base class.""" return InputExample(tensor_dict['idx'].numpy(), tensor_dict['sentence'].numpy().decode('utf-8'), None, @@ -278,6 +281,7 @@ class Sst2Processor(DataProcessor): """Processor for the SST-2 data set (GLUE version).""" def get_example_from_tensor_dict(self, tensor_dict): + """See base class.""" return InputExample(tensor_dict['idx'].numpy(), tensor_dict['sentence'].numpy().decode('utf-8'), None, @@ -315,6 +319,7 @@ class StsbProcessor(DataProcessor): """Processor for the STS-B data set (GLUE version).""" def get_example_from_tensor_dict(self, tensor_dict): + """See base class.""" return InputExample(tensor_dict['idx'].numpy(), tensor_dict['sentence1'].numpy().decode('utf-8'), tensor_dict['sentence2'].numpy().decode('utf-8'), @@ -353,6 +358,7 @@ class QqpProcessor(DataProcessor): """Processor for the QQP data set (GLUE version).""" def get_example_from_tensor_dict(self, tensor_dict): + """See base class.""" return InputExample(tensor_dict['idx'].numpy(), tensor_dict['question1'].numpy().decode('utf-8'), tensor_dict['question2'].numpy().decode('utf-8'), @@ -394,6 +400,7 @@ class QnliProcessor(DataProcessor): """Processor for the QNLI data set (GLUE version).""" def get_example_from_tensor_dict(self, tensor_dict): + """See base class.""" return InputExample(tensor_dict['idx'].numpy(), tensor_dict['question'].numpy().decode('utf-8'), tensor_dict['sentence'].numpy().decode('utf-8'), @@ -433,6 +440,7 @@ class RteProcessor(DataProcessor): """Processor for the RTE data set (GLUE version).""" def get_example_from_tensor_dict(self, tensor_dict): + """See base class.""" return InputExample(tensor_dict['idx'].numpy(), tensor_dict['sentence1'].numpy().decode('utf-8'), tensor_dict['sentence2'].numpy().decode('utf-8'), @@ -471,6 +479,7 @@ class WnliProcessor(DataProcessor): """Processor for the WNLI data set (GLUE version).""" def get_example_from_tensor_dict(self, tensor_dict): + """See base class.""" return InputExample(tensor_dict['idx'].numpy(), tensor_dict['sentence1'].numpy().decode('utf-8'), tensor_dict['sentence2'].numpy().decode('utf-8'), diff --git a/transformers/data/processors/utils.py b/transformers/data/processors/utils.py index d16ea786a0..27138f9959 100644 --- a/transformers/data/processors/utils.py +++ b/transformers/data/processors/utils.py @@ -86,6 +86,15 @@ class InputFeatures(object): class DataProcessor(object): """Base class for data converters for sequence classification data sets.""" + def get_example_from_tensor_dict(self, tensor_dict): + """Gets an example from a dict with tensorflow tensors + + Args: + tensor_dict: Keys and values should match the corresponding Glue + tensorflow_dataset examples. + """ + raise NotImplementedError() + def get_train_examples(self, data_dir): """Gets a collection of `InputExample`s for the train set.""" raise NotImplementedError()