From e31a4728013db1914b7dd0f50cc18c5702db5b79 Mon Sep 17 00:00:00 2001 From: Agrin Hilmkil Date: Fri, 27 Sep 2019 16:51:17 +0200 Subject: [PATCH] Fix tensorflow_dataset glue support `glue_convert_examples_to_features` assumed that tensorflow_dataset examples contains the features `'sentence1'` and `'sentence2'`. This commit encapsulates the choice of features in the glue processor and uses that to parse examples. --- transformers/data/processors/glue.py | 59 ++++++++++++++++++++++++++-- 1 file changed, 55 insertions(+), 4 deletions(-) diff --git a/transformers/data/processors/glue.py b/transformers/data/processors/glue.py index 2322f58604..8bf209a08b 100644 --- a/transformers/data/processors/glue.py +++ b/transformers/data/processors/glue.py @@ -79,10 +79,7 @@ def glue_convert_examples_to_features(examples, tokenizer, if ex_index % 10000 == 0: logger.info("Writing example %d" % (ex_index)) if is_tf_dataset: - example = InputExample(example['idx'].numpy(), - example['sentence1'].numpy().decode('utf-8'), - example['sentence2'].numpy().decode('utf-8'), - str(example['label'].numpy())) + example = processor.get_example_from_tensor_dict(example) inputs = tokenizer.encode_plus( example.text_a, @@ -157,6 +154,12 @@ def glue_convert_examples_to_features(examples, tokenizer, class MrpcProcessor(DataProcessor): """Processor for the MRPC data set (GLUE version).""" + def get_example_from_tensor_dict(self, tensor_dict): + return InputExample(tensor_dict['idx'].numpy(), + tensor_dict['sentence1'].numpy().decode('utf-8'), + tensor_dict['sentence2'].numpy().decode('utf-8'), + str(tensor_dict['label'].numpy())) + def get_train_examples(self, data_dir): """See base class.""" logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv"))) @@ -190,6 +193,12 @@ class MrpcProcessor(DataProcessor): class MnliProcessor(DataProcessor): """Processor for the MultiNLI data set (GLUE version).""" + def get_example_from_tensor_dict(self, tensor_dict): + return InputExample(tensor_dict['idx'].numpy(), + tensor_dict['premise'].numpy().decode('utf-8'), + tensor_dict['hypothesis'].numpy().decode('utf-8'), + str(tensor_dict['label'].numpy())) + def get_train_examples(self, data_dir): """See base class.""" return self._create_examples( @@ -233,6 +242,12 @@ class MnliMismatchedProcessor(MnliProcessor): class ColaProcessor(DataProcessor): """Processor for the CoLA data set (GLUE version).""" + def get_example_from_tensor_dict(self, tensor_dict): + return InputExample(tensor_dict['idx'].numpy(), + tensor_dict['sentence'].numpy().decode('utf-8'), + None, + str(tensor_dict['label'].numpy())) + def get_train_examples(self, data_dir): """See base class.""" return self._create_examples( @@ -262,6 +277,12 @@ class ColaProcessor(DataProcessor): class Sst2Processor(DataProcessor): """Processor for the SST-2 data set (GLUE version).""" + def get_example_from_tensor_dict(self, tensor_dict): + return InputExample(tensor_dict['idx'].numpy(), + tensor_dict['sentence'].numpy().decode('utf-8'), + None, + str(tensor_dict['label'].numpy())) + def get_train_examples(self, data_dir): """See base class.""" return self._create_examples( @@ -293,6 +314,12 @@ class Sst2Processor(DataProcessor): class StsbProcessor(DataProcessor): """Processor for the STS-B data set (GLUE version).""" + def get_example_from_tensor_dict(self, tensor_dict): + return InputExample(tensor_dict['idx'].numpy(), + tensor_dict['sentence1'].numpy().decode('utf-8'), + tensor_dict['sentence2'].numpy().decode('utf-8'), + str(tensor_dict['label'].numpy())) + def get_train_examples(self, data_dir): """See base class.""" return self._create_examples( @@ -325,6 +352,12 @@ class StsbProcessor(DataProcessor): class QqpProcessor(DataProcessor): """Processor for the QQP data set (GLUE version).""" + def get_example_from_tensor_dict(self, tensor_dict): + return InputExample(tensor_dict['idx'].numpy(), + tensor_dict['question1'].numpy().decode('utf-8'), + tensor_dict['question2'].numpy().decode('utf-8'), + str(tensor_dict['label'].numpy())) + def get_train_examples(self, data_dir): """See base class.""" return self._create_examples( @@ -360,6 +393,12 @@ class QqpProcessor(DataProcessor): class QnliProcessor(DataProcessor): """Processor for the QNLI data set (GLUE version).""" + def get_example_from_tensor_dict(self, tensor_dict): + return InputExample(tensor_dict['idx'].numpy(), + tensor_dict['question'].numpy().decode('utf-8'), + tensor_dict['sentence'].numpy().decode('utf-8'), + str(tensor_dict['label'].numpy())) + def get_train_examples(self, data_dir): """See base class.""" return self._create_examples( @@ -393,6 +432,12 @@ class QnliProcessor(DataProcessor): class RteProcessor(DataProcessor): """Processor for the RTE data set (GLUE version).""" + def get_example_from_tensor_dict(self, tensor_dict): + return InputExample(tensor_dict['idx'].numpy(), + tensor_dict['sentence1'].numpy().decode('utf-8'), + tensor_dict['sentence2'].numpy().decode('utf-8'), + str(tensor_dict['label'].numpy())) + def get_train_examples(self, data_dir): """See base class.""" return self._create_examples( @@ -425,6 +470,12 @@ class RteProcessor(DataProcessor): class WnliProcessor(DataProcessor): """Processor for the WNLI data set (GLUE version).""" + def get_example_from_tensor_dict(self, tensor_dict): + return InputExample(tensor_dict['idx'].numpy(), + tensor_dict['sentence1'].numpy().decode('utf-8'), + tensor_dict['sentence2'].numpy().decode('utf-8'), + str(tensor_dict['label'].numpy())) + def get_train_examples(self, data_dir): """See base class.""" return self._create_examples(