From f09e5ecef0566ce485c7ee913602dd502fcaeeb8 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Tue, 24 Sep 2019 09:47:34 -0400 Subject: [PATCH] [Proposal] GLUE processors included in library --- examples/run_glue.py | 5 +- .../preprocessing/__init__.py | 56 ++++ .../preprocessing/glue.py | 274 +++++------------- pytorch_transformers/preprocessing/utils.py | 99 +++++++ 4 files changed, 230 insertions(+), 204 deletions(-) create mode 100644 pytorch_transformers/preprocessing/__init__.py rename examples/utils_glue.py => pytorch_transformers/preprocessing/glue.py (75%) create mode 100644 pytorch_transformers/preprocessing/utils.py diff --git a/examples/run_glue.py b/examples/run_glue.py index e8c2edc833..3748c1a13a 100644 --- a/examples/run_glue.py +++ b/examples/run_glue.py @@ -46,8 +46,7 @@ from pytorch_transformers import (WEIGHTS_NAME, BertConfig, from pytorch_transformers import AdamW, WarmupLinearSchedule -from utils_glue import (compute_metrics, convert_examples_to_features, - output_modes, processors) +from pytorch_transformers.preprocessing import (compute_metrics, output_modes, processors, convert_examples_to_glue_features) logger = logging.getLogger(__name__) @@ -276,7 +275,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): # HACK(label indices are swapped in RoBERTa pretrained model) label_list[1], label_list[2] = label_list[2], label_list[1] examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir) - features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode, + features = convert_examples_to_glue_features(examples, label_list, args.max_seq_length, tokenizer, output_mode, pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0], pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0, diff --git a/pytorch_transformers/preprocessing/__init__.py b/pytorch_transformers/preprocessing/__init__.py new file mode 100644 index 0000000000..33426f06d6 --- /dev/null +++ b/pytorch_transformers/preprocessing/__init__.py @@ -0,0 +1,56 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from glue import (ColaProcessor, + MnliProcessor, + MnliMismatchedProcessor, + MrpcProcessor, + Sst2Processor, + StsbProcessor, + QqpProcessor, + QnliProcessor, + RteProcessor, + WnliProcessor, + convert_examples_to_glue_features, + ) + +from utils import DataProcessor, simple_accuracy, acc_and_f1, pearson_and_spearman, compute_metrics + +processors = { + "cola": ColaProcessor, + "mnli": MnliProcessor, + "mnli-mm": MnliMismatchedProcessor, + "mrpc": MrpcProcessor, + "sst-2": Sst2Processor, + "sts-b": StsbProcessor, + "qqp": QqpProcessor, + "qnli": QnliProcessor, + "rte": RteProcessor, + "wnli": WnliProcessor, +} + +output_modes = { + "cola": "classification", + "mnli": "classification", + "mnli-mm": "classification", + "mrpc": "classification", + "sst-2": "classification", + "sts-b": "regression", + "qqp": "classification", + "qnli": "classification", + "rte": "classification", + "wnli": "classification", +} diff --git a/examples/utils_glue.py b/pytorch_transformers/preprocessing/glue.py similarity index 75% rename from examples/utils_glue.py rename to pytorch_transformers/preprocessing/glue.py index 2557540cc6..b36ecebed6 100644 --- a/examples/utils_glue.py +++ b/pytorch_transformers/preprocessing/glue.py @@ -13,22 +13,84 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" BERT classification fine-tuning: utilities to work with GLUE tasks """ +""" GLUE processors and helpers """ -from __future__ import absolute_import, division, print_function - -import csv +from utils import DataProcessor import logging import os -import sys -from io import open - -from scipy.stats import pearsonr, spearmanr -from sklearn.metrics import matthews_corrcoef, f1_score logger = logging.getLogger(__name__) +def convert_examples_to_glue_features(examples, label_list, max_seq_length, + tokenizer, output_mode, + pad_on_left=False, + pad_token=0, + pad_token_segment_id=0, + mask_padding_with_zero=True): + """ + Loads a data file into a list of `InputBatch`s + """ + + label_map = {label: i for i, label in enumerate(label_list)} + + features = [] + for (ex_index, example) in enumerate(examples): + if ex_index % 10000 == 0: + logger.info("Writing example %d of %d" % (ex_index, len(examples))) + + inputs = tokenizer.encode_plus( + example.text_a, + example.text_b, + add_special_tokens=True, + output_token_type=True, + max_length=max_seq_length, + truncate_first_sequence=True # We're truncating the first sequence as a priority + ) + input_ids, segment_ids = inputs["input_ids"], inputs["token_type_ids"] + + # The mask has 1 for real tokens and 0 for padding tokens. Only real + # tokens are attended to. + input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) + + # Zero-pad up to the sequence length. + padding_length = max_seq_length - len(input_ids) + if pad_on_left: + input_ids = ([pad_token] * padding_length) + input_ids + input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask + segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids + else: + input_ids = input_ids + ([pad_token] * padding_length) + input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length) + segment_ids = segment_ids + ([pad_token_segment_id] * padding_length) + + assert len(input_ids) == max_seq_length + assert len(input_mask) == max_seq_length + assert len(segment_ids) == max_seq_length + + if output_mode == "classification": + label_id = label_map[example.label] + elif output_mode == "regression": + label_id = float(example.label) + else: + raise KeyError(output_mode) + + if ex_index < 5: + logger.info("*** Example ***") + logger.info("guid: %s" % (example.guid)) + logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) + logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) + logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) + logger.info("label: %s (id = %d)" % (example.label, label_id)) + + features.append( + InputFeatures(input_ids=input_ids, + input_mask=input_mask, + segment_ids=segment_ids, + label_id=label_id)) + return features + + class InputExample(object): """A single training/test example for simple sequence classification.""" @@ -60,34 +122,6 @@ class InputFeatures(object): self.label_id = label_id -class DataProcessor(object): - """Base class for data converters for sequence classification data sets.""" - - def get_train_examples(self, data_dir): - """Gets a collection of `InputExample`s for the train set.""" - raise NotImplementedError() - - def get_dev_examples(self, data_dir): - """Gets a collection of `InputExample`s for the dev set.""" - raise NotImplementedError() - - def get_labels(self): - """Gets the list of labels for this data set.""" - raise NotImplementedError() - - @classmethod - def _read_tsv(cls, input_file, quotechar=None): - """Reads a tab separated value file.""" - with open(input_file, "r", encoding="utf-8-sig") as f: - reader = csv.reader(f, delimiter="\t", quotechar=quotechar) - lines = [] - for line in reader: - if sys.version_info[0] == 2: - line = list(unicode(cell, 'utf-8') for cell in line) - lines.append(line) - return lines - - class MrpcProcessor(DataProcessor): """Processor for the MRPC data set (GLUE version).""" @@ -302,7 +336,7 @@ class QnliProcessor(DataProcessor): def get_dev_examples(self, data_dir): """See base class.""" return self._create_examples( - self._read_tsv(os.path.join(data_dir, "dev.tsv")), + self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched") def get_labels(self): @@ -387,168 +421,6 @@ class WnliProcessor(DataProcessor): InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) return examples - -def convert_examples_to_features(examples, label_list, max_seq_length, - tokenizer, output_mode, - pad_on_left=False, - pad_token=0, - pad_token_segment_id=0, - mask_padding_with_zero=True): - """ - Loads a data file into a list of `InputBatch`s - """ - - label_map = {label : i for i, label in enumerate(label_list)} - - features = [] - for (ex_index, example) in enumerate(examples): - if ex_index % 10000 == 0: - logger.info("Writing example %d of %d" % (ex_index, len(examples))) - - inputs = tokenizer.encode_plus( - example.text_a, - example.text_b, - add_special_tokens=True, - output_token_type=True, - max_length=max_seq_length, - truncate_first_sequence=True # We're truncating the first sequence as a priority - ) - input_ids, segment_ids = inputs["input_ids"], inputs["token_type_ids"] - - # The mask has 1 for real tokens and 0 for padding tokens. Only real - # tokens are attended to. - input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) - - # Zero-pad up to the sequence length. - padding_length = max_seq_length - len(input_ids) - if pad_on_left: - input_ids = ([pad_token] * padding_length) + input_ids - input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask - segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids - else: - input_ids = input_ids + ([pad_token] * padding_length) - input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length) - segment_ids = segment_ids + ([pad_token_segment_id] * padding_length) - - assert len(input_ids) == max_seq_length - assert len(input_mask) == max_seq_length - assert len(segment_ids) == max_seq_length - - if output_mode == "classification": - label_id = label_map[example.label] - elif output_mode == "regression": - label_id = float(example.label) - else: - raise KeyError(output_mode) - - if ex_index < 5: - logger.info("*** Example ***") - logger.info("guid: %s" % (example.guid)) - logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) - logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) - logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) - logger.info("label: %s (id = %d)" % (example.label, label_id)) - - features.append( - InputFeatures(input_ids=input_ids, - input_mask=input_mask, - segment_ids=segment_ids, - label_id=label_id)) - return features - - -def _truncate_seq_pair(tokens_a, tokens_b, max_length): - """Truncates a sequence pair in place to the maximum length.""" - - # This is a simple heuristic which will always truncate the longer sequence - # one token at a time. This makes more sense than truncating an equal percent - # of tokens from each, since if one sequence is very short then each token - # that's truncated likely contains more information than a longer sequence. - while True: - total_length = len(tokens_a) + len(tokens_b) - if total_length <= max_length: - break - if len(tokens_a) > len(tokens_b): - tokens_a.pop() - else: - tokens_b.pop() - - -def simple_accuracy(preds, labels): - return (preds == labels).mean() - - -def acc_and_f1(preds, labels): - acc = simple_accuracy(preds, labels) - f1 = f1_score(y_true=labels, y_pred=preds) - return { - "acc": acc, - "f1": f1, - "acc_and_f1": (acc + f1) / 2, - } - - -def pearson_and_spearman(preds, labels): - pearson_corr = pearsonr(preds, labels)[0] - spearman_corr = spearmanr(preds, labels)[0] - return { - "pearson": pearson_corr, - "spearmanr": spearman_corr, - "corr": (pearson_corr + spearman_corr) / 2, - } - - -def compute_metrics(task_name, preds, labels): - assert len(preds) == len(labels) - if task_name == "cola": - return {"mcc": matthews_corrcoef(labels, preds)} - elif task_name == "sst-2": - return {"acc": simple_accuracy(preds, labels)} - elif task_name == "mrpc": - return acc_and_f1(preds, labels) - elif task_name == "sts-b": - return pearson_and_spearman(preds, labels) - elif task_name == "qqp": - return acc_and_f1(preds, labels) - elif task_name == "mnli": - return {"acc": simple_accuracy(preds, labels)} - elif task_name == "mnli-mm": - return {"acc": simple_accuracy(preds, labels)} - elif task_name == "qnli": - return {"acc": simple_accuracy(preds, labels)} - elif task_name == "rte": - return {"acc": simple_accuracy(preds, labels)} - elif task_name == "wnli": - return {"acc": simple_accuracy(preds, labels)} - else: - raise KeyError(task_name) - -processors = { - "cola": ColaProcessor, - "mnli": MnliProcessor, - "mnli-mm": MnliMismatchedProcessor, - "mrpc": MrpcProcessor, - "sst-2": Sst2Processor, - "sts-b": StsbProcessor, - "qqp": QqpProcessor, - "qnli": QnliProcessor, - "rte": RteProcessor, - "wnli": WnliProcessor, -} - -output_modes = { - "cola": "classification", - "mnli": "classification", - "mnli-mm": "classification", - "mrpc": "classification", - "sst-2": "classification", - "sts-b": "regression", - "qqp": "classification", - "qnli": "classification", - "rte": "classification", - "wnli": "classification", -} - GLUE_TASKS_NUM_LABELS = { "cola": 2, "mnli": 3, @@ -559,4 +431,4 @@ GLUE_TASKS_NUM_LABELS = { "qnli": 2, "rte": 2, "wnli": 2, -} +} \ No newline at end of file diff --git a/pytorch_transformers/preprocessing/utils.py b/pytorch_transformers/preprocessing/utils.py new file mode 100644 index 0000000000..b4a3d0d968 --- /dev/null +++ b/pytorch_transformers/preprocessing/utils.py @@ -0,0 +1,99 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +import sys + +from scipy.stats import pearsonr, spearmanr +from sklearn.metrics import matthews_corrcoef, f1_score + + +class DataProcessor(object): + """Base class for data converters for sequence classification data sets.""" + + def get_train_examples(self, data_dir): + """Gets a collection of `InputExample`s for the train set.""" + raise NotImplementedError() + + def get_dev_examples(self, data_dir): + """Gets a collection of `InputExample`s for the dev set.""" + raise NotImplementedError() + + def get_labels(self): + """Gets the list of labels for this data set.""" + raise NotImplementedError() + + @classmethod + def _read_tsv(cls, input_file, quotechar=None): + """Reads a tab separated value file.""" + with open(input_file, "r", encoding="utf-8-sig") as f: + reader = csv.reader(f, delimiter="\t", quotechar=quotechar) + lines = [] + for line in reader: + if sys.version_info[0] == 2: + line = list(unicode(cell, 'utf-8') for cell in line) + lines.append(line) + return lines + + +def simple_accuracy(preds, labels): + return (preds == labels).mean() + + +def acc_and_f1(preds, labels): + acc = simple_accuracy(preds, labels) + f1 = f1_score(y_true=labels, y_pred=preds) + return { + "acc": acc, + "f1": f1, + "acc_and_f1": (acc + f1) / 2, + } + + +def pearson_and_spearman(preds, labels): + pearson_corr = pearsonr(preds, labels)[0] + spearman_corr = spearmanr(preds, labels)[0] + return { + "pearson": pearson_corr, + "spearmanr": spearman_corr, + "corr": (pearson_corr + spearman_corr) / 2, + } + + +def compute_metrics(task_name, preds, labels): + assert len(preds) == len(labels) + if task_name == "cola": + return {"mcc": matthews_corrcoef(labels, preds)} + elif task_name == "sst-2": + return {"acc": simple_accuracy(preds, labels)} + elif task_name == "mrpc": + return acc_and_f1(preds, labels) + elif task_name == "sts-b": + return pearson_and_spearman(preds, labels) + elif task_name == "qqp": + return acc_and_f1(preds, labels) + elif task_name == "mnli": + return {"acc": simple_accuracy(preds, labels)} + elif task_name == "mnli-mm": + return {"acc": simple_accuracy(preds, labels)} + elif task_name == "qnli": + return {"acc": simple_accuracy(preds, labels)} + elif task_name == "rte": + return {"acc": simple_accuracy(preds, labels)} + elif task_name == "wnli": + return {"acc": simple_accuracy(preds, labels)} + else: + raise KeyError(task_name) \ No newline at end of file