From 9a25c5bd3afeab85a80acb1a5348beec1d2cbbfd Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Fri, 18 Dec 2020 14:19:24 -0500 Subject: [PATCH] Add new run_swag example (#9175) * Add new run_swag example * Add check * Add sample * Apply suggestions from code review Co-authored-by: Lysandre Debut * Very important change to make Lysandre happy Co-authored-by: Lysandre Debut --- examples/README.md | 2 +- .../multiple_choice}/run_multiple_choice.py | 0 .../multiple_choice/utils_multiple_choice.py | 579 ++++++++++++++++++ examples/multiple-choice/README.md | 13 +- examples/multiple-choice/run_swag.py | 349 +++++++++++ examples/test_examples.py | 28 + tests/fixtures/tests_samples/swag/sample.json | 10 + 7 files changed, 970 insertions(+), 11 deletions(-) rename examples/{multiple-choice => legacy/multiple_choice}/run_multiple_choice.py (100%) create mode 100644 examples/legacy/multiple_choice/utils_multiple_choice.py create mode 100644 examples/multiple-choice/run_swag.py create mode 100644 tests/fixtures/tests_samples/swag/sample.json diff --git a/examples/README.md b/examples/README.md index dce3f09892..077758bc2d 100644 --- a/examples/README.md +++ b/examples/README.md @@ -54,7 +54,7 @@ Coming soon! | Task | Example datasets | Trainer support | TFTrainer support | 🤗 Datasets | Colab |---|---|:---:|:---:|:---:|:---:| | [**`language-modeling`**](https://github.com/huggingface/transformers/tree/master/examples/language-modeling) | Raw text | ✅ | - | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/01_how_to_train.ipynb) -| [**`multiple-choice`**](https://github.com/huggingface/transformers/tree/master/examples/multiple-choice) | SWAG, RACE, ARC | ✅ | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ViktorAlm/notebooks/blob/master/MPC_GPU_Demo_for_TF_and_PT.ipynb) +| [**`multiple-choice`**](https://github.com/huggingface/transformers/tree/master/examples/multiple-choice) | SWAG, RACE, ARC | ✅ | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ViktorAlm/notebooks/blob/master/MPC_GPU_Demo_for_TF_and_PT.ipynb) | [**`question-answering`**](https://github.com/huggingface/transformers/tree/master/examples/question-answering) | SQuAD | ✅ | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/huggingface/notebooks/blob/master/examples/question_answering.ipynb) | [**`summarization`**](https://github.com/huggingface/transformers/tree/master/examples/seq2seq) | CNN/Daily Mail | ✅ | - | - | - | [**`text-classification`**](https://github.com/huggingface/transformers/tree/master/examples/text-classification) | GLUE, XNLI | ✅ | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/huggingface/notebooks/blob/master/examples/text_classification.ipynb) diff --git a/examples/multiple-choice/run_multiple_choice.py b/examples/legacy/multiple_choice/run_multiple_choice.py similarity index 100% rename from examples/multiple-choice/run_multiple_choice.py rename to examples/legacy/multiple_choice/run_multiple_choice.py diff --git a/examples/legacy/multiple_choice/utils_multiple_choice.py b/examples/legacy/multiple_choice/utils_multiple_choice.py new file mode 100644 index 0000000000..784a7578d3 --- /dev/null +++ b/examples/legacy/multiple_choice/utils_multiple_choice.py @@ -0,0 +1,579 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Multiple choice fine-tuning: utilities to work with multiple choice tasks of reading comprehension """ + + +import csv +import glob +import json +import logging +import os +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional + +import tqdm + +from filelock import FileLock +from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available + + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class InputExample: + """ + A single training/test example for multiple choice + + Args: + example_id: Unique id for the example. + question: string. The untokenized text of the second sequence (question). + contexts: list of str. The untokenized text of the first sequence (context of corresponding question). + endings: list of str. multiple choice's options. Its length must be equal to contexts' length. + label: (Optional) string. The label of the example. This should be + specified for train and dev examples, but not for test examples. + """ + + example_id: str + question: str + contexts: List[str] + endings: List[str] + label: Optional[str] + + +@dataclass(frozen=True) +class InputFeatures: + """ + A single set of features of data. + Property names are the same names as the corresponding inputs to a model. + """ + + example_id: str + input_ids: List[List[int]] + attention_mask: Optional[List[List[int]]] + token_type_ids: Optional[List[List[int]]] + label: Optional[int] + + +class Split(Enum): + train = "train" + dev = "dev" + test = "test" + + +if is_torch_available(): + import torch + from torch.utils.data.dataset import Dataset + + class MultipleChoiceDataset(Dataset): + """ + This will be superseded by a framework-agnostic approach + soon. + """ + + features: List[InputFeatures] + + def __init__( + self, + data_dir: str, + tokenizer: PreTrainedTokenizer, + task: str, + max_seq_length: Optional[int] = None, + overwrite_cache=False, + mode: Split = Split.train, + ): + processor = processors[task]() + + cached_features_file = os.path.join( + data_dir, + "cached_{}_{}_{}_{}".format( + mode.value, + tokenizer.__class__.__name__, + str(max_seq_length), + task, + ), + ) + + # Make sure only the first process in distributed training processes the dataset, + # and the others will use the cache. + lock_path = cached_features_file + ".lock" + with FileLock(lock_path): + + if os.path.exists(cached_features_file) and not overwrite_cache: + logger.info(f"Loading features from cached file {cached_features_file}") + self.features = torch.load(cached_features_file) + else: + logger.info(f"Creating features from dataset file at {data_dir}") + label_list = processor.get_labels() + if mode == Split.dev: + examples = processor.get_dev_examples(data_dir) + elif mode == Split.test: + examples = processor.get_test_examples(data_dir) + else: + examples = processor.get_train_examples(data_dir) + logger.info("Training examples: %s", len(examples)) + self.features = convert_examples_to_features( + examples, + label_list, + max_seq_length, + tokenizer, + ) + logger.info("Saving features into cached file %s", cached_features_file) + torch.save(self.features, cached_features_file) + + def __len__(self): + return len(self.features) + + def __getitem__(self, i) -> InputFeatures: + return self.features[i] + + +if is_tf_available(): + import tensorflow as tf + + class TFMultipleChoiceDataset: + """ + This will be superseded by a framework-agnostic approach + soon. + """ + + features: List[InputFeatures] + + def __init__( + self, + data_dir: str, + tokenizer: PreTrainedTokenizer, + task: str, + max_seq_length: Optional[int] = 128, + overwrite_cache=False, + mode: Split = Split.train, + ): + processor = processors[task]() + + logger.info(f"Creating features from dataset file at {data_dir}") + label_list = processor.get_labels() + if mode == Split.dev: + examples = processor.get_dev_examples(data_dir) + elif mode == Split.test: + examples = processor.get_test_examples(data_dir) + else: + examples = processor.get_train_examples(data_dir) + logger.info("Training examples: %s", len(examples)) + + self.features = convert_examples_to_features( + examples, + label_list, + max_seq_length, + tokenizer, + ) + + def gen(): + for (ex_index, ex) in tqdm.tqdm(enumerate(self.features), desc="convert examples to features"): + if ex_index % 10000 == 0: + logger.info("Writing example %d of %d" % (ex_index, len(examples))) + + yield ( + { + "example_id": 0, + "input_ids": ex.input_ids, + "attention_mask": ex.attention_mask, + "token_type_ids": ex.token_type_ids, + }, + ex.label, + ) + + self.dataset = tf.data.Dataset.from_generator( + gen, + ( + { + "example_id": tf.int32, + "input_ids": tf.int32, + "attention_mask": tf.int32, + "token_type_ids": tf.int32, + }, + tf.int64, + ), + ( + { + "example_id": tf.TensorShape([]), + "input_ids": tf.TensorShape([None, None]), + "attention_mask": tf.TensorShape([None, None]), + "token_type_ids": tf.TensorShape([None, None]), + }, + tf.TensorShape([]), + ), + ) + + def get_dataset(self): + self.dataset = self.dataset.apply(tf.data.experimental.assert_cardinality(len(self.features))) + + return self.dataset + + def __len__(self): + return len(self.features) + + def __getitem__(self, i) -> InputFeatures: + return self.features[i] + + +class DataProcessor: + """Base class for data converters for multiple choice data sets.""" + + def get_train_examples(self, data_dir): + """Gets a collection of `InputExample`s for the train set.""" + raise NotImplementedError() + + def get_dev_examples(self, data_dir): + """Gets a collection of `InputExample`s for the dev set.""" + raise NotImplementedError() + + def get_test_examples(self, data_dir): + """Gets a collection of `InputExample`s for the test set.""" + raise NotImplementedError() + + def get_labels(self): + """Gets the list of labels for this data set.""" + raise NotImplementedError() + + +class RaceProcessor(DataProcessor): + """Processor for the RACE data set.""" + + def get_train_examples(self, data_dir): + """See base class.""" + logger.info("LOOKING AT {} train".format(data_dir)) + high = os.path.join(data_dir, "train/high") + middle = os.path.join(data_dir, "train/middle") + high = self._read_txt(high) + middle = self._read_txt(middle) + return self._create_examples(high + middle, "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + logger.info("LOOKING AT {} dev".format(data_dir)) + high = os.path.join(data_dir, "dev/high") + middle = os.path.join(data_dir, "dev/middle") + high = self._read_txt(high) + middle = self._read_txt(middle) + return self._create_examples(high + middle, "dev") + + def get_test_examples(self, data_dir): + """See base class.""" + logger.info("LOOKING AT {} test".format(data_dir)) + high = os.path.join(data_dir, "test/high") + middle = os.path.join(data_dir, "test/middle") + high = self._read_txt(high) + middle = self._read_txt(middle) + return self._create_examples(high + middle, "test") + + def get_labels(self): + """See base class.""" + return ["0", "1", "2", "3"] + + def _read_txt(self, input_dir): + lines = [] + files = glob.glob(input_dir + "/*txt") + for file in tqdm.tqdm(files, desc="read files"): + with open(file, "r", encoding="utf-8") as fin: + data_raw = json.load(fin) + data_raw["race_id"] = file + lines.append(data_raw) + return lines + + def _create_examples(self, lines, set_type): + """Creates examples for the training and dev sets.""" + examples = [] + for (_, data_raw) in enumerate(lines): + race_id = "%s-%s" % (set_type, data_raw["race_id"]) + article = data_raw["article"] + for i in range(len(data_raw["answers"])): + truth = str(ord(data_raw["answers"][i]) - ord("A")) + question = data_raw["questions"][i] + options = data_raw["options"][i] + + examples.append( + InputExample( + example_id=race_id, + question=question, + contexts=[article, article, article, article], # this is not efficient but convenient + endings=[options[0], options[1], options[2], options[3]], + label=truth, + ) + ) + return examples + + +class SynonymProcessor(DataProcessor): + """Processor for the Synonym data set.""" + + def get_train_examples(self, data_dir): + """See base class.""" + logger.info("LOOKING AT {} train".format(data_dir)) + return self._create_examples(self._read_csv(os.path.join(data_dir, "mctrain.csv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + logger.info("LOOKING AT {} dev".format(data_dir)) + return self._create_examples(self._read_csv(os.path.join(data_dir, "mchp.csv")), "dev") + + def get_test_examples(self, data_dir): + """See base class.""" + logger.info("LOOKING AT {} dev".format(data_dir)) + + return self._create_examples(self._read_csv(os.path.join(data_dir, "mctest.csv")), "test") + + def get_labels(self): + """See base class.""" + return ["0", "1", "2", "3", "4"] + + def _read_csv(self, input_file): + with open(input_file, "r", encoding="utf-8") as f: + return list(csv.reader(f)) + + def _create_examples(self, lines: List[List[str]], type: str): + """Creates examples for the training and dev sets.""" + + examples = [ + InputExample( + example_id=line[0], + question="", # in the swag dataset, the + # common beginning of each + # choice is stored in "sent2". + contexts=[line[1], line[1], line[1], line[1], line[1]], + endings=[line[2], line[3], line[4], line[5], line[6]], + label=line[7], + ) + for line in lines # we skip the line with the column names + ] + + return examples + + +class SwagProcessor(DataProcessor): + """Processor for the SWAG data set.""" + + def get_train_examples(self, data_dir): + """See base class.""" + logger.info("LOOKING AT {} train".format(data_dir)) + return self._create_examples(self._read_csv(os.path.join(data_dir, "train.csv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + logger.info("LOOKING AT {} dev".format(data_dir)) + return self._create_examples(self._read_csv(os.path.join(data_dir, "val.csv")), "dev") + + def get_test_examples(self, data_dir): + """See base class.""" + logger.info("LOOKING AT {} dev".format(data_dir)) + raise ValueError( + "For swag testing, the input file does not contain a label column. It can not be tested in current code" + "setting!" + ) + return self._create_examples(self._read_csv(os.path.join(data_dir, "test.csv")), "test") + + def get_labels(self): + """See base class.""" + return ["0", "1", "2", "3"] + + def _read_csv(self, input_file): + with open(input_file, "r", encoding="utf-8") as f: + return list(csv.reader(f)) + + def _create_examples(self, lines: List[List[str]], type: str): + """Creates examples for the training and dev sets.""" + if type == "train" and lines[0][-1] != "label": + raise ValueError("For training, the input file must contain a label column.") + + examples = [ + InputExample( + example_id=line[2], + question=line[5], # in the swag dataset, the + # common beginning of each + # choice is stored in "sent2". + contexts=[line[4], line[4], line[4], line[4]], + endings=[line[7], line[8], line[9], line[10]], + label=line[11], + ) + for line in lines[1:] # we skip the line with the column names + ] + + return examples + + +class ArcProcessor(DataProcessor): + """Processor for the ARC data set (request from allennlp).""" + + def get_train_examples(self, data_dir): + """See base class.""" + logger.info("LOOKING AT {} train".format(data_dir)) + return self._create_examples(self._read_json(os.path.join(data_dir, "train.jsonl")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + logger.info("LOOKING AT {} dev".format(data_dir)) + return self._create_examples(self._read_json(os.path.join(data_dir, "dev.jsonl")), "dev") + + def get_test_examples(self, data_dir): + logger.info("LOOKING AT {} test".format(data_dir)) + return self._create_examples(self._read_json(os.path.join(data_dir, "test.jsonl")), "test") + + def get_labels(self): + """See base class.""" + return ["0", "1", "2", "3"] + + def _read_json(self, input_file): + with open(input_file, "r", encoding="utf-8") as fin: + lines = fin.readlines() + return lines + + def _create_examples(self, lines, type): + """Creates examples for the training and dev sets.""" + + # There are two types of labels. They should be normalized + def normalize(truth): + if truth in "ABCD": + return ord(truth) - ord("A") + elif truth in "1234": + return int(truth) - 1 + else: + logger.info("truth ERROR! %s", str(truth)) + return None + + examples = [] + three_choice = 0 + four_choice = 0 + five_choice = 0 + other_choices = 0 + # we deleted example which has more than or less than four choices + for line in tqdm.tqdm(lines, desc="read arc data"): + data_raw = json.loads(line.strip("\n")) + if len(data_raw["question"]["choices"]) == 3: + three_choice += 1 + continue + elif len(data_raw["question"]["choices"]) == 5: + five_choice += 1 + continue + elif len(data_raw["question"]["choices"]) != 4: + other_choices += 1 + continue + four_choice += 1 + truth = str(normalize(data_raw["answerKey"])) + assert truth != "None" + question_choices = data_raw["question"] + question = question_choices["stem"] + id = data_raw["id"] + options = question_choices["choices"] + if len(options) == 4: + examples.append( + InputExample( + example_id=id, + question=question, + contexts=[ + options[0]["para"].replace("_", ""), + options[1]["para"].replace("_", ""), + options[2]["para"].replace("_", ""), + options[3]["para"].replace("_", ""), + ], + endings=[options[0]["text"], options[1]["text"], options[2]["text"], options[3]["text"]], + label=truth, + ) + ) + + if type == "train": + assert len(examples) > 1 + assert examples[0].label is not None + logger.info("len examples: %s}", str(len(examples))) + logger.info("Three choices: %s", str(three_choice)) + logger.info("Five choices: %s", str(five_choice)) + logger.info("Other choices: %s", str(other_choices)) + logger.info("four choices: %s", str(four_choice)) + + return examples + + +def convert_examples_to_features( + examples: List[InputExample], + label_list: List[str], + max_length: int, + tokenizer: PreTrainedTokenizer, +) -> List[InputFeatures]: + """ + Loads a data file into a list of `InputFeatures` + """ + + label_map = {label: i for i, label in enumerate(label_list)} + + features = [] + for (ex_index, example) in tqdm.tqdm(enumerate(examples), desc="convert examples to features"): + if ex_index % 10000 == 0: + logger.info("Writing example %d of %d" % (ex_index, len(examples))) + choices_inputs = [] + for ending_idx, (context, ending) in enumerate(zip(example.contexts, example.endings)): + text_a = context + if example.question.find("_") != -1: + # this is for cloze question + text_b = example.question.replace("_", ending) + else: + text_b = example.question + " " + ending + + inputs = tokenizer( + text_a, + text_b, + add_special_tokens=True, + max_length=max_length, + padding="max_length", + truncation=True, + return_overflowing_tokens=True, + ) + if "num_truncated_tokens" in inputs and inputs["num_truncated_tokens"] > 0: + logger.info( + "Attention! you are cropping tokens (swag task is ok). " + "If you are training ARC and RACE and you are poping question + options," + "you need to try to use a bigger max seq length!" + ) + + choices_inputs.append(inputs) + + label = label_map[example.label] + + input_ids = [x["input_ids"] for x in choices_inputs] + attention_mask = ( + [x["attention_mask"] for x in choices_inputs] if "attention_mask" in choices_inputs[0] else None + ) + token_type_ids = ( + [x["token_type_ids"] for x in choices_inputs] if "token_type_ids" in choices_inputs[0] else None + ) + + features.append( + InputFeatures( + example_id=example.example_id, + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + label=label, + ) + ) + + for f in features[:2]: + logger.info("*** Example ***") + logger.info("feature: %s" % f) + + return features + + +processors = {"race": RaceProcessor, "swag": SwagProcessor, "arc": ArcProcessor, "syn": SynonymProcessor} +MULTIPLE_CHOICE_TASKS_NUM_LABELS = {"race", 4, "swag", 4, "arc", 4, "syn", 5} diff --git a/examples/multiple-choice/README.md b/examples/multiple-choice/README.md index 3d0a643cd8..34d1dfee13 100644 --- a/examples/multiple-choice/README.md +++ b/examples/multiple-choice/README.md @@ -16,27 +16,20 @@ limitations under the License. ## Multiple Choice -Based on the script [`run_multiple_choice.py`](). +Based on the script [`run_swag.py`](). #### Fine-tuning on SWAG -Download [swag](https://github.com/rowanz/swagaf/tree/master/data) data ```bash -#training on 4 tesla V100(16GB) GPUS -export SWAG_DIR=/path/to/swag_data_dir -python ./examples/multiple-choice/run_multiple_choice.py \ ---task_name swag \ +python examples/multiple-choice/run_swag.py \ --model_name_or_path roberta-base \ --do_train \ --do_eval \ ---data_dir $SWAG_DIR \ --learning_rate 5e-5 \ --num_train_epochs 3 \ ---max_seq_length 80 \ ---output_dir models_bert/swag_base \ +--output_dir /tmp/swag_base \ --per_gpu_eval_batch_size=16 \ --per_device_train_batch_size=16 \ ---gradient_accumulation_steps 2 \ --overwrite_output ``` Training with the defined hyper-parameters yields the following results: diff --git a/examples/multiple-choice/run_swag.py b/examples/multiple-choice/run_swag.py new file mode 100644 index 0000000000..a8e232d90e --- /dev/null +++ b/examples/multiple-choice/run_swag.py @@ -0,0 +1,349 @@ +# coding=utf-8 +# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for multiple choice. +""" +# You can also adapt this script on your own multiple choice task. Pointers for this are left as comments. + +import logging +import os +import sys +from dataclasses import dataclass, field +from typing import Optional, Union + +import numpy as np +import torch +from datasets import load_dataset + +import transformers +from transformers import ( + AutoConfig, + AutoModelForMultipleChoice, + AutoTokenizer, + HfArgumentParser, + Trainer, + TrainingArguments, + default_data_collator, + set_seed, +) +from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase +from transformers.trainer_utils import is_main_process + + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) + validation_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_seq_length: int = field( + default=None, + metadata={ + "help": "The maximum total input sequence length after tokenization. If passed, sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + pad_to_max_length: bool = field( + default=False, + metadata={ + "help": "Whether to pad all samples to the maximum sentence length. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " + "efficient on GPU but very bad for TPU." + }, + ) + + def __post_init__(self): + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + + +@dataclass +class DataCollatorForMultipleChoice: + """ + Data collator that will dynamically pad the inputs for multiple choice received. + + Args: + tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): + The tokenizer used for encoding the data. + padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding index) + among: + + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + max_length (:obj:`int`, `optional`): + Maximum length of the returned list and optionally padding length (see above). + pad_to_multiple_of (:obj:`int`, `optional`): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + """ + + tokenizer: PreTrainedTokenizerBase + padding: Union[bool, str, PaddingStrategy] = True + max_length: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + + def __call__(self, features): + label_name = "label" if "label" in features[0].keys() else "labels" + labels = [feature.pop(label_name) for feature in features] + batch_size = len(features) + num_choices = len(features[0]["input_ids"]) + flattened_features = [ + [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features + ] + flattened_features = sum(flattened_features, []) + + batch = self.tokenizer.pad( + flattened_features, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors="pt", + ) + + # Un-flatten + batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()} + # Add back labels + batch["labels"] = torch.tensor(labels, dtype=torch.int64) + return batch + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if ( + os.path.exists(training_args.output_dir) + and os.listdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty." + "Use --overwrite_output_dir to overcome." + ) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN, + ) + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + # Set the verbosity to info of the Transformers logger (on main process only): + if is_main_process(training_args.local_rank): + transformers.utils.logging.set_verbosity_info() + logger.info("Training/evaluation parameters %s", training_args) + + # Set seed before initializing model. + set_seed(training_args.seed) + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if data_args.train_file is not None or data_args.validation_file is not None: + data_files = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + if data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file + extension = data_args.train_file.split(".")[-1] + datasets = load_dataset(extension, data_files=data_files) + else: + # Downloading and loading the swag dataset from the hub. + datasets = load_dataset("swag", "regular") + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + ) + model = AutoModelForMultipleChoice.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + ) + + # When using your own dataset or a different dataset from swag, you will probably need to change this. + ending_names = [f"ending{i}" for i in range(4)] + context_name = "sent1" + question_header_name = "sent2" + + # Preprocessing the datasets. + def preprocess_function(examples): + first_sentences = [[context] * 4 for context in examples[context_name]] + question_headers = examples[question_header_name] + second_sentences = [ + [f"{header} {examples[end][i]}" for end in ending_names] for i, header in enumerate(question_headers) + ] + + # Flatten out + first_sentences = sum(first_sentences, []) + second_sentences = sum(second_sentences, []) + + # Tokenize + tokenized_examples = tokenizer( + first_sentences, + second_sentences, + truncation=True, + max_length=data_args.max_seq_length, + padding="max_length" if data_args.pad_to_max_length else False, + ) + # Un-flatten + return {k: [v[i : i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()} + + tokenized_datasets = datasets.map( + preprocess_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=not data_args.overwrite_cache, + ) + + # Data collator + data_collator = ( + default_data_collator if data_args.pad_to_max_length else DataCollatorForMultipleChoice(tokenizer=tokenizer) + ) + + # Metric + def compute_metrics(eval_predictions): + predictions, label_ids = eval_predictions + preds = np.argmax(predictions, axis=1) + return {"accuracy": (preds == label_ids).astype(np.float32).mean().item()} + + # Initialize our Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized_datasets["train"] if training_args.do_train else None, + eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics, + ) + + # Training + if training_args.do_train: + trainer.train( + model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None + ) + trainer.save_model() # Saves the tokenizer too for easy upload + + # Evaluation + results = {} + if training_args.do_eval: + logger.info("*** Evaluate ***") + + results = trainer.evaluate() + + output_eval_file = os.path.join(training_args.output_dir, "eval_results_swag.txt") + if trainer.is_world_process_zero(): + with open(output_eval_file, "w") as writer: + logger.info("***** Eval results *****") + for key, value in results.items(): + logger.info(f" {key} = {value}") + writer.write(f"{key} = {value}\n") + + return results + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/examples/test_examples.py b/examples/test_examples.py index 1b5811255d..e4ef9da6ae 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -33,6 +33,7 @@ SRC_DIRS = [ "text-classification", "token-classification", "language-modeling", + "multiple-choice", "question-answering", ] ] @@ -46,6 +47,7 @@ if SRC_DIRS is not None: import run_mlm import run_ner import run_qa as run_squad + import run_swag logging.basicConfig(level=logging.DEBUG) @@ -216,6 +218,32 @@ class ExamplesTests(TestCasePlus): self.assertGreaterEqual(result["f1"], 30) self.assertGreaterEqual(result["exact"], 30) + @require_torch_non_multi_gpu_but_fix_me + def test_run_swag(self): + stream_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stream_handler) + + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" + run_swag.py + --model_name_or_path bert-base-uncased + --train_file tests/fixtures/tests_samples/swag/sample.json + --validation_file tests/fixtures/tests_samples/swag/sample.json + --output_dir {tmp_dir} + --overwrite_output_dir + --max_steps=20 + --warmup_steps=2 + --do_train + --do_eval + --learning_rate=2e-4 + --per_device_train_batch_size=2 + --per_device_eval_batch_size=1 + """.split() + + with patch.object(sys, "argv", testargs): + result = run_swag.main() + self.assertGreaterEqual(result["eval_accuracy"], 0.8) + @require_torch_non_multi_gpu_but_fix_me def test_generation(self): stream_handler = logging.StreamHandler(sys.stdout) diff --git a/tests/fixtures/tests_samples/swag/sample.json b/tests/fixtures/tests_samples/swag/sample.json new file mode 100644 index 0000000000..d00ad8d184 --- /dev/null +++ b/tests/fixtures/tests_samples/swag/sample.json @@ -0,0 +1,10 @@ +{"ending0": "passes by walking down the street playing their instruments.", "ending1": "has heard approaching them.", "ending2": "arrives and they're outside dancing and asleep.", "ending3": "turns the lead singer watches the performance.", "label": 0, "sent1": "Members of the procession walk down the street holding small horn brass instruments.", "sent2": "A drum line"} +{"ending0": "are playing ping pong and celebrating one left each in quick.", "ending1": "wait slowly towards the cadets.", "ending2": "continues to play as well along the crowd along with the band being interviewed.", "ending3": "continue to play marching, interspersed.", "label": 3, "sent1": "A drum line passes by walking down the street playing their instruments.", "sent2": "Members of the procession"} +{"ending0": "pay the other coaches to cheer as people this chatter dips in lawn sheets.", "ending1": "walk down the street holding small horn brass instruments.", "ending2": "is seen in the background.", "ending3": "are talking a couple of people playing a game of tug of war.", "label": 1, "sent1": "A group of members in green uniforms walks waving flags.", "sent2": "Members of the procession"} +{"ending0": "are playing ping pong and celebrating one left each in quick.", "ending1": "wait slowly towards the cadets.", "ending2": "makes a square call and ends by jumping down into snowy streets where fans begin to take their positions.", "ending3": "play and go back and forth hitting the drums while the audience claps for them.", "label": 3, "sent1": "A drum line passes by walking down the street playing their instruments.", "sent2": "Members of the procession"} +{"ending0": "finishes the song and lowers the instrument.", "ending1": "hits the saxophone and demonstrates how to properly use the racquet.", "ending2": "finishes massage the instrument again and continues.", "ending3": "continues dancing while the man gore the music outside while drums.", "label": 0, "sent1": "The person plays a song on the violin.", "sent2": "The man"} +{"ending0": "finishes playing then marches their tenderly.", "ending1": "walks in frame and rubs on his hands, and then walks into a room.", "ending2": "continues playing guitar while moving from the camera.", "ending3": "plays a song on the violin.", "label": 3, "sent1": "The person holds up the violin to his chin and gets ready.", "sent2": "The person"} +{"ending0": "examines the instrument in his hand.", "ending1": "stops playing the drums and waves over the other boys.", "ending2": "lights the cigarette and sticks his head in.", "ending3": "drags off the vacuum.", "label": 0, "sent1": "A person retrieves an instrument from a closet.", "sent2": "The man"} +{"ending0": "studies a picture of the man playing the violin.", "ending1": "holds up the violin to his chin and gets ready.", "ending2": "stops to speak to the camera again.", "ending3": "puts his arm around the man and backs away.", "label": 1, "sent1": "The man examines the instrument in his hand.", "sent2": "The person"} +{"ending0": "hands her another phone.", "ending1": "takes the drink, then holds it.", "ending2": "looks off then looks at someone.", "ending3": "stares blearily down at the floor.", "label": 3, "sent1": "Someone walks over to the radio.", "sent2": "Someone"} +{"ending0": "looks off then looks at someone.", "ending1": "hands her another phone.", "ending2": "takes the drink, then holds it.", "ending3": "turns on a monitor.", "label": 3, "sent1": "Someone walks over to the radio.", "sent2": "Someone"}