From 1efc208ff386fb6df56302c8f6f9484ddf93b92a Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Mon, 6 Jan 2020 15:02:25 +0100 Subject: [PATCH] Complete DataProcessor class --- src/transformers/data/processors/utils.py | 27 +++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/transformers/data/processors/utils.py b/src/transformers/data/processors/utils.py index d16d72f6cb..4cc931cdf9 100644 --- a/src/transformers/data/processors/utils.py +++ b/src/transformers/data/processors/utils.py @@ -93,6 +93,33 @@ class InputFeatures(object): class DataProcessor(object): """Base class for data converters for sequence classification data sets.""" + def get_example_from_tensor_dict(self, tensor_dict): + """Gets an example from a dict with tensorflow tensors + Args: + tensor_dict: Keys and values should match the corresponding Glue + tensorflow_dataset examples. + """ + raise NotImplementedError() + + def get_train_examples(self, data_dir): + """Gets a collection of `InputExample`s for the train set.""" + raise NotImplementedError() + + def get_dev_examples(self, data_dir): + """Gets a collection of `InputExample`s for the dev set.""" + raise NotImplementedError() + + def get_labels(self): + """Gets the list of labels for this data set.""" + raise NotImplementedError() + + def tfds_map(self, example): + """Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are. + This method converts examples to the correct format.""" + if len(self.get_labels()) > 1: + example.label = self.get_labels()[int(example.label)] + return example + @classmethod def _read_tsv(cls, input_file, quotechar=None): """Reads a tab separated value file."""