Fix tensorflow_dataset glue support

`glue_convert_examples_to_features` assumed that tensorflow_dataset
examples contains the features `'sentence1'` and `'sentence2'`. This
commit encapsulates the choice of features in the glue processor and
uses that to parse examples.
This commit is contained in:
Agrin Hilmkil
2019-09-27 16:51:17 +02:00
parent ca559826c4
commit e31a472801

View File

@@ -79,10 +79,7 @@ def glue_convert_examples_to_features(examples, tokenizer,
if ex_index % 10000 == 0: if ex_index % 10000 == 0:
logger.info("Writing example %d" % (ex_index)) logger.info("Writing example %d" % (ex_index))
if is_tf_dataset: if is_tf_dataset:
example = InputExample(example['idx'].numpy(), example = processor.get_example_from_tensor_dict(example)
example['sentence1'].numpy().decode('utf-8'),
example['sentence2'].numpy().decode('utf-8'),
str(example['label'].numpy()))
inputs = tokenizer.encode_plus( inputs = tokenizer.encode_plus(
example.text_a, example.text_a,
@@ -157,6 +154,12 @@ def glue_convert_examples_to_features(examples, tokenizer,
class MrpcProcessor(DataProcessor): class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version).""" """Processor for the MRPC data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['sentence1'].numpy().decode('utf-8'),
tensor_dict['sentence2'].numpy().decode('utf-8'),
str(tensor_dict['label'].numpy()))
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv"))) logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv")))
@@ -190,6 +193,12 @@ class MrpcProcessor(DataProcessor):
class MnliProcessor(DataProcessor): class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version).""" """Processor for the MultiNLI data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['premise'].numpy().decode('utf-8'),
tensor_dict['hypothesis'].numpy().decode('utf-8'),
str(tensor_dict['label'].numpy()))
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(
@@ -233,6 +242,12 @@ class MnliMismatchedProcessor(MnliProcessor):
class ColaProcessor(DataProcessor): class ColaProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version).""" """Processor for the CoLA data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['sentence'].numpy().decode('utf-8'),
None,
str(tensor_dict['label'].numpy()))
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(
@@ -262,6 +277,12 @@ class ColaProcessor(DataProcessor):
class Sst2Processor(DataProcessor): class Sst2Processor(DataProcessor):
"""Processor for the SST-2 data set (GLUE version).""" """Processor for the SST-2 data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['sentence'].numpy().decode('utf-8'),
None,
str(tensor_dict['label'].numpy()))
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(
@@ -293,6 +314,12 @@ class Sst2Processor(DataProcessor):
class StsbProcessor(DataProcessor): class StsbProcessor(DataProcessor):
"""Processor for the STS-B data set (GLUE version).""" """Processor for the STS-B data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['sentence1'].numpy().decode('utf-8'),
tensor_dict['sentence2'].numpy().decode('utf-8'),
str(tensor_dict['label'].numpy()))
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(
@@ -325,6 +352,12 @@ class StsbProcessor(DataProcessor):
class QqpProcessor(DataProcessor): class QqpProcessor(DataProcessor):
"""Processor for the QQP data set (GLUE version).""" """Processor for the QQP data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['question1'].numpy().decode('utf-8'),
tensor_dict['question2'].numpy().decode('utf-8'),
str(tensor_dict['label'].numpy()))
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(
@@ -360,6 +393,12 @@ class QqpProcessor(DataProcessor):
class QnliProcessor(DataProcessor): class QnliProcessor(DataProcessor):
"""Processor for the QNLI data set (GLUE version).""" """Processor for the QNLI data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['question'].numpy().decode('utf-8'),
tensor_dict['sentence'].numpy().decode('utf-8'),
str(tensor_dict['label'].numpy()))
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(
@@ -393,6 +432,12 @@ class QnliProcessor(DataProcessor):
class RteProcessor(DataProcessor): class RteProcessor(DataProcessor):
"""Processor for the RTE data set (GLUE version).""" """Processor for the RTE data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['sentence1'].numpy().decode('utf-8'),
tensor_dict['sentence2'].numpy().decode('utf-8'),
str(tensor_dict['label'].numpy()))
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(
@@ -425,6 +470,12 @@ class RteProcessor(DataProcessor):
class WnliProcessor(DataProcessor): class WnliProcessor(DataProcessor):
"""Processor for the WNLI data set (GLUE version).""" """Processor for the WNLI data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['sentence1'].numpy().decode('utf-8'),
tensor_dict['sentence2'].numpy().decode('utf-8'),
str(tensor_dict['label'].numpy()))
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(