Reorganize examples (#9010)
* Reorganize example folder * Continue reorganization * Change requirements for tests * Final cleanup * Finish regroup with tests all passing * Copyright * Requirements and readme * Make a full link for the documentation * Address review comments * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Add symlink * Reorg again * Apply suggestions from code review Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com> * Adapt title * Update to new strucutre * Remove test * Update READMEs Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
28
examples/research_projects/README.md
Normal file
28
examples/research_projects/README.md
Normal file
@@ -0,0 +1,28 @@
|
||||
<!---
|
||||
Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
# Research projects
|
||||
|
||||
This folder contains various research projects using 🤗 Transformers. They are not maintained and require a specific
|
||||
version of 🤗 Transformers that is indicated in the requirements file of each folder. Updating them to the most recent version of the library will require some work.
|
||||
|
||||
To use any of them, just run the command
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
inside the folder of your choice.
|
||||
|
||||
If you need help with any of those, contact the author(s), indicated at the top of the `README` of each folder.
|
||||
38
examples/research_projects/adversarial/README.md
Normal file
38
examples/research_projects/adversarial/README.md
Normal file
@@ -0,0 +1,38 @@
|
||||
## Adversarial evaluation of model performances
|
||||
|
||||
Here is an example on evaluating a model using adversarial evaluation of natural language inference with the Heuristic Analysis for NLI Systems (HANS) dataset [McCoy et al., 2019](https://arxiv.org/abs/1902.01007). The example was gracefully provided by [Nafise Sadat Moosavi](https://github.com/ns-moosavi).
|
||||
|
||||
The HANS dataset can be downloaded from [this location](https://github.com/tommccoy1/hans).
|
||||
|
||||
This is an example of using test_hans.py:
|
||||
|
||||
```bash
|
||||
export HANS_DIR=path-to-hans
|
||||
export MODEL_TYPE=type-of-the-model-e.g.-bert-roberta-xlnet-etc
|
||||
export MODEL_PATH=path-to-the-model-directory-that-is-trained-on-NLI-e.g.-by-using-run_glue.py
|
||||
|
||||
python run_hans.py \
|
||||
--task_name hans \
|
||||
--model_type $MODEL_TYPE \
|
||||
--do_eval \
|
||||
--data_dir $HANS_DIR \
|
||||
--model_name_or_path $MODEL_PATH \
|
||||
--max_seq_length 128 \
|
||||
--output_dir $MODEL_PATH \
|
||||
```
|
||||
|
||||
This will create the hans_predictions.txt file in MODEL_PATH, which can then be evaluated using hans/evaluate_heur_output.py from the HANS dataset.
|
||||
|
||||
The results of the BERT-base model that is trained on MNLI using batch size 8 and the random seed 42 on the HANS dataset is as follows:
|
||||
|
||||
```bash
|
||||
Heuristic entailed results:
|
||||
lexical_overlap: 0.9702
|
||||
subsequence: 0.9942
|
||||
constituent: 0.9962
|
||||
|
||||
Heuristic non-entailed results:
|
||||
lexical_overlap: 0.199
|
||||
subsequence: 0.0396
|
||||
constituent: 0.118
|
||||
```
|
||||
1
examples/research_projects/adversarial/requirements.txt
Normal file
1
examples/research_projects/adversarial/requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
transformers == 3.5.1
|
||||
239
examples/research_projects/adversarial/run_hans.py
Normal file
239
examples/research_projects/adversarial/run_hans.py
Normal file
@@ -0,0 +1,239 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Finetuning the library models for sequence classification on HANS."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
default_data_collator,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.trainer_utils import is_main_process
|
||||
from utils_hans import HansDataset, InputFeatures, hans_processors, hans_tasks_num_labels
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
task_name: str = field(
|
||||
metadata={"help": "The name of the task to train selected in the list: " + ", ".join(hans_processors.keys())}
|
||||
)
|
||||
data_dir: str = field(
|
||||
metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."}
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
|
||||
|
||||
def hans_data_collator(features: List[InputFeatures]) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Data collator that removes the "pairID" key if present.
|
||||
"""
|
||||
batch = default_data_collator(features)
|
||||
_ = batch.pop("pairID", None)
|
||||
return batch
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
if (
|
||||
os.path.exists(training_args.output_dir)
|
||||
and os.listdir(training_args.output_dir)
|
||||
and training_args.do_train
|
||||
and not training_args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
training_args.local_rank,
|
||||
training_args.device,
|
||||
training_args.n_gpu,
|
||||
bool(training_args.local_rank != -1),
|
||||
training_args.fp16,
|
||||
)
|
||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||
if is_main_process(training_args.local_rank):
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
logger.info("Training/evaluation parameters %s", training_args)
|
||||
|
||||
# Set seed
|
||||
set_seed(training_args.seed)
|
||||
|
||||
try:
|
||||
num_labels = hans_tasks_num_labels[data_args.task_name]
|
||||
except KeyError:
|
||||
raise ValueError("Task not found: %s" % (data_args.task_name))
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
#
|
||||
# Distributed training:
|
||||
# The .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=data_args.task_name,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
|
||||
# Get datasets
|
||||
train_dataset = (
|
||||
HansDataset(
|
||||
data_dir=data_args.data_dir,
|
||||
tokenizer=tokenizer,
|
||||
task=data_args.task_name,
|
||||
max_seq_length=data_args.max_seq_length,
|
||||
overwrite_cache=data_args.overwrite_cache,
|
||||
)
|
||||
if training_args.do_train
|
||||
else None
|
||||
)
|
||||
eval_dataset = (
|
||||
HansDataset(
|
||||
data_dir=data_args.data_dir,
|
||||
tokenizer=tokenizer,
|
||||
task=data_args.task_name,
|
||||
max_seq_length=data_args.max_seq_length,
|
||||
overwrite_cache=data_args.overwrite_cache,
|
||||
evaluate=True,
|
||||
)
|
||||
if training_args.do_eval
|
||||
else None
|
||||
)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
data_collator=hans_data_collator,
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
trainer.train(
|
||||
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
|
||||
)
|
||||
trainer.save_model()
|
||||
# For convenience, we also re-save the tokenizer to the same directory,
|
||||
# so that you can share your model easily on huggingface.co/models =)
|
||||
if trainer.is_world_master():
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
output = trainer.predict(eval_dataset)
|
||||
preds = output.predictions
|
||||
preds = np.argmax(preds, axis=1)
|
||||
|
||||
pair_ids = [ex.pairID for ex in eval_dataset]
|
||||
output_eval_file = os.path.join(training_args.output_dir, "hans_predictions.txt")
|
||||
label_list = eval_dataset.get_labels()
|
||||
if trainer.is_world_master():
|
||||
with open(output_eval_file, "w") as writer:
|
||||
writer.write("pairID,gold_label\n")
|
||||
for pid, pred in zip(pair_ids, preds):
|
||||
writer.write("ex" + str(pid) + "," + label_list[int(pred)] + "\n")
|
||||
|
||||
trainer._log(output.metrics)
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
340
examples/research_projects/adversarial/utils_hans.py
Normal file
340
examples/research_projects/adversarial/utils_hans.py
Normal file
@@ -0,0 +1,340 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import tqdm
|
||||
|
||||
from filelock import FileLock
|
||||
from transformers import (
|
||||
BartTokenizer,
|
||||
BartTokenizerFast,
|
||||
DataProcessor,
|
||||
PreTrainedTokenizer,
|
||||
RobertaTokenizer,
|
||||
RobertaTokenizerFast,
|
||||
XLMRobertaTokenizer,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InputExample:
|
||||
"""
|
||||
A single training/test example for simple sequence classification.
|
||||
|
||||
Args:
|
||||
guid: Unique id for the example.
|
||||
text_a: string. The untokenized text of the first sequence. For single
|
||||
sequence tasks, only this sequence must be specified.
|
||||
text_b: (Optional) string. The untokenized text of the second sequence.
|
||||
Only must be specified for sequence pair tasks.
|
||||
label: (Optional) string. The label of the example. This should be
|
||||
specified for train and dev examples, but not for test examples.
|
||||
pairID: (Optional) string. Unique identifier for the pair of sentences.
|
||||
"""
|
||||
|
||||
guid: str
|
||||
text_a: str
|
||||
text_b: Optional[str] = None
|
||||
label: Optional[str] = None
|
||||
pairID: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InputFeatures:
|
||||
"""
|
||||
A single set of features of data.
|
||||
Property names are the same names as the corresponding inputs to a model.
|
||||
|
||||
Args:
|
||||
input_ids: Indices of input sequence tokens in the vocabulary.
|
||||
attention_mask: Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
Usually ``1`` for tokens that are NOT MASKED, ``0`` for MASKED (padded) tokens.
|
||||
token_type_ids: (Optional) Segment token indices to indicate first and second
|
||||
portions of the inputs. Only some models use them.
|
||||
label: (Optional) Label corresponding to the input. Int for classification problems,
|
||||
float for regression problems.
|
||||
pairID: (Optional) Unique identifier for the pair of sentences.
|
||||
"""
|
||||
|
||||
input_ids: List[int]
|
||||
attention_mask: Optional[List[int]] = None
|
||||
token_type_ids: Optional[List[int]] = None
|
||||
label: Optional[Union[int, float]] = None
|
||||
pairID: Optional[int] = None
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
class HansDataset(Dataset):
|
||||
"""
|
||||
This will be superseded by a framework-agnostic approach
|
||||
soon.
|
||||
"""
|
||||
|
||||
features: List[InputFeatures]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
task: str,
|
||||
max_seq_length: Optional[int] = None,
|
||||
overwrite_cache=False,
|
||||
evaluate: bool = False,
|
||||
):
|
||||
processor = hans_processors[task]()
|
||||
|
||||
cached_features_file = os.path.join(
|
||||
data_dir,
|
||||
"cached_{}_{}_{}_{}".format(
|
||||
"dev" if evaluate else "train",
|
||||
tokenizer.__class__.__name__,
|
||||
str(max_seq_length),
|
||||
task,
|
||||
),
|
||||
)
|
||||
label_list = processor.get_labels()
|
||||
if tokenizer.__class__ in (
|
||||
RobertaTokenizer,
|
||||
RobertaTokenizerFast,
|
||||
XLMRobertaTokenizer,
|
||||
BartTokenizer,
|
||||
BartTokenizerFast,
|
||||
):
|
||||
# HACK(label indices are swapped in RoBERTa pretrained model)
|
||||
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||
self.label_list = label_list
|
||||
|
||||
# Make sure only the first process in distributed training processes the dataset,
|
||||
# and the others will use the cache.
|
||||
lock_path = cached_features_file + ".lock"
|
||||
with FileLock(lock_path):
|
||||
|
||||
if os.path.exists(cached_features_file) and not overwrite_cache:
|
||||
logger.info(f"Loading features from cached file {cached_features_file}")
|
||||
self.features = torch.load(cached_features_file)
|
||||
else:
|
||||
logger.info(f"Creating features from dataset file at {data_dir}")
|
||||
|
||||
examples = (
|
||||
processor.get_dev_examples(data_dir) if evaluate else processor.get_train_examples(data_dir)
|
||||
)
|
||||
|
||||
logger.info("Training examples: %s", len(examples))
|
||||
self.features = hans_convert_examples_to_features(examples, label_list, max_seq_length, tokenizer)
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(self.features, cached_features_file)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.features)
|
||||
|
||||
def __getitem__(self, i) -> InputFeatures:
|
||||
return self.features[i]
|
||||
|
||||
def get_labels(self):
|
||||
return self.label_list
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
class TFHansDataset:
|
||||
"""
|
||||
This will be superseded by a framework-agnostic approach
|
||||
soon.
|
||||
"""
|
||||
|
||||
features: List[InputFeatures]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
task: str,
|
||||
max_seq_length: Optional[int] = 128,
|
||||
overwrite_cache=False,
|
||||
evaluate: bool = False,
|
||||
):
|
||||
processor = hans_processors[task]()
|
||||
label_list = processor.get_labels()
|
||||
if tokenizer.__class__ in (
|
||||
RobertaTokenizer,
|
||||
RobertaTokenizerFast,
|
||||
XLMRobertaTokenizer,
|
||||
BartTokenizer,
|
||||
BartTokenizerFast,
|
||||
):
|
||||
# HACK(label indices are swapped in RoBERTa pretrained model)
|
||||
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||
self.label_list = label_list
|
||||
|
||||
examples = processor.get_dev_examples(data_dir) if evaluate else processor.get_train_examples(data_dir)
|
||||
self.features = hans_convert_examples_to_features(examples, label_list, max_seq_length, tokenizer)
|
||||
|
||||
def gen():
|
||||
for (ex_index, ex) in tqdm.tqdm(enumerate(self.features), desc="convert examples to features"):
|
||||
if ex_index % 10000 == 0:
|
||||
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
|
||||
|
||||
yield (
|
||||
{
|
||||
"example_id": 0,
|
||||
"input_ids": ex.input_ids,
|
||||
"attention_mask": ex.attention_mask,
|
||||
"token_type_ids": ex.token_type_ids,
|
||||
},
|
||||
ex.label,
|
||||
)
|
||||
|
||||
self.dataset = tf.data.Dataset.from_generator(
|
||||
gen,
|
||||
(
|
||||
{
|
||||
"example_id": tf.int32,
|
||||
"input_ids": tf.int32,
|
||||
"attention_mask": tf.int32,
|
||||
"token_type_ids": tf.int32,
|
||||
},
|
||||
tf.int64,
|
||||
),
|
||||
(
|
||||
{
|
||||
"example_id": tf.TensorShape([]),
|
||||
"input_ids": tf.TensorShape([None, None]),
|
||||
"attention_mask": tf.TensorShape([None, None]),
|
||||
"token_type_ids": tf.TensorShape([None, None]),
|
||||
},
|
||||
tf.TensorShape([]),
|
||||
),
|
||||
)
|
||||
|
||||
def get_dataset(self):
|
||||
return self.dataset
|
||||
|
||||
def __len__(self):
|
||||
return len(self.features)
|
||||
|
||||
def __getitem__(self, i) -> InputFeatures:
|
||||
return self.features[i]
|
||||
|
||||
def get_labels(self):
|
||||
return self.label_list
|
||||
|
||||
|
||||
class HansProcessor(DataProcessor):
|
||||
"""Processor for the HANS data set."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "heuristics_train_set.txt")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "heuristics_evaluation_set.txt")), "dev")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class.
|
||||
Note that we follow the standard three labels for MNLI
|
||||
(see :class:`~transformers.data.processors.utils.MnliProcessor`)
|
||||
but the HANS evaluation groups `contradiction` and `neutral` into `non-entailment` (label 0) while
|
||||
`entailment` is label 1."""
|
||||
return ["contradiction", "entailment", "neutral"]
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
continue
|
||||
guid = "%s-%s" % (set_type, line[0])
|
||||
text_a = line[5]
|
||||
text_b = line[6]
|
||||
pairID = line[7][2:] if line[7].startswith("ex") else line[7]
|
||||
label = line[0]
|
||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, pairID=pairID))
|
||||
return examples
|
||||
|
||||
|
||||
def hans_convert_examples_to_features(
|
||||
examples: List[InputExample],
|
||||
label_list: List[str],
|
||||
max_length: int,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
):
|
||||
"""
|
||||
Loads a data file into a list of ``InputFeatures``
|
||||
|
||||
Args:
|
||||
examples: List of ``InputExamples`` containing the examples.
|
||||
label_list: List of labels. Can be obtained from the processor using the ``processor.get_labels()`` method.
|
||||
max_length: Maximum example length.
|
||||
tokenizer: Instance of a tokenizer that will tokenize the examples.
|
||||
|
||||
Returns:
|
||||
A list of task-specific ``InputFeatures`` which can be fed to the model.
|
||||
|
||||
"""
|
||||
|
||||
label_map = {label: i for i, label in enumerate(label_list)}
|
||||
|
||||
features = []
|
||||
for (ex_index, example) in tqdm.tqdm(enumerate(examples), desc="convert examples to features"):
|
||||
if ex_index % 10000 == 0:
|
||||
logger.info("Writing example %d" % (ex_index))
|
||||
|
||||
inputs = tokenizer(
|
||||
example.text_a,
|
||||
example.text_b,
|
||||
add_special_tokens=True,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_overflowing_tokens=True,
|
||||
)
|
||||
|
||||
label = label_map[example.label] if example.label in label_map else 0
|
||||
|
||||
pairID = int(example.pairID)
|
||||
|
||||
features.append(InputFeatures(**inputs, label=label, pairID=pairID))
|
||||
|
||||
for i, example in enumerate(examples[:5]):
|
||||
logger.info("*** Example ***")
|
||||
logger.info(f"guid: {example}")
|
||||
logger.info(f"features: {features[i]}")
|
||||
|
||||
return features
|
||||
|
||||
|
||||
hans_tasks_num_labels = {
|
||||
"hans": 3,
|
||||
}
|
||||
|
||||
hans_processors = {
|
||||
"hans": HansProcessor,
|
||||
}
|
||||
89
examples/research_projects/bert-loses-patience/README.md
Executable file
89
examples/research_projects/bert-loses-patience/README.md
Executable file
@@ -0,0 +1,89 @@
|
||||
# Patience-based Early Exit
|
||||
|
||||
Patience-based Early Exit (PABEE) is a plug-and-play inference method for pretrained language models.
|
||||
We have already implemented it on BERT and ALBERT. Basically, you can make your LM faster and more robust with PABEE. It can even improve the performance of ALBERT on GLUE. The only sacrifice is that the batch size can only be 1.
|
||||
Learn more in the paper ["BERT Loses Patience: Fast and Robust Inference with Early Exit"](https://arxiv.org/abs/2006.04152) and the official [GitHub repo](https://github.com/JetRunner/PABEE).
|
||||
|
||||

|
||||
|
||||
## Training
|
||||
|
||||
You can fine-tune a pretrained language model (you can choose from BERT and ALBERT) and train the internal classifiers by:
|
||||
```bash
|
||||
export GLUE_DIR=/path/to/glue_data
|
||||
export TASK_NAME=MRPC
|
||||
|
||||
python ./run_glue_with_pabee.py \
|
||||
--model_type albert \
|
||||
--model_name_or_path bert-base-uncased/albert-base-v2 \
|
||||
--task_name $TASK_NAME \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--do_lower_case \
|
||||
--data_dir "$GLUE_DIR/$TASK_NAME" \
|
||||
--max_seq_length 128 \
|
||||
--per_gpu_train_batch_size 32 \
|
||||
--per_gpu_eval_batch_size 32 \
|
||||
--learning_rate 2e-5 \
|
||||
--save_steps 50 \
|
||||
--logging_steps 50 \
|
||||
--num_train_epochs 5 \
|
||||
--output_dir /path/to/save/ \
|
||||
--evaluate_during_training
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
You can inference with different patience settings by:
|
||||
```bash
|
||||
export GLUE_DIR=/path/to/glue_data
|
||||
export TASK_NAME=MRPC
|
||||
|
||||
python ./run_glue_with_pabee.py \
|
||||
--model_type albert \
|
||||
--model_name_or_path /path/to/save/ \
|
||||
--task_name $TASK_NAME \
|
||||
--do_eval \
|
||||
--do_lower_case \
|
||||
--data_dir "$GLUE_DIR/$TASK_NAME" \
|
||||
--max_seq_length 128 \
|
||||
--per_gpu_eval_batch_size 1 \
|
||||
--learning_rate 2e-5 \
|
||||
--logging_steps 50 \
|
||||
--num_train_epochs 15 \
|
||||
--output_dir /path/to/save/ \
|
||||
--eval_all_checkpoints \
|
||||
--patience 3,4,5,6,7,8
|
||||
```
|
||||
where `patience` can be a list of patience settings, separated by a comma. It will help determine which patience works best.
|
||||
|
||||
When evaluating on a regression task (STS-B), you may add `--regression_threshold 0.1` to define the regression threshold.
|
||||
|
||||
## Results
|
||||
On the GLUE dev set:
|
||||
|
||||
| Model | \#Param | Speed | CoLA | MNLI | MRPC | QNLI | QQP | RTE | SST\-2 | STS\-B |
|
||||
|--------------|---------|--------|-------|-------|-------|-------|-------|-------|--------|--------|
|
||||
| ALBERT\-base | 12M | | 58\.9 | 84\.6 | 89\.5 | 91\.7 | 89\.6 | 78\.6 | 92\.8 | 89\.5 |
|
||||
| \+PABEE | 12M | 1\.57x | 61\.2 | 85\.1 | 90\.0 | 91\.8 | 89\.6 | 80\.1 | 93\.0 | 90\.1 |
|
||||
|
||||
| Model | \#Param | Speed\-up | MNLI | SST\-2 | STS\-B |
|
||||
|---------------|---------|-----------|-------|--------|--------|
|
||||
| BERT\-base | 108M | | 84\.5 | 92\.1 | 88\.9 |
|
||||
| \+PABEE | 108M | 1\.62x | 83\.6 | 92\.0 | 88\.7 |
|
||||
| ALBERT\-large | 18M | | 86\.4 | 94\.9 | 90\.4 |
|
||||
| \+PABEE | 18M | 2\.42x | 86\.8 | 95\.2 | 90\.6 |
|
||||
|
||||
|
||||
## Citation
|
||||
If you find this resource useful, please consider citing the following paper:
|
||||
```bibtex
|
||||
@misc{zhou2020bert,
|
||||
title={BERT Loses Patience: Fast and Robust Inference with Early Exit},
|
||||
author={Wangchunshu Zhou and Canwen Xu and Tao Ge and Julian McAuley and Ke Xu and Furu Wei},
|
||||
year={2020},
|
||||
eprint={2006.04152},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL}
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,316 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 Google AI, Google Brain, the HuggingFace Inc. team and Microsoft Corporation.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch ALBERT model with Patience-based Early Exit. """
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from transformers.models.albert.modeling_albert import (
|
||||
ALBERT_INPUTS_DOCSTRING,
|
||||
ALBERT_START_DOCSTRING,
|
||||
AlbertModel,
|
||||
AlbertPreTrainedModel,
|
||||
AlbertTransformer,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AlbertTransformerWithPabee(AlbertTransformer):
|
||||
def adaptive_forward(self, hidden_states, current_layer, attention_mask=None, head_mask=None):
|
||||
if current_layer == 0:
|
||||
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
|
||||
else:
|
||||
hidden_states = hidden_states[0]
|
||||
|
||||
layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
|
||||
|
||||
# Index of the hidden group
|
||||
group_idx = int(current_layer / (self.config.num_hidden_layers / self.config.num_hidden_groups))
|
||||
|
||||
layer_group_output = self.albert_layer_groups[group_idx](
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
|
||||
)
|
||||
hidden_states = layer_group_output[0]
|
||||
|
||||
return (hidden_states,)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare ALBERT Model transformer with PABEE outputting raw hidden-states without any specific head on top.",
|
||||
ALBERT_START_DOCSTRING,
|
||||
)
|
||||
class AlbertModelWithPabee(AlbertModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.encoder = AlbertTransformerWithPabee(config)
|
||||
|
||||
self.init_weights()
|
||||
self.patience = 0
|
||||
self.inference_instances_num = 0
|
||||
self.inference_layers_num = 0
|
||||
|
||||
self.regression_threshold = 0
|
||||
|
||||
def set_regression_threshold(self, threshold):
|
||||
self.regression_threshold = threshold
|
||||
|
||||
def set_patience(self, patience):
|
||||
self.patience = patience
|
||||
|
||||
def reset_stats(self):
|
||||
self.inference_instances_num = 0
|
||||
self.inference_layers_num = 0
|
||||
|
||||
def log_stats(self):
|
||||
avg_inf_layers = self.inference_layers_num / self.inference_instances_num
|
||||
message = f"*** Patience = {self.patience} Avg. Inference Layers = {avg_inf_layers:.2f} Speed Up = {1 - avg_inf_layers / self.config.num_hidden_layers:.2f} ***"
|
||||
print(message)
|
||||
|
||||
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_dropout=None,
|
||||
output_layers=None,
|
||||
regression=False,
|
||||
):
|
||||
r"""
|
||||
Return:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
|
||||
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
|
||||
Last layer hidden-state of the first token of the sequence (classification token)
|
||||
further processed by a Linear layer and a Tanh activation function. The Linear
|
||||
layer weights are trained from the next sentence prediction (classification)
|
||||
objective during pre-training.
|
||||
|
||||
This output is usually *not* a good summary
|
||||
of the semantic content of the input, you're often better with averaging or pooling
|
||||
the sequence of hidden-states for the whole input sequence.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(input_shape, device=device)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
||||
)
|
||||
encoder_outputs = embedding_output
|
||||
|
||||
if self.training:
|
||||
res = []
|
||||
for i in range(self.config.num_hidden_layers):
|
||||
encoder_outputs = self.encoder.adaptive_forward(
|
||||
encoder_outputs,
|
||||
current_layer=i,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
)
|
||||
|
||||
pooled_output = self.pooler_activation(self.pooler(encoder_outputs[0][:, 0]))
|
||||
logits = output_layers[i](output_dropout(pooled_output))
|
||||
res.append(logits)
|
||||
elif self.patience == 0: # Use all layers for inference
|
||||
encoder_outputs = self.encoder(encoder_outputs, extended_attention_mask, head_mask=head_mask)
|
||||
pooled_output = self.pooler_activation(self.pooler(encoder_outputs[0][:, 0]))
|
||||
res = [output_layers[self.config.num_hidden_layers - 1](pooled_output)]
|
||||
else:
|
||||
patient_counter = 0
|
||||
patient_result = None
|
||||
calculated_layer_num = 0
|
||||
for i in range(self.config.num_hidden_layers):
|
||||
calculated_layer_num += 1
|
||||
encoder_outputs = self.encoder.adaptive_forward(
|
||||
encoder_outputs,
|
||||
current_layer=i,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
)
|
||||
|
||||
pooled_output = self.pooler_activation(self.pooler(encoder_outputs[0][:, 0]))
|
||||
logits = output_layers[i](pooled_output)
|
||||
if regression:
|
||||
labels = logits.detach()
|
||||
if patient_result is not None:
|
||||
patient_labels = patient_result.detach()
|
||||
if (patient_result is not None) and torch.abs(patient_result - labels) < self.regression_threshold:
|
||||
patient_counter += 1
|
||||
else:
|
||||
patient_counter = 0
|
||||
else:
|
||||
labels = logits.detach().argmax(dim=1)
|
||||
if patient_result is not None:
|
||||
patient_labels = patient_result.detach().argmax(dim=1)
|
||||
if (patient_result is not None) and torch.all(labels.eq(patient_labels)):
|
||||
patient_counter += 1
|
||||
else:
|
||||
patient_counter = 0
|
||||
|
||||
patient_result = logits
|
||||
if patient_counter == self.patience:
|
||||
break
|
||||
res = [patient_result]
|
||||
self.inference_layers_num += calculated_layer_num
|
||||
self.inference_instances_num += 1
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""Albert Model transformer with PABEE and a sequence classification/regression head on top (a linear layer on top of
|
||||
the pooled output) e.g. for GLUE tasks. """,
|
||||
ALBERT_START_DOCSTRING,
|
||||
)
|
||||
class AlbertForSequenceClassificationWithPabee(AlbertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.albert = AlbertModelWithPabee(config)
|
||||
self.dropout = nn.Dropout(config.classifier_dropout_prob)
|
||||
self.classifiers = nn.ModuleList(
|
||||
[nn.Linear(config.hidden_size, self.config.num_labels) for _ in range(config.num_hidden_layers)]
|
||||
)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the sequence classification/regression loss.
|
||||
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
||||
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
|
||||
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Returns:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
|
||||
loss: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
|
||||
Examples::
|
||||
|
||||
from transformers import AlbertTokenizer
|
||||
from pabee import AlbertForSequenceClassificationWithPabee
|
||||
import torch
|
||||
|
||||
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
|
||||
model = AlbertForSequenceClassificationWithPabee.from_pretrained('albert-base-v2')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, labels=labels)
|
||||
loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
|
||||
logits = self.albert(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_dropout=self.dropout,
|
||||
output_layers=self.classifiers,
|
||||
regression=self.num_labels == 1,
|
||||
)
|
||||
|
||||
outputs = (logits[-1],)
|
||||
|
||||
if labels is not None:
|
||||
total_loss = None
|
||||
total_weights = 0
|
||||
for ix, logits_item in enumerate(logits):
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits_item.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits_item.view(-1, self.num_labels), labels.view(-1))
|
||||
if total_loss is None:
|
||||
total_loss = loss
|
||||
else:
|
||||
total_loss += loss * (ix + 1)
|
||||
total_weights += ix + 1
|
||||
outputs = (total_loss / total_weights,) + outputs
|
||||
|
||||
return outputs
|
||||
@@ -0,0 +1,342 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Google AI Language Team Authors, The HuggingFace Inc. team and Microsoft Corporation.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch BERT model with Patience-based Early Exit. """
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from transformers.models.bert.modeling_bert import (
|
||||
BERT_INPUTS_DOCSTRING,
|
||||
BERT_START_DOCSTRING,
|
||||
BertEncoder,
|
||||
BertModel,
|
||||
BertPreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BertEncoderWithPabee(BertEncoder):
|
||||
def adaptive_forward(self, hidden_states, current_layer, attention_mask=None, head_mask=None):
|
||||
layer_outputs = self.layer[current_layer](hidden_states, attention_mask, head_mask[current_layer])
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Bert Model transformer with PABEE outputting raw hidden-states without any specific head on top.",
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class BertModelWithPabee(BertModel):
|
||||
"""
|
||||
|
||||
The model can behave as an encoder (with only self-attention) as well
|
||||
as a decoder, in which case a layer of cross-attention is added between
|
||||
the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani,
|
||||
Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
||||
|
||||
To behave as an decoder the model needs to be initialized with the
|
||||
:obj:`is_decoder` argument of the configuration set to :obj:`True`; an
|
||||
:obj:`encoder_hidden_states` is expected as an input to the forward pass.
|
||||
|
||||
.. _`Attention is all you need`:
|
||||
https://arxiv.org/abs/1706.03762
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.encoder = BertEncoderWithPabee(config)
|
||||
|
||||
self.init_weights()
|
||||
self.patience = 0
|
||||
self.inference_instances_num = 0
|
||||
self.inference_layers_num = 0
|
||||
|
||||
self.regression_threshold = 0
|
||||
|
||||
def set_regression_threshold(self, threshold):
|
||||
self.regression_threshold = threshold
|
||||
|
||||
def set_patience(self, patience):
|
||||
self.patience = patience
|
||||
|
||||
def reset_stats(self):
|
||||
self.inference_instances_num = 0
|
||||
self.inference_layers_num = 0
|
||||
|
||||
def log_stats(self):
|
||||
avg_inf_layers = self.inference_layers_num / self.inference_instances_num
|
||||
message = f"*** Patience = {self.patience} Avg. Inference Layers = {avg_inf_layers:.2f} Speed Up = {1 - avg_inf_layers / self.config.num_hidden_layers:.2f} ***"
|
||||
print(message)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
output_dropout=None,
|
||||
output_layers=None,
|
||||
regression=False,
|
||||
):
|
||||
r"""
|
||||
Return:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
|
||||
Last layer hidden-state of the first token of the sequence (classification token)
|
||||
further processed by a Linear layer and a Tanh activation function. The Linear
|
||||
layer weights are trained from the next sentence prediction (classification)
|
||||
objective during pre-training.
|
||||
|
||||
This output is usually *not* a good summary
|
||||
of the semantic content of the input, you're often better with averaging or pooling
|
||||
the sequence of hidden-states for the whole input sequence.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(input_shape, device=device)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
|
||||
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.is_decoder and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
||||
)
|
||||
encoder_outputs = embedding_output
|
||||
|
||||
if self.training:
|
||||
res = []
|
||||
for i in range(self.config.num_hidden_layers):
|
||||
encoder_outputs = self.encoder.adaptive_forward(
|
||||
encoder_outputs, current_layer=i, attention_mask=extended_attention_mask, head_mask=head_mask
|
||||
)
|
||||
|
||||
pooled_output = self.pooler(encoder_outputs)
|
||||
logits = output_layers[i](output_dropout(pooled_output))
|
||||
res.append(logits)
|
||||
elif self.patience == 0: # Use all layers for inference
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
)
|
||||
pooled_output = self.pooler(encoder_outputs[0])
|
||||
res = [output_layers[self.config.num_hidden_layers - 1](pooled_output)]
|
||||
else:
|
||||
patient_counter = 0
|
||||
patient_result = None
|
||||
calculated_layer_num = 0
|
||||
for i in range(self.config.num_hidden_layers):
|
||||
calculated_layer_num += 1
|
||||
encoder_outputs = self.encoder.adaptive_forward(
|
||||
encoder_outputs, current_layer=i, attention_mask=extended_attention_mask, head_mask=head_mask
|
||||
)
|
||||
|
||||
pooled_output = self.pooler(encoder_outputs)
|
||||
logits = output_layers[i](pooled_output)
|
||||
if regression:
|
||||
labels = logits.detach()
|
||||
if patient_result is not None:
|
||||
patient_labels = patient_result.detach()
|
||||
if (patient_result is not None) and torch.abs(patient_result - labels) < self.regression_threshold:
|
||||
patient_counter += 1
|
||||
else:
|
||||
patient_counter = 0
|
||||
else:
|
||||
labels = logits.detach().argmax(dim=1)
|
||||
if patient_result is not None:
|
||||
patient_labels = patient_result.detach().argmax(dim=1)
|
||||
if (patient_result is not None) and torch.all(labels.eq(patient_labels)):
|
||||
patient_counter += 1
|
||||
else:
|
||||
patient_counter = 0
|
||||
|
||||
patient_result = logits
|
||||
if patient_counter == self.patience:
|
||||
break
|
||||
res = [patient_result]
|
||||
self.inference_layers_num += calculated_layer_num
|
||||
self.inference_instances_num += 1
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""Bert Model transformer with PABEE and a sequence classification/regression head on top (a linear layer on top of
|
||||
the pooled output) e.g. for GLUE tasks. """,
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class BertForSequenceClassificationWithPabee(BertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.bert = BertModelWithPabee(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifiers = nn.ModuleList(
|
||||
[nn.Linear(config.hidden_size, self.config.num_labels) for _ in range(config.num_hidden_layers)]
|
||||
)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the sequence classification/regression loss.
|
||||
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
|
||||
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Returns:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
|
||||
Examples::
|
||||
|
||||
from transformers import BertTokenizer, BertForSequenceClassification
|
||||
from pabee import BertForSequenceClassificationWithPabee
|
||||
import torch
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
model = BertForSequenceClassificationWithPabee.from_pretrained('bert-base-uncased')
|
||||
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, labels=labels)
|
||||
|
||||
loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
|
||||
logits = self.bert(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_dropout=self.dropout,
|
||||
output_layers=self.classifiers,
|
||||
regression=self.num_labels == 1,
|
||||
)
|
||||
|
||||
outputs = (logits[-1],)
|
||||
|
||||
if labels is not None:
|
||||
total_loss = None
|
||||
total_weights = 0
|
||||
for ix, logits_item in enumerate(logits):
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits_item.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits_item.view(-1, self.num_labels), labels.view(-1))
|
||||
if total_loss is None:
|
||||
total_loss = loss
|
||||
else:
|
||||
total_loss += loss * (ix + 1)
|
||||
total_weights += ix + 1
|
||||
outputs = (total_loss / total_weights,) + outputs
|
||||
|
||||
return outputs
|
||||
@@ -0,0 +1 @@
|
||||
transformers == 3.5.1
|
||||
750
examples/research_projects/bert-loses-patience/run_glue_with_pabee.py
Executable file
750
examples/research_projects/bert-loses-patience/run_glue_with_pabee.py
Executable file
@@ -0,0 +1,750 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Google AI Language Team Authors, The HuggingFace Inc. team and Microsoft Corporation.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Training and inference using the library models for sequence classification on GLUE (Bert, Albert) with PABEE."""
|
||||
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
import transformers
|
||||
from pabee.modeling_pabee_albert import AlbertForSequenceClassificationWithPabee
|
||||
from pabee.modeling_pabee_bert import BertForSequenceClassificationWithPabee
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
AlbertConfig,
|
||||
AlbertTokenizer,
|
||||
BertConfig,
|
||||
BertTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from transformers import glue_compute_metrics as compute_metrics
|
||||
from transformers import glue_convert_examples_to_features as convert_examples_to_features
|
||||
from transformers import glue_output_modes as output_modes
|
||||
from transformers import glue_processors as processors
|
||||
from transformers.trainer_utils import is_main_process
|
||||
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except ImportError:
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForSequenceClassificationWithPabee, BertTokenizer),
|
||||
"albert": (AlbertConfig, AlbertForSequenceClassificationWithPabee, AlbertTokenizer),
|
||||
}
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer):
|
||||
""" Train the model """
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter()
|
||||
|
||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
if args.max_steps > 0:
|
||||
t_total = args.max_steps
|
||||
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
||||
else:
|
||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||
]
|
||||
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
|
||||
os.path.join(args.model_name_or_path, "scheduler.pt")
|
||||
):
|
||||
# Load in optimizer and scheduler states
|
||||
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
||||
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
|
||||
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||
|
||||
# multi-gpu training (should be after apex fp16 initialization)
|
||||
if args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True,
|
||||
)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size
|
||||
* args.gradient_accumulation_steps
|
||||
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
global_step = 0
|
||||
epochs_trained = 0
|
||||
steps_trained_in_current_epoch = 0
|
||||
# Check if continuing training from a checkpoint
|
||||
if os.path.exists(args.model_name_or_path):
|
||||
# set global_step to gobal_step of last saved checkpoint from model path
|
||||
global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
|
||||
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||
|
||||
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
||||
logger.info(" Continuing training from epoch %d", epochs_trained)
|
||||
logger.info(" Continuing training from global step %d", global_step)
|
||||
logger.info(
|
||||
" Will skip the first %d steps in the first epoch",
|
||||
steps_trained_in_current_epoch,
|
||||
)
|
||||
|
||||
tr_loss, logging_loss = 0.0, 0.0
|
||||
model.zero_grad()
|
||||
train_iterator = trange(
|
||||
epochs_trained,
|
||||
int(args.num_train_epochs),
|
||||
desc="Epoch",
|
||||
disable=args.local_rank not in [-1, 0],
|
||||
)
|
||||
set_seed(args) # Added here for reproductibility
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
|
||||
# Skip past any already trained steps if resuming training
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
steps_trained_in_current_epoch -= 1
|
||||
continue
|
||||
|
||||
model.train()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
inputs = {
|
||||
"input_ids": batch[0],
|
||||
"attention_mask": batch[1],
|
||||
"labels": batch[3],
|
||||
}
|
||||
inputs["token_type_ids"] = batch[2]
|
||||
outputs = model(**inputs)
|
||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||
|
||||
if args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
if args.fp16:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
tr_loss += loss.item()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
if args.fp16:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step() # Update learning rate schedule
|
||||
model.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
logs = {}
|
||||
if (
|
||||
args.local_rank == -1 and args.evaluate_during_training
|
||||
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||
results = evaluate(args, model, tokenizer)
|
||||
for key, value in results.items():
|
||||
eval_key = "eval_{}".format(key)
|
||||
logs[eval_key] = value
|
||||
|
||||
loss_scalar = (tr_loss - logging_loss) / args.logging_steps
|
||||
learning_rate_scalar = scheduler.get_lr()[0]
|
||||
logs["learning_rate"] = learning_rate_scalar
|
||||
logs["loss"] = loss_scalar
|
||||
logging_loss = tr_loss
|
||||
|
||||
for key, value in logs.items():
|
||||
tb_writer.add_scalar(key, value, global_step)
|
||||
print(json.dumps({**logs, **{"step": global_step}}))
|
||||
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
epoch_iterator.close()
|
||||
break
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
train_iterator.close()
|
||||
break
|
||||
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer.close()
|
||||
|
||||
return global_step, tr_loss / global_step
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, prefix="", patience=0):
|
||||
|
||||
if args.model_type == "albert":
|
||||
model.albert.set_regression_threshold(args.regression_threshold)
|
||||
model.albert.set_patience(patience)
|
||||
model.albert.reset_stats()
|
||||
elif args.model_type == "bert":
|
||||
model.bert.set_regression_threshold(args.regression_threshold)
|
||||
model.bert.set_patience(patience)
|
||||
model.bert.reset_stats()
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
||||
eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
|
||||
eval_outputs_dirs = (args.output_dir, args.output_dir + "-MM") if args.task_name == "mnli" else (args.output_dir,)
|
||||
|
||||
results = {}
|
||||
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
|
||||
eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
|
||||
|
||||
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(eval_output_dir)
|
||||
|
||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(eval_dataset)
|
||||
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
# multi-gpu eval
|
||||
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Eval!
|
||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||
logger.info(" Num examples = %d", len(eval_dataset))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
eval_loss = 0.0
|
||||
nb_eval_steps = 0
|
||||
preds = None
|
||||
out_label_ids = None
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
model.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = {
|
||||
"input_ids": batch[0],
|
||||
"attention_mask": batch[1],
|
||||
"labels": batch[3],
|
||||
}
|
||||
inputs["token_type_ids"] = batch[2]
|
||||
outputs = model(**inputs)
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
|
||||
eval_loss += tmp_eval_loss.mean().item()
|
||||
nb_eval_steps += 1
|
||||
if preds is None:
|
||||
preds = logits.detach().cpu().numpy()
|
||||
out_label_ids = inputs["labels"].detach().cpu().numpy()
|
||||
else:
|
||||
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
||||
out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
if args.output_mode == "classification":
|
||||
preds = np.argmax(preds, axis=1)
|
||||
elif args.output_mode == "regression":
|
||||
preds = np.squeeze(preds)
|
||||
result = compute_metrics(eval_task, preds, out_label_ids)
|
||||
results.update(result)
|
||||
|
||||
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results {} *****".format(prefix))
|
||||
for key in sorted(result.keys()):
|
||||
logger.info(" %s = %s", key, str(result[key]))
|
||||
print(" %s = %s" % (key, str(result[key])))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
if args.eval_all_checkpoints and patience != 0:
|
||||
if args.model_type == "albert":
|
||||
model.albert.log_stats()
|
||||
elif args.model_type == "bert":
|
||||
model.bert.log_stats()
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
||||
if args.local_rank not in [-1, 0] and not evaluate:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
processor = processors[task]()
|
||||
output_mode = output_modes[task]
|
||||
# Load data features from cache or dataset file
|
||||
cached_features_file = os.path.join(
|
||||
args.data_dir,
|
||||
"cached_{}_{}_{}_{}".format(
|
||||
"dev" if evaluate else "train",
|
||||
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||||
str(args.max_seq_length),
|
||||
str(task),
|
||||
),
|
||||
)
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features = torch.load(cached_features_file)
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s", args.data_dir)
|
||||
label_list = processor.get_labels()
|
||||
if task in ["mnli", "mnli-mm"] and args.model_type in ["roberta", "xlmroberta"]:
|
||||
# HACK(label indices are swapped in RoBERTa pretrained model)
|
||||
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||
examples = (
|
||||
processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
||||
)
|
||||
features = convert_examples_to_features(
|
||||
examples,
|
||||
tokenizer,
|
||||
label_list=label_list,
|
||||
max_length=args.max_seq_length,
|
||||
output_mode=output_mode,
|
||||
)
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(features, cached_features_file)
|
||||
|
||||
if args.local_rank == 0 and not evaluate:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
# Convert to Tensors and build dataset
|
||||
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
||||
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
|
||||
if output_mode == "classification":
|
||||
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
||||
elif output_mode == "regression":
|
||||
all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
|
||||
|
||||
dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_name",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--patience",
|
||||
default="0",
|
||||
type=str,
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--regression_threshold",
|
||||
default=0,
|
||||
type=float,
|
||||
required=False,
|
||||
)
|
||||
|
||||
# Other parameters
|
||||
parser.add_argument(
|
||||
"--config_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained config name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from huggingface.co",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||
parser.add_argument(
|
||||
"--evaluate_during_training",
|
||||
action="store_true",
|
||||
help="Run evaluation during training at each logging step.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_lower_case",
|
||||
action="store_true",
|
||||
help="Set this flag if you are using an uncased model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--per_gpu_train_batch_size",
|
||||
default=8,
|
||||
type=int,
|
||||
help="Batch size per GPU/CPU for training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_gpu_eval_batch_size",
|
||||
default=1,
|
||||
type=int,
|
||||
help="Batch size per GPU/CPU for evaluation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
default=5e-5,
|
||||
type=float,
|
||||
help="The initial learning rate for Adam.",
|
||||
)
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs",
|
||||
default=3.0,
|
||||
type=float,
|
||||
help="Total number of training epochs to perform.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||
)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
|
||||
parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
|
||||
parser.add_argument(
|
||||
"--save_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Save checkpoint every X updates steps.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_all_checkpoints",
|
||||
action="store_true",
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||
)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||
parser.add_argument(
|
||||
"--overwrite_output_dir",
|
||||
action="store_true",
|
||||
help="Overwrite the content of the output directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache",
|
||||
action="store_true",
|
||||
help="Overwrite the cached training and evaluation sets",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local_rank",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="For distributed training: local_rank",
|
||||
)
|
||||
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if (
|
||||
os.path.exists(args.output_dir)
|
||||
and os.listdir(args.output_dir)
|
||||
and args.do_train
|
||||
and not args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||
args.output_dir
|
||||
)
|
||||
)
|
||||
|
||||
# Setup distant debugging if needed
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
|
||||
# Setup CUDA, GPU & distributed training
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
args.n_gpu = torch.cuda.device_count()
|
||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank,
|
||||
device,
|
||||
args.n_gpu,
|
||||
bool(args.local_rank != -1),
|
||||
args.fp16,
|
||||
)
|
||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||
if is_main_process(args.local_rank):
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
|
||||
# Prepare GLUE task
|
||||
args.task_name = args.task_name.lower()
|
||||
if args.task_name not in processors:
|
||||
raise ValueError("Task not found: %s" % (args.task_name))
|
||||
processor = processors[args.task_name]()
|
||||
args.output_mode = output_modes[args.task_name]
|
||||
label_list = processor.get_labels()
|
||||
num_labels = len(label_list)
|
||||
|
||||
if args.patience != "0" and args.per_gpu_eval_batch_size != 1:
|
||||
raise ValueError("The eval batch size must be 1 with PABEE inference on.")
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=args.task_name,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
model.to(args.device)
|
||||
|
||||
print("Total Model Parameters:", sum(param.numel() for param in model.parameters()))
|
||||
output_layers_param_num = sum(param.numel() for param in model.classifiers.parameters())
|
||||
print("Output Layers Parameters:", output_layers_param_num)
|
||||
single_output_layer_param_num = sum(param.numel() for param in model.classifiers[0].parameters())
|
||||
print(
|
||||
"Added Output Layers Parameters:",
|
||||
output_layers_param_num - single_output_layer_param_num,
|
||||
)
|
||||
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Training
|
||||
if args.do_train:
|
||||
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
patience_list = [int(x) for x in args.patience.split(",")]
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(
|
||||
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||
)
|
||||
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
model.to(args.device)
|
||||
|
||||
print(f"Evaluation for checkpoint {prefix}")
|
||||
for patience in patience_list:
|
||||
result = evaluate(args, model, tokenizer, prefix=prefix, patience=patience)
|
||||
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
||||
results.update(result)
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,51 @@
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
import run_glue_with_pabee
|
||||
from transformers.testing_utils import TestCasePlus, require_torch_non_multi_gpu_but_fix_me
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def get_setup_file():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-f")
|
||||
args = parser.parse_args()
|
||||
return args.f
|
||||
|
||||
|
||||
class PabeeTests(TestCasePlus):
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_run_glue(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_glue_with_pabee.py
|
||||
--model_type albert
|
||||
--model_name_or_path albert-base-v2
|
||||
--data_dir ./tests/fixtures/tests_samples/MRPC/
|
||||
--output_dir {tmp_dir}
|
||||
--overwrite_output_dir
|
||||
--task_name mrpc
|
||||
--do_train
|
||||
--do_eval
|
||||
--per_gpu_train_batch_size=2
|
||||
--per_gpu_eval_batch_size=1
|
||||
--learning_rate=2e-5
|
||||
--max_steps=50
|
||||
--warmup_steps=2
|
||||
--seed=42
|
||||
--max_seq_length=128
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
result = run_glue_with_pabee.main()
|
||||
for value in result.values():
|
||||
self.assertGreaterEqual(value, 0.75)
|
||||
61
examples/research_projects/bertabs/README.md
Normal file
61
examples/research_projects/bertabs/README.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# Text Summarization with Pretrained Encoders
|
||||
|
||||
This folder contains part of the code necessary to reproduce the results on abstractive summarization from the article [Text Summarization with Pretrained Encoders](https://arxiv.org/pdf/1908.08345.pdf) by [Yang Liu](https://nlp-yang.github.io/) and [Mirella Lapata](https://homepages.inf.ed.ac.uk/mlap/). It can also be used to summarize any document.
|
||||
|
||||
The original code can be found on the Yang Liu's [github repository](https://github.com/nlpyang/PreSumm).
|
||||
|
||||
The model is loaded with the pre-trained weights for the abstractive summarization model trained on the CNN/Daily Mail dataset with an extractive and then abstractive tasks.
|
||||
|
||||
## Setup
|
||||
|
||||
```
|
||||
git clone https://github.com/huggingface/transformers && cd transformers
|
||||
pip install .
|
||||
pip install nltk py-rouge
|
||||
cd examples/seq2seq/bertabs
|
||||
```
|
||||
|
||||
## Reproduce the authors' ROUGE score
|
||||
|
||||
To be able to reproduce the authors' results on the CNN/Daily Mail dataset you first need to download both CNN and Daily Mail datasets [from Kyunghyun Cho's website](https://cs.nyu.edu/~kcho/DMQA/) (the links next to "Stories") in the same folder. Then uncompress the archives by running:
|
||||
|
||||
```bash
|
||||
tar -xvf cnn_stories.tgz && tar -xvf dailymail_stories.tgz
|
||||
```
|
||||
|
||||
And move all the stories to the same folder. We will refer as `$DATA_PATH` the path to where you uncompressed both archive. Then run the following in the same folder as `run_summarization.py`:
|
||||
|
||||
```bash
|
||||
python run_summarization.py \
|
||||
--documents_dir $DATA_PATH \
|
||||
--summaries_output_dir $SUMMARIES_PATH \ # optional
|
||||
--no_cuda false \
|
||||
--batch_size 4 \
|
||||
--min_length 50 \
|
||||
--max_length 200 \
|
||||
--beam_size 5 \
|
||||
--alpha 0.95 \
|
||||
--block_trigram true \
|
||||
--compute_rouge true
|
||||
```
|
||||
|
||||
The scripts executes on GPU if one is available and if `no_cuda` is not set to `true`. Inference on multiple GPUs is not supported yet. The ROUGE scores will be displayed in the console at the end of evaluation and written in a `rouge_scores.txt` file. The script takes 30 hours to compute with a single Tesla V100 GPU and a batch size of 10 (300,000 texts to summarize).
|
||||
|
||||
## Summarize any text
|
||||
|
||||
Put the documents that you would like to summarize in a folder (the path to which is referred to as `$DATA_PATH` below) and run the following in the same folder as `run_summarization.py`:
|
||||
|
||||
```bash
|
||||
python run_summarization.py \
|
||||
--documents_dir $DATA_PATH \
|
||||
--summaries_output_dir $SUMMARIES_PATH \ # optional
|
||||
--no_cuda false \
|
||||
--batch_size 4 \
|
||||
--min_length 50 \
|
||||
--max_length 200 \
|
||||
--beam_size 5 \
|
||||
--alpha 0.95 \
|
||||
--block_trigram true \
|
||||
```
|
||||
|
||||
You may want to play around with `min_length`, `max_length` and `alpha` to suit your use case. If you want to compute ROUGE on another dataset you will need to tweak the stories/summaries import in `utils_summarization.py` and tell it where to fetch the reference summaries.
|
||||
0
examples/research_projects/bertabs/__init__.py
Normal file
0
examples/research_projects/bertabs/__init__.py
Normal file
97
examples/research_projects/bertabs/configuration_bertabs.py
Normal file
97
examples/research_projects/bertabs/configuration_bertabs.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" BertAbs configuration """
|
||||
import logging
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
BERTABS_FINETUNED_CONFIG_MAP = {
|
||||
"bertabs-finetuned-cnndm": "https://huggingface.co/remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization/resolve/main/config.json",
|
||||
}
|
||||
|
||||
|
||||
class BertAbsConfig(PretrainedConfig):
|
||||
r"""Class to store the configuration of the BertAbs model.
|
||||
|
||||
Arguments:
|
||||
vocab_size: int
|
||||
Number of tokens in the vocabulary.
|
||||
max_pos: int
|
||||
The maximum sequence length that this model will be used with.
|
||||
enc_layer: int
|
||||
The numner of hidden layers in the Transformer encoder.
|
||||
enc_hidden_size: int
|
||||
The size of the encoder's layers.
|
||||
enc_heads: int
|
||||
The number of attention heads for each attention layer in the encoder.
|
||||
enc_ff_size: int
|
||||
The size of the encoder's feed-forward layers.
|
||||
enc_dropout: int
|
||||
The dropout probability for all fully connected layers in the
|
||||
embeddings, layers, pooler and also the attention probabilities in
|
||||
the encoder.
|
||||
dec_layer: int
|
||||
The numner of hidden layers in the decoder.
|
||||
dec_hidden_size: int
|
||||
The size of the decoder's layers.
|
||||
dec_heads: int
|
||||
The number of attention heads for each attention layer in the decoder.
|
||||
dec_ff_size: int
|
||||
The size of the decoder's feed-forward layers.
|
||||
dec_dropout: int
|
||||
The dropout probability for all fully connected layers in the
|
||||
embeddings, layers, pooler and also the attention probabilities in
|
||||
the decoder.
|
||||
"""
|
||||
|
||||
model_type = "bertabs"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=30522,
|
||||
max_pos=512,
|
||||
enc_layers=6,
|
||||
enc_hidden_size=512,
|
||||
enc_heads=8,
|
||||
enc_ff_size=512,
|
||||
enc_dropout=0.2,
|
||||
dec_layers=6,
|
||||
dec_hidden_size=768,
|
||||
dec_heads=8,
|
||||
dec_ff_size=2048,
|
||||
dec_dropout=0.2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.max_pos = max_pos
|
||||
|
||||
self.enc_layers = enc_layers
|
||||
self.enc_hidden_size = enc_hidden_size
|
||||
self.enc_heads = enc_heads
|
||||
self.enc_ff_size = enc_ff_size
|
||||
self.enc_dropout = enc_dropout
|
||||
|
||||
self.dec_layers = dec_layers
|
||||
self.dec_hidden_size = dec_hidden_size
|
||||
self.dec_heads = dec_heads
|
||||
self.dec_ff_size = dec_ff_size
|
||||
self.dec_dropout = dec_dropout
|
||||
@@ -0,0 +1,185 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Convert BertExtAbs's checkpoints.
|
||||
|
||||
The script looks like it is doing something trivial but it is not. The "weights"
|
||||
proposed by the authors are actually the entire model pickled. We need to load
|
||||
the model within the original codebase to be able to only save its `state_dict`.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
|
||||
from model_bertabs import BertAbsSummarizer
|
||||
from models.model_builder import AbsSummarizer # The authors' implementation
|
||||
from transformers import BertTokenizer
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SAMPLE_TEXT = "Hello world! cécé herlolip"
|
||||
|
||||
|
||||
BertAbsConfig = namedtuple(
|
||||
"BertAbsConfig",
|
||||
[
|
||||
"temp_dir",
|
||||
"large",
|
||||
"use_bert_emb",
|
||||
"finetune_bert",
|
||||
"encoder",
|
||||
"share_emb",
|
||||
"max_pos",
|
||||
"enc_layers",
|
||||
"enc_hidden_size",
|
||||
"enc_heads",
|
||||
"enc_ff_size",
|
||||
"enc_dropout",
|
||||
"dec_layers",
|
||||
"dec_hidden_size",
|
||||
"dec_heads",
|
||||
"dec_ff_size",
|
||||
"dec_dropout",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def convert_bertabs_checkpoints(path_to_checkpoints, dump_path):
|
||||
"""Copy/paste and tweak the pre-trained weights provided by the creators
|
||||
of BertAbs for the internal architecture.
|
||||
"""
|
||||
|
||||
# Instantiate the authors' model with the pre-trained weights
|
||||
config = BertAbsConfig(
|
||||
temp_dir=".",
|
||||
finetune_bert=False,
|
||||
large=False,
|
||||
share_emb=True,
|
||||
use_bert_emb=False,
|
||||
encoder="bert",
|
||||
max_pos=512,
|
||||
enc_layers=6,
|
||||
enc_hidden_size=512,
|
||||
enc_heads=8,
|
||||
enc_ff_size=512,
|
||||
enc_dropout=0.2,
|
||||
dec_layers=6,
|
||||
dec_hidden_size=768,
|
||||
dec_heads=8,
|
||||
dec_ff_size=2048,
|
||||
dec_dropout=0.2,
|
||||
)
|
||||
checkpoints = torch.load(path_to_checkpoints, lambda storage, loc: storage)
|
||||
original = AbsSummarizer(config, torch.device("cpu"), checkpoints)
|
||||
original.eval()
|
||||
|
||||
new_model = BertAbsSummarizer(config, torch.device("cpu"))
|
||||
new_model.eval()
|
||||
|
||||
# -------------------
|
||||
# Convert the weights
|
||||
# -------------------
|
||||
|
||||
logging.info("convert the model")
|
||||
new_model.bert.load_state_dict(original.bert.state_dict())
|
||||
new_model.decoder.load_state_dict(original.decoder.state_dict())
|
||||
new_model.generator.load_state_dict(original.generator.state_dict())
|
||||
|
||||
# ----------------------------------
|
||||
# Make sure the outpus are identical
|
||||
# ----------------------------------
|
||||
|
||||
logging.info("Make sure that the models' outputs are identical")
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# prepare the model inputs
|
||||
encoder_input_ids = tokenizer.encode("This is sample éàalj'-.")
|
||||
encoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(encoder_input_ids)))
|
||||
encoder_input_ids = torch.tensor(encoder_input_ids).unsqueeze(0)
|
||||
decoder_input_ids = tokenizer.encode("This is sample 3 éàalj'-.")
|
||||
decoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(decoder_input_ids)))
|
||||
decoder_input_ids = torch.tensor(decoder_input_ids).unsqueeze(0)
|
||||
|
||||
# failsafe to make sure the weights reset does not affect the
|
||||
# loaded weights.
|
||||
assert torch.max(torch.abs(original.generator[0].weight - new_model.generator[0].weight)) == 0
|
||||
|
||||
# forward pass
|
||||
src = encoder_input_ids
|
||||
tgt = decoder_input_ids
|
||||
segs = token_type_ids = None
|
||||
clss = None
|
||||
mask_src = encoder_attention_mask = None
|
||||
mask_tgt = decoder_attention_mask = None
|
||||
mask_cls = None
|
||||
|
||||
# The original model does not apply the geneator layer immediatly but rather in
|
||||
# the beam search (where it combines softmax + linear layer). Since we already
|
||||
# apply the softmax in our generation process we only apply the linear layer here.
|
||||
# We make sure that the outputs of the full stack are identical
|
||||
output_original_model = original(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)[0]
|
||||
output_original_generator = original.generator(output_original_model)
|
||||
|
||||
output_converted_model = new_model(
|
||||
encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask
|
||||
)[0]
|
||||
output_converted_generator = new_model.generator(output_converted_model)
|
||||
|
||||
maximum_absolute_difference = torch.max(torch.abs(output_converted_model - output_original_model)).item()
|
||||
print("Maximum absolute difference beween weights: {:.2f}".format(maximum_absolute_difference))
|
||||
maximum_absolute_difference = torch.max(torch.abs(output_converted_generator - output_original_generator)).item()
|
||||
print("Maximum absolute difference beween weights: {:.2f}".format(maximum_absolute_difference))
|
||||
|
||||
are_identical = torch.allclose(output_converted_model, output_original_model, atol=1e-3)
|
||||
if are_identical:
|
||||
logging.info("all weights are equal up to 1e-3")
|
||||
else:
|
||||
raise ValueError("the weights are different. The new model is likely different from the original one.")
|
||||
|
||||
# The model has been saved with torch.save(model) and this is bound to the exact
|
||||
# directory structure. We save the state_dict instead.
|
||||
logging.info("saving the model's state dictionary")
|
||||
torch.save(
|
||||
new_model.state_dict(), "./bertabs-finetuned-cnndm-extractive-abstractive-summarization/pytorch_model.bin"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--bertabs_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path the official PyTorch dump.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the output PyTorch model.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_bertabs_checkpoints(
|
||||
args.bertabs_checkpoint_path,
|
||||
args.pytorch_dump_folder_path,
|
||||
)
|
||||
1058
examples/research_projects/bertabs/modeling_bertabs.py
Normal file
1058
examples/research_projects/bertabs/modeling_bertabs.py
Normal file
File diff suppressed because it is too large
Load Diff
5
examples/research_projects/bertabs/requirements.txt
Normal file
5
examples/research_projects/bertabs/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
transformers == 3.5.1
|
||||
|
||||
# For ROUGE
|
||||
nltk
|
||||
py-rouge
|
||||
347
examples/research_projects/bertabs/run_summarization.py
Normal file
347
examples/research_projects/bertabs/run_summarization.py
Normal file
@@ -0,0 +1,347 @@
|
||||
#! /usr/bin/python3
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, SequentialSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from modeling_bertabs import BertAbs, build_predictor
|
||||
from transformers import BertTokenizer
|
||||
|
||||
from .utils_summarization import (
|
||||
CNNDMDataset,
|
||||
build_mask,
|
||||
compute_token_type_ids,
|
||||
encode_for_summarization,
|
||||
truncate_or_pad,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||
|
||||
|
||||
Batch = namedtuple("Batch", ["document_names", "batch_size", "src", "segs", "mask_src", "tgt_str"])
|
||||
|
||||
|
||||
def evaluate(args):
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
|
||||
model = BertAbs.from_pretrained("remi/bertabs-finetuned-extractive-abstractive-summarization")
|
||||
model.to(args.device)
|
||||
model.eval()
|
||||
|
||||
symbols = {
|
||||
"BOS": tokenizer.vocab["[unused0]"],
|
||||
"EOS": tokenizer.vocab["[unused1]"],
|
||||
"PAD": tokenizer.vocab["[PAD]"],
|
||||
}
|
||||
|
||||
if args.compute_rouge:
|
||||
reference_summaries = []
|
||||
generated_summaries = []
|
||||
|
||||
import nltk
|
||||
|
||||
import rouge
|
||||
|
||||
nltk.download("punkt")
|
||||
rouge_evaluator = rouge.Rouge(
|
||||
metrics=["rouge-n", "rouge-l"],
|
||||
max_n=2,
|
||||
limit_length=True,
|
||||
length_limit=args.beam_size,
|
||||
length_limit_type="words",
|
||||
apply_avg=True,
|
||||
apply_best=False,
|
||||
alpha=0.5, # Default F1_score
|
||||
weight_factor=1.2,
|
||||
stemming=True,
|
||||
)
|
||||
|
||||
# these (unused) arguments are defined to keep the compatibility
|
||||
# with the legacy code and will be deleted in a next iteration.
|
||||
args.result_path = ""
|
||||
args.temp_dir = ""
|
||||
|
||||
data_iterator = build_data_iterator(args, tokenizer)
|
||||
predictor = build_predictor(args, tokenizer, symbols, model)
|
||||
|
||||
logger.info("***** Running evaluation *****")
|
||||
logger.info(" Number examples = %d", len(data_iterator.dataset))
|
||||
logger.info(" Batch size = %d", args.batch_size)
|
||||
logger.info("")
|
||||
logger.info("***** Beam Search parameters *****")
|
||||
logger.info(" Beam size = %d", args.beam_size)
|
||||
logger.info(" Minimum length = %d", args.min_length)
|
||||
logger.info(" Maximum length = %d", args.max_length)
|
||||
logger.info(" Alpha (length penalty) = %.2f", args.alpha)
|
||||
logger.info(" Trigrams %s be blocked", ("will" if args.block_trigram else "will NOT"))
|
||||
|
||||
for batch in tqdm(data_iterator):
|
||||
batch_data = predictor.translate_batch(batch)
|
||||
translations = predictor.from_batch(batch_data)
|
||||
summaries = [format_summary(t) for t in translations]
|
||||
save_summaries(summaries, args.summaries_output_dir, batch.document_names)
|
||||
|
||||
if args.compute_rouge:
|
||||
reference_summaries += batch.tgt_str
|
||||
generated_summaries += summaries
|
||||
|
||||
if args.compute_rouge:
|
||||
scores = rouge_evaluator.get_scores(generated_summaries, reference_summaries)
|
||||
str_scores = format_rouge_scores(scores)
|
||||
save_rouge_scores(str_scores)
|
||||
print(str_scores)
|
||||
|
||||
|
||||
def save_summaries(summaries, path, original_document_name):
|
||||
"""Write the summaries in fies that are prefixed by the original
|
||||
files' name with the `_summary` appended.
|
||||
|
||||
Attributes:
|
||||
original_document_names: List[string]
|
||||
Name of the document that was summarized.
|
||||
path: string
|
||||
Path were the summaries will be written
|
||||
summaries: List[string]
|
||||
The summaries that we produced.
|
||||
"""
|
||||
for summary, document_name in zip(summaries, original_document_name):
|
||||
# Prepare the summary file's name
|
||||
if "." in document_name:
|
||||
bare_document_name = ".".join(document_name.split(".")[:-1])
|
||||
extension = document_name.split(".")[-1]
|
||||
name = bare_document_name + "_summary." + extension
|
||||
else:
|
||||
name = document_name + "_summary"
|
||||
|
||||
file_path = os.path.join(path, name)
|
||||
with open(file_path, "w") as output:
|
||||
output.write(summary)
|
||||
|
||||
|
||||
def format_summary(translation):
|
||||
"""Transforms the output of the `from_batch` function
|
||||
into nicely formatted summaries.
|
||||
"""
|
||||
raw_summary, _, _ = translation
|
||||
summary = (
|
||||
raw_summary.replace("[unused0]", "")
|
||||
.replace("[unused3]", "")
|
||||
.replace("[PAD]", "")
|
||||
.replace("[unused1]", "")
|
||||
.replace(r" +", " ")
|
||||
.replace(" [unused2] ", ". ")
|
||||
.replace("[unused2]", "")
|
||||
.strip()
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
def format_rouge_scores(scores):
|
||||
return """\n
|
||||
****** ROUGE SCORES ******
|
||||
|
||||
** ROUGE 1
|
||||
F1 >> {:.3f}
|
||||
Precision >> {:.3f}
|
||||
Recall >> {:.3f}
|
||||
|
||||
** ROUGE 2
|
||||
F1 >> {:.3f}
|
||||
Precision >> {:.3f}
|
||||
Recall >> {:.3f}
|
||||
|
||||
** ROUGE L
|
||||
F1 >> {:.3f}
|
||||
Precision >> {:.3f}
|
||||
Recall >> {:.3f}""".format(
|
||||
scores["rouge-1"]["f"],
|
||||
scores["rouge-1"]["p"],
|
||||
scores["rouge-1"]["r"],
|
||||
scores["rouge-2"]["f"],
|
||||
scores["rouge-2"]["p"],
|
||||
scores["rouge-2"]["r"],
|
||||
scores["rouge-l"]["f"],
|
||||
scores["rouge-l"]["p"],
|
||||
scores["rouge-l"]["r"],
|
||||
)
|
||||
|
||||
|
||||
def save_rouge_scores(str_scores):
|
||||
with open("rouge_scores.txt", "w") as output:
|
||||
output.write(str_scores)
|
||||
|
||||
|
||||
#
|
||||
# LOAD the dataset
|
||||
#
|
||||
|
||||
|
||||
def build_data_iterator(args, tokenizer):
|
||||
dataset = load_and_cache_examples(args, tokenizer)
|
||||
sampler = SequentialSampler(dataset)
|
||||
|
||||
def collate_fn(data):
|
||||
return collate(data, tokenizer, block_size=512, device=args.device)
|
||||
|
||||
iterator = DataLoader(
|
||||
dataset,
|
||||
sampler=sampler,
|
||||
batch_size=args.batch_size,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
return iterator
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer):
|
||||
dataset = CNNDMDataset(args.documents_dir)
|
||||
return dataset
|
||||
|
||||
|
||||
def collate(data, tokenizer, block_size, device):
|
||||
"""Collate formats the data passed to the data loader.
|
||||
|
||||
In particular we tokenize the data batch after batch to avoid keeping them
|
||||
all in memory. We output the data as a namedtuple to fit the original BertAbs's
|
||||
API.
|
||||
"""
|
||||
data = [x for x in data if not len(x[1]) == 0] # remove empty_files
|
||||
names = [name for name, _, _ in data]
|
||||
summaries = [" ".join(summary_list) for _, _, summary_list in data]
|
||||
|
||||
encoded_text = [encode_for_summarization(story, summary, tokenizer) for _, story, summary in data]
|
||||
encoded_stories = torch.tensor(
|
||||
[truncate_or_pad(story, block_size, tokenizer.pad_token_id) for story, _ in encoded_text]
|
||||
)
|
||||
encoder_token_type_ids = compute_token_type_ids(encoded_stories, tokenizer.cls_token_id)
|
||||
encoder_mask = build_mask(encoded_stories, tokenizer.pad_token_id)
|
||||
|
||||
batch = Batch(
|
||||
document_names=names,
|
||||
batch_size=len(encoded_stories),
|
||||
src=encoded_stories.to(device),
|
||||
segs=encoder_token_type_ids.to(device),
|
||||
mask_src=encoder_mask.to(device),
|
||||
tgt_str=summaries,
|
||||
)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def decode_summary(summary_tokens, tokenizer):
|
||||
"""Decode the summary and return it in a format
|
||||
suitable for evaluation.
|
||||
"""
|
||||
summary_tokens = summary_tokens.to("cpu").numpy()
|
||||
summary = tokenizer.decode(summary_tokens)
|
||||
sentences = summary.split(".")
|
||||
sentences = [s + "." for s in sentences]
|
||||
return sentences
|
||||
|
||||
|
||||
def main():
|
||||
"""The main function defines the interface with the users."""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--documents_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The folder where the documents to summarize are located.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--summaries_output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=False,
|
||||
help="The folder in wich the summaries should be written. Defaults to the folder where the documents are",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compute_rouge",
|
||||
default=False,
|
||||
type=bool,
|
||||
required=False,
|
||||
help="Compute the ROUGE metrics during evaluation. Only available for the CNN/DailyMail dataset.",
|
||||
)
|
||||
# EVALUATION options
|
||||
parser.add_argument(
|
||||
"--no_cuda",
|
||||
default=False,
|
||||
type=bool,
|
||||
help="Whether to force the execution on CPU.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
default=4,
|
||||
type=int,
|
||||
help="Batch size per GPU/CPU for training.",
|
||||
)
|
||||
# BEAM SEARCH arguments
|
||||
parser.add_argument(
|
||||
"--min_length",
|
||||
default=50,
|
||||
type=int,
|
||||
help="Minimum number of tokens for the summaries.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_length",
|
||||
default=200,
|
||||
type=int,
|
||||
help="Maixmum number of tokens for the summaries.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--beam_size",
|
||||
default=5,
|
||||
type=int,
|
||||
help="The number of beams to start with for each example.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha",
|
||||
default=0.95,
|
||||
type=float,
|
||||
help="The value of alpha for the length penalty in the beam search.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--block_trigram",
|
||||
default=True,
|
||||
type=bool,
|
||||
help="Whether to block the existence of repeating trigrams in the text generated by beam search.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Select device (distibuted not available)
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
|
||||
# Check the existence of directories
|
||||
if not args.summaries_output_dir:
|
||||
args.summaries_output_dir = args.documents_dir
|
||||
|
||||
if not documents_dir_is_valid(args.documents_dir):
|
||||
raise FileNotFoundError(
|
||||
"We could not find the directory you specified for the documents to summarize, or it was empty. Please specify a valid path."
|
||||
)
|
||||
os.makedirs(args.summaries_output_dir, exist_ok=True)
|
||||
|
||||
evaluate(args)
|
||||
|
||||
|
||||
def documents_dir_is_valid(path):
|
||||
if not os.path.exists(path):
|
||||
return False
|
||||
|
||||
file_list = os.listdir(path)
|
||||
if len(file_list) == 0:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,98 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .utils_summarization import build_mask, compute_token_type_ids, process_story, truncate_or_pad
|
||||
|
||||
|
||||
class SummarizationDataProcessingTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.block_size = 10
|
||||
|
||||
def test_fit_to_block_sequence_too_small(self):
|
||||
""" Pad the sequence with 0 if the sequence is smaller than the block size."""
|
||||
sequence = [1, 2, 3, 4]
|
||||
expected_output = [1, 2, 3, 4, 0, 0, 0, 0, 0, 0]
|
||||
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)
|
||||
|
||||
def test_fit_to_block_sequence_fit_exactly(self):
|
||||
""" Do nothing if the sequence is the right size. """
|
||||
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)
|
||||
|
||||
def test_fit_to_block_sequence_too_big(self):
|
||||
""" Truncate the sequence if it is too long. """
|
||||
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
|
||||
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)
|
||||
|
||||
def test_process_story_no_highlights(self):
|
||||
"""Processing a story with no highlights returns an empty list for the summary."""
|
||||
raw_story = """It was the year of Our Lord one thousand seven hundred and
|
||||
seventy-five.\n\nSpiritual revelations were conceded to England at that
|
||||
favoured period, as at this."""
|
||||
_, summary_lines = process_story(raw_story)
|
||||
self.assertEqual(summary_lines, [])
|
||||
|
||||
def test_process_empty_story(self):
|
||||
"""An empty story returns an empty collection of lines."""
|
||||
raw_story = ""
|
||||
story_lines, summary_lines = process_story(raw_story)
|
||||
self.assertEqual(story_lines, [])
|
||||
self.assertEqual(summary_lines, [])
|
||||
|
||||
def test_process_story_with_missing_period(self):
|
||||
raw_story = (
|
||||
"It was the year of Our Lord one thousand seven hundred and "
|
||||
"seventy-five\n\nSpiritual revelations were conceded to England "
|
||||
"at that favoured period, as at this.\n@highlight\n\nIt was the best of times"
|
||||
)
|
||||
story_lines, summary_lines = process_story(raw_story)
|
||||
|
||||
expected_story_lines = [
|
||||
"It was the year of Our Lord one thousand seven hundred and seventy-five.",
|
||||
"Spiritual revelations were conceded to England at that favoured period, as at this.",
|
||||
]
|
||||
self.assertEqual(expected_story_lines, story_lines)
|
||||
|
||||
expected_summary_lines = ["It was the best of times."]
|
||||
self.assertEqual(expected_summary_lines, summary_lines)
|
||||
|
||||
def test_build_mask_no_padding(self):
|
||||
sequence = torch.tensor([1, 2, 3, 4])
|
||||
expected = torch.tensor([1, 1, 1, 1])
|
||||
np.testing.assert_array_equal(build_mask(sequence, 0).numpy(), expected.numpy())
|
||||
|
||||
def test_build_mask(self):
|
||||
sequence = torch.tensor([1, 2, 3, 4, 23, 23, 23])
|
||||
expected = torch.tensor([1, 1, 1, 1, 0, 0, 0])
|
||||
np.testing.assert_array_equal(build_mask(sequence, 23).numpy(), expected.numpy())
|
||||
|
||||
def test_build_mask_with_padding_equal_to_one(self):
|
||||
sequence = torch.tensor([8, 2, 3, 4, 1, 1, 1])
|
||||
expected = torch.tensor([1, 1, 1, 1, 0, 0, 0])
|
||||
np.testing.assert_array_equal(build_mask(sequence, 1).numpy(), expected.numpy())
|
||||
|
||||
def test_compute_token_type_ids(self):
|
||||
separator = 101
|
||||
batch = torch.tensor([[1, 2, 3, 4, 5, 6], [1, 2, 3, 101, 5, 6], [1, 101, 3, 4, 101, 6]])
|
||||
expected = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 1, 1]])
|
||||
|
||||
result = compute_token_type_ids(batch, separator)
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
167
examples/research_projects/bertabs/utils_summarization.py
Normal file
167
examples/research_projects/bertabs/utils_summarization.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import os
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
# ------------
|
||||
# Data loading
|
||||
# ------------
|
||||
|
||||
|
||||
class CNNDMDataset(Dataset):
|
||||
"""Abstracts the dataset used to train seq2seq models.
|
||||
|
||||
The class will process the documents that are located in the specified
|
||||
folder. The preprocessing will work on any document that is reasonably
|
||||
formatted. On the CNN/DailyMail dataset it will extract both the story
|
||||
and the summary.
|
||||
|
||||
CNN/Daily News:
|
||||
|
||||
The CNN/Daily News raw datasets are downloaded from [1]. The stories are
|
||||
stored in different files; the summary appears at the end of the story as
|
||||
sentences that are prefixed by the special `@highlight` line. To process
|
||||
the data, untar both datasets in the same folder, and pass the path to this
|
||||
folder as the "data_dir argument. The formatting code was inspired by [2].
|
||||
|
||||
[1] https://cs.nyu.edu/~kcho/
|
||||
[2] https://github.com/abisee/cnn-dailymail/
|
||||
"""
|
||||
|
||||
def __init__(self, path="", prefix="train"):
|
||||
"""We initialize the class by listing all the documents to summarize.
|
||||
Files are not read in memory due to the size of some datasets (like CNN/DailyMail).
|
||||
"""
|
||||
assert os.path.isdir(path)
|
||||
|
||||
self.documents = []
|
||||
story_filenames_list = os.listdir(path)
|
||||
for story_filename in story_filenames_list:
|
||||
if "summary" in story_filename:
|
||||
continue
|
||||
path_to_story = os.path.join(path, story_filename)
|
||||
if not os.path.isfile(path_to_story):
|
||||
continue
|
||||
self.documents.append(path_to_story)
|
||||
|
||||
def __len__(self):
|
||||
""" Returns the number of documents. """
|
||||
return len(self.documents)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
document_path = self.documents[idx]
|
||||
document_name = document_path.split("/")[-1]
|
||||
with open(document_path, encoding="utf-8") as source:
|
||||
raw_story = source.read()
|
||||
story_lines, summary_lines = process_story(raw_story)
|
||||
return document_name, story_lines, summary_lines
|
||||
|
||||
|
||||
def process_story(raw_story):
|
||||
"""Extract the story and summary from a story file.
|
||||
|
||||
Arguments:
|
||||
raw_story (str): content of the story file as an utf-8 encoded string.
|
||||
|
||||
Raises:
|
||||
IndexError: If the story is empty or contains no highlights.
|
||||
"""
|
||||
nonempty_lines = list(filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")]))
|
||||
|
||||
# for some unknown reason some lines miss a period, add it
|
||||
nonempty_lines = [_add_missing_period(line) for line in nonempty_lines]
|
||||
|
||||
# gather article lines
|
||||
story_lines = []
|
||||
lines = deque(nonempty_lines)
|
||||
while True:
|
||||
try:
|
||||
element = lines.popleft()
|
||||
if element.startswith("@highlight"):
|
||||
break
|
||||
story_lines.append(element)
|
||||
except IndexError:
|
||||
# if "@highlight" is absent from the file we pop
|
||||
# all elements until there is None, raising an exception.
|
||||
return story_lines, []
|
||||
|
||||
# gather summary lines
|
||||
summary_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))
|
||||
|
||||
return story_lines, summary_lines
|
||||
|
||||
|
||||
def _add_missing_period(line):
|
||||
END_TOKENS = [".", "!", "?", "...", "'", "`", '"', "\u2019", "\u2019", ")"]
|
||||
if line.startswith("@highlight"):
|
||||
return line
|
||||
if line[-1] in END_TOKENS:
|
||||
return line
|
||||
return line + "."
|
||||
|
||||
|
||||
# --------------------------
|
||||
# Encoding and preprocessing
|
||||
# --------------------------
|
||||
|
||||
|
||||
def truncate_or_pad(sequence, block_size, pad_token_id):
|
||||
"""Adapt the source and target sequences' lengths to the block size.
|
||||
If the sequence is shorter we append padding token to the right of the sequence.
|
||||
"""
|
||||
if len(sequence) > block_size:
|
||||
return sequence[:block_size]
|
||||
else:
|
||||
sequence.extend([pad_token_id] * (block_size - len(sequence)))
|
||||
return sequence
|
||||
|
||||
|
||||
def build_mask(sequence, pad_token_id):
|
||||
"""Builds the mask. The attention mechanism will only attend to positions
|
||||
with value 1."""
|
||||
mask = torch.ones_like(sequence)
|
||||
idx_pad_tokens = sequence == pad_token_id
|
||||
mask[idx_pad_tokens] = 0
|
||||
return mask
|
||||
|
||||
|
||||
def encode_for_summarization(story_lines, summary_lines, tokenizer):
|
||||
"""Encode the story and summary lines, and join them
|
||||
as specified in [1] by using `[SEP] [CLS]` tokens to separate
|
||||
sentences.
|
||||
"""
|
||||
story_lines_token_ids = [tokenizer.encode(line) for line in story_lines]
|
||||
story_token_ids = [token for sentence in story_lines_token_ids for token in sentence]
|
||||
summary_lines_token_ids = [tokenizer.encode(line) for line in summary_lines]
|
||||
summary_token_ids = [token for sentence in summary_lines_token_ids for token in sentence]
|
||||
|
||||
return story_token_ids, summary_token_ids
|
||||
|
||||
|
||||
def compute_token_type_ids(batch, separator_token_id):
|
||||
"""Segment embeddings as described in [1]
|
||||
|
||||
The values {0,1} were found in the repository [2].
|
||||
|
||||
Attributes:
|
||||
batch: torch.Tensor, size [batch_size, block_size]
|
||||
Batch of input.
|
||||
separator_token_id: int
|
||||
The value of the token that separates the segments.
|
||||
|
||||
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
|
||||
arXiv preprint arXiv:1908.08345 (2019).
|
||||
[2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217)
|
||||
"""
|
||||
batch_embeddings = []
|
||||
for sequence in batch:
|
||||
sentence_num = -1
|
||||
embeddings = []
|
||||
for s in sequence:
|
||||
if s == separator_token_id:
|
||||
sentence_num += 1
|
||||
embeddings.append(sentence_num % 2)
|
||||
batch_embeddings.append(embeddings)
|
||||
return torch.tensor(batch_embeddings)
|
||||
1
examples/research_projects/bertology/requirements.txt
Normal file
1
examples/research_projects/bertology/requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
transformers == 3.5.1
|
||||
449
examples/research_projects/bertology/run_bertology.py
Normal file
449
examples/research_projects/bertology/run_bertology.py
Normal file
@@ -0,0 +1,449 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2018 CMU and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Bertology: this script shows how you can explore the internals of the models in the library to:
|
||||
- compute the entropy of the head attentions
|
||||
- compute the importance of each head
|
||||
- prune (remove) the low importance head.
|
||||
Some parts of this script are adapted from the code of Michel et al. (http://arxiv.org/abs/1905.10650)
|
||||
which is available at https://github.com/pmichel31415/are-16-heads-really-better-than-1
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, SequentialSampler, Subset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
GlueDataset,
|
||||
default_data_collator,
|
||||
glue_compute_metrics,
|
||||
glue_output_modes,
|
||||
glue_processors,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.trainer_utils import is_main_process
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def entropy(p):
|
||||
""" Compute the entropy of a probability distribution """
|
||||
plogp = p * torch.log(p)
|
||||
plogp[p == 0] = 0
|
||||
return -plogp.sum(dim=-1)
|
||||
|
||||
|
||||
def print_2d_tensor(tensor):
|
||||
""" Print a 2D tensor """
|
||||
logger.info("lv, h >\t" + "\t".join(f"{x + 1}" for x in range(len(tensor))))
|
||||
for row in range(len(tensor)):
|
||||
if tensor.dtype != torch.long:
|
||||
logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:.5f}" for x in tensor[row].cpu().data))
|
||||
else:
|
||||
logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:d}" for x in tensor[row].cpu().data))
|
||||
|
||||
|
||||
def compute_heads_importance(
|
||||
args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None, actually_pruned=False
|
||||
):
|
||||
"""This method shows how to compute:
|
||||
- head attention entropy
|
||||
- head importance scores according to http://arxiv.org/abs/1905.10650
|
||||
"""
|
||||
# Prepare our tensors
|
||||
n_layers, n_heads = model.config.num_hidden_layers, model.config.num_attention_heads
|
||||
head_importance = torch.zeros(n_layers, n_heads).to(args.device)
|
||||
attn_entropy = torch.zeros(n_layers, n_heads).to(args.device)
|
||||
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(n_layers, n_heads).to(args.device)
|
||||
|
||||
head_mask.requires_grad_(requires_grad=True)
|
||||
# If actually pruned attention multi-head, set head mask to None to avoid shape mismatch
|
||||
if actually_pruned:
|
||||
head_mask = None
|
||||
|
||||
preds = None
|
||||
labels = None
|
||||
tot_tokens = 0.0
|
||||
|
||||
for step, inputs in enumerate(tqdm(eval_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
|
||||
for k, v in inputs.items():
|
||||
inputs[k] = v.to(args.device)
|
||||
|
||||
# Do a forward pass (not with torch.no_grad() since we need gradients for importance score - see below)
|
||||
outputs = model(**inputs, head_mask=head_mask)
|
||||
loss, logits, all_attentions = (
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
outputs[-1],
|
||||
) # Loss and logits are the first, attention the last
|
||||
loss.backward() # Backpropagate to populate the gradients in the head mask
|
||||
|
||||
if compute_entropy:
|
||||
for layer, attn in enumerate(all_attentions):
|
||||
masked_entropy = entropy(attn.detach()) * inputs["attention_mask"].float().unsqueeze(1)
|
||||
attn_entropy[layer] += masked_entropy.sum(-1).sum(0).detach()
|
||||
|
||||
if compute_importance:
|
||||
head_importance += head_mask.grad.abs().detach()
|
||||
|
||||
# Also store our logits/labels if we want to compute metrics afterwards
|
||||
if preds is None:
|
||||
preds = logits.detach().cpu().numpy()
|
||||
labels = inputs["labels"].detach().cpu().numpy()
|
||||
else:
|
||||
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
||||
labels = np.append(labels, inputs["labels"].detach().cpu().numpy(), axis=0)
|
||||
|
||||
tot_tokens += inputs["attention_mask"].float().detach().sum().data
|
||||
|
||||
# Normalize
|
||||
attn_entropy /= tot_tokens
|
||||
head_importance /= tot_tokens
|
||||
# Layerwise importance normalization
|
||||
if not args.dont_normalize_importance_by_layer:
|
||||
exponent = 2
|
||||
norm_by_layer = torch.pow(torch.pow(head_importance, exponent).sum(-1), 1 / exponent)
|
||||
head_importance /= norm_by_layer.unsqueeze(-1) + 1e-20
|
||||
|
||||
if not args.dont_normalize_global_importance:
|
||||
head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min())
|
||||
|
||||
# Print/save matrices
|
||||
np.save(os.path.join(args.output_dir, "attn_entropy.npy"), attn_entropy.detach().cpu().numpy())
|
||||
np.save(os.path.join(args.output_dir, "head_importance.npy"), head_importance.detach().cpu().numpy())
|
||||
|
||||
logger.info("Attention entropies")
|
||||
print_2d_tensor(attn_entropy)
|
||||
logger.info("Head importance scores")
|
||||
print_2d_tensor(head_importance)
|
||||
logger.info("Head ranked by importance scores")
|
||||
head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=args.device)
|
||||
head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(
|
||||
head_importance.numel(), device=args.device
|
||||
)
|
||||
head_ranks = head_ranks.view_as(head_importance)
|
||||
print_2d_tensor(head_ranks)
|
||||
|
||||
return attn_entropy, head_importance, preds, labels
|
||||
|
||||
|
||||
def mask_heads(args, model, eval_dataloader):
|
||||
"""This method shows how to mask head (set some heads to zero), to test the effect on the network,
|
||||
based on the head importance scores, as described in Michel et al. (http://arxiv.org/abs/1905.10650)
|
||||
"""
|
||||
_, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False)
|
||||
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
||||
original_score = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
|
||||
logger.info("Pruning: original score: %f, threshold: %f", original_score, original_score * args.masking_threshold)
|
||||
|
||||
new_head_mask = torch.ones_like(head_importance)
|
||||
num_to_mask = max(1, int(new_head_mask.numel() * args.masking_amount))
|
||||
|
||||
current_score = original_score
|
||||
while current_score >= original_score * args.masking_threshold:
|
||||
head_mask = new_head_mask.clone() # save current head mask
|
||||
# heads from least important to most - keep only not-masked heads
|
||||
head_importance[head_mask == 0.0] = float("Inf")
|
||||
current_heads_to_mask = head_importance.view(-1).sort()[1]
|
||||
|
||||
if len(current_heads_to_mask) <= num_to_mask:
|
||||
break
|
||||
|
||||
# mask heads
|
||||
current_heads_to_mask = current_heads_to_mask[:num_to_mask]
|
||||
logger.info("Heads to mask: %s", str(current_heads_to_mask.tolist()))
|
||||
new_head_mask = new_head_mask.view(-1)
|
||||
new_head_mask[current_heads_to_mask] = 0.0
|
||||
new_head_mask = new_head_mask.view_as(head_mask)
|
||||
new_head_mask = new_head_mask.clone().detach()
|
||||
print_2d_tensor(new_head_mask)
|
||||
|
||||
# Compute metric and head importance again
|
||||
_, head_importance, preds, labels = compute_heads_importance(
|
||||
args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask
|
||||
)
|
||||
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
||||
current_score = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
|
||||
logger.info(
|
||||
"Masking: current score: %f, remaining heads %d (%.1f percents)",
|
||||
current_score,
|
||||
new_head_mask.sum(),
|
||||
new_head_mask.sum() / new_head_mask.numel() * 100,
|
||||
)
|
||||
|
||||
logger.info("Final head mask")
|
||||
print_2d_tensor(head_mask)
|
||||
np.save(os.path.join(args.output_dir, "head_mask.npy"), head_mask.detach().cpu().numpy())
|
||||
|
||||
return head_mask
|
||||
|
||||
|
||||
def prune_heads(args, model, eval_dataloader, head_mask):
|
||||
"""This method shows how to prune head (remove heads weights) based on
|
||||
the head importance scores as described in Michel et al. (http://arxiv.org/abs/1905.10650)
|
||||
"""
|
||||
# Try pruning and test time speedup
|
||||
# Pruning is like masking but we actually remove the masked weights
|
||||
before_time = datetime.now()
|
||||
_, _, preds, labels = compute_heads_importance(
|
||||
args, model, eval_dataloader, compute_entropy=False, compute_importance=False, head_mask=head_mask
|
||||
)
|
||||
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
||||
score_masking = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
|
||||
original_time = datetime.now() - before_time
|
||||
|
||||
original_num_params = sum(p.numel() for p in model.parameters())
|
||||
heads_to_prune = dict(
|
||||
(layer, (1 - head_mask[layer].long()).nonzero().squeeze().tolist()) for layer in range(len(head_mask))
|
||||
)
|
||||
|
||||
assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item()
|
||||
model.prune_heads(heads_to_prune)
|
||||
pruned_num_params = sum(p.numel() for p in model.parameters())
|
||||
|
||||
before_time = datetime.now()
|
||||
_, _, preds, labels = compute_heads_importance(
|
||||
args,
|
||||
model,
|
||||
eval_dataloader,
|
||||
compute_entropy=False,
|
||||
compute_importance=False,
|
||||
head_mask=None,
|
||||
actually_pruned=True,
|
||||
)
|
||||
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
||||
score_pruning = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
|
||||
new_time = datetime.now() - before_time
|
||||
|
||||
logger.info(
|
||||
"Pruning: original num of params: %.2e, after pruning %.2e (%.1f percents)",
|
||||
original_num_params,
|
||||
pruned_num_params,
|
||||
pruned_num_params / original_num_params * 100,
|
||||
)
|
||||
logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning)
|
||||
logger.info("Pruning: speed ratio (new timing / original timing): %f percents", original_time / new_time * 100)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_name",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The name of the task to train selected in the list: " + ", ".join(glue_processors.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
|
||||
# Other parameters
|
||||
parser.add_argument(
|
||||
"--config_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained config name or path if not the same as model_name_or_path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name_or_path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from huggingface.co",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_subset", type=int, default=-1, help="If > 0: limit the data to a subset of data_subset instances."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_output_dir", action="store_true", help="Whether to overwrite data in output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dont_normalize_importance_by_layer", action="store_true", help="Don't normalize importance score by layers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dont_normalize_global_importance",
|
||||
action="store_true",
|
||||
help="Don't normalize all importance scores between 0 and 1",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--try_masking", action="store_true", help="Whether to try to mask head until a threshold of accuracy."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--masking_threshold",
|
||||
default=0.9,
|
||||
type=float,
|
||||
help="masking threshold in term of metrics (stop masking when metric < threshold * original metric value).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--masking_amount", default=0.1, type=float, help="Amount to heads to masking at each masking step."
|
||||
)
|
||||
parser.add_argument("--metric_name", default="acc", type=str, help="Metric to use for head masking.")
|
||||
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after WordPiece tokenization. \n"
|
||||
"Sequences longer than this will be truncated, sequences shorter padded.",
|
||||
)
|
||||
parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")
|
||||
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
|
||||
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
|
||||
# Setup devices and distributed training
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
|
||||
else:
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
args.device = torch.device("cuda", args.local_rank)
|
||||
args.n_gpu = 1
|
||||
torch.distributed.init_process_group(backend="nccl") # Initializes the distributed backend
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
logger.info("device: {} n_gpu: {}, distributed: {}".format(args.device, args.n_gpu, bool(args.local_rank != -1)))
|
||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||
if is_main_process(args.local_rank):
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
# Set seeds
|
||||
set_seed(args.seed)
|
||||
|
||||
# Prepare GLUE task
|
||||
args.task_name = args.task_name.lower()
|
||||
if args.task_name not in glue_processors:
|
||||
raise ValueError("Task not found: %s" % (args.task_name))
|
||||
processor = glue_processors[args.task_name]()
|
||||
args.output_mode = glue_output_modes[args.task_name]
|
||||
label_list = processor.get_labels()
|
||||
num_labels = len(label_list)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
#
|
||||
# Distributed training:
|
||||
# The .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=args.task_name,
|
||||
output_attentions=True,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
|
||||
# Distributed and parallel training
|
||||
model.to(args.device)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||
)
|
||||
elif args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Print/save training arguments
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
torch.save(args, os.path.join(args.output_dir, "run_args.bin"))
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Prepare dataset for the GLUE task
|
||||
eval_dataset = GlueDataset(args, tokenizer=tokenizer, mode="dev")
|
||||
if args.data_subset > 0:
|
||||
eval_dataset = Subset(eval_dataset, list(range(min(args.data_subset, len(eval_dataset)))))
|
||||
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset, sampler=eval_sampler, batch_size=args.batch_size, collate_fn=default_data_collator
|
||||
)
|
||||
|
||||
# Compute head entropy and importance score
|
||||
compute_heads_importance(args, model, eval_dataloader)
|
||||
|
||||
# Try head masking (set heads to zero until the score goes under a threshole)
|
||||
# and head pruning (remove masked heads and see the effect on the network)
|
||||
if args.try_masking and args.masking_threshold > 0.0 and args.masking_threshold < 1.0:
|
||||
head_mask = mask_heads(args, model, eval_dataloader)
|
||||
prune_heads(args, model, eval_dataloader, head_mask)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
54
examples/research_projects/deebert/README.md
Normal file
54
examples/research_projects/deebert/README.md
Normal file
@@ -0,0 +1,54 @@
|
||||
# DeeBERT: Early Exiting for *BERT
|
||||
|
||||
This is the code base for the paper [DeeBERT: Dynamic Early Exiting for Accelerating BERT Inference](https://www.aclweb.org/anthology/2020.acl-main.204/), modified from its [original code base](https://github.com/castorini/deebert).
|
||||
|
||||
The original code base also has information for downloading sample models that we have trained in advance.
|
||||
|
||||
## Usage
|
||||
|
||||
There are three scripts in the folder which can be run directly.
|
||||
|
||||
In each script, there are several things to modify before running:
|
||||
|
||||
* `PATH_TO_DATA`: path to the GLUE dataset.
|
||||
* `--output_dir`: path for saving fine-tuned models. Default: `./saved_models`.
|
||||
* `--plot_data_dir`: path for saving evaluation results. Default: `./results`. Results are printed to stdout and also saved to `npy` files in this directory to facilitate plotting figures and further analyses.
|
||||
* `MODEL_TYPE`: bert or roberta
|
||||
* `MODEL_SIZE`: base or large
|
||||
* `DATASET`: SST-2, MRPC, RTE, QNLI, QQP, or MNLI
|
||||
|
||||
#### train_deebert.sh
|
||||
|
||||
This is for fine-tuning DeeBERT models.
|
||||
|
||||
#### eval_deebert.sh
|
||||
|
||||
This is for evaluating each exit layer for fine-tuned DeeBERT models.
|
||||
|
||||
#### entropy_eval.sh
|
||||
|
||||
This is for evaluating fine-tuned DeeBERT models, given a number of different early exit entropy thresholds.
|
||||
|
||||
|
||||
|
||||
## Citation
|
||||
|
||||
Please cite our paper if you find the resource useful:
|
||||
```
|
||||
@inproceedings{xin-etal-2020-deebert,
|
||||
title = "{D}ee{BERT}: Dynamic Early Exiting for Accelerating {BERT} Inference",
|
||||
author = "Xin, Ji and
|
||||
Tang, Raphael and
|
||||
Lee, Jaejun and
|
||||
Yu, Yaoliang and
|
||||
Lin, Jimmy",
|
||||
booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics",
|
||||
month = jul,
|
||||
year = "2020",
|
||||
address = "Online",
|
||||
publisher = "Association for Computational Linguistics",
|
||||
url = "https://www.aclweb.org/anthology/2020.acl-main.204",
|
||||
pages = "2246--2251",
|
||||
}
|
||||
```
|
||||
|
||||
33
examples/research_projects/deebert/entropy_eval.sh
Executable file
33
examples/research_projects/deebert/entropy_eval.sh
Executable file
@@ -0,0 +1,33 @@
|
||||
#!/bin/bash
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
PATH_TO_DATA=/h/xinji/projects/GLUE
|
||||
|
||||
MODEL_TYPE=bert # bert or roberta
|
||||
MODEL_SIZE=base # base or large
|
||||
DATASET=MRPC # SST-2, MRPC, RTE, QNLI, QQP, or MNLI
|
||||
|
||||
MODEL_NAME=${MODEL_TYPE}-${MODEL_SIZE}
|
||||
if [ $MODEL_TYPE = 'bert' ]
|
||||
then
|
||||
MODEL_NAME=${MODEL_NAME}-uncased
|
||||
fi
|
||||
|
||||
ENTROPIES="0 0.1 0.2 0.3 0.4 0.5 0.6 0.7"
|
||||
|
||||
for ENTROPY in $ENTROPIES; do
|
||||
python -u run_glue_deebert.py \
|
||||
--model_type $MODEL_TYPE \
|
||||
--model_name_or_path ./saved_models/${MODEL_TYPE}-${MODEL_SIZE}/$DATASET/two_stage \
|
||||
--task_name $DATASET \
|
||||
--do_eval \
|
||||
--do_lower_case \
|
||||
--data_dir $PATH_TO_DATA/$DATASET \
|
||||
--output_dir ./saved_models/${MODEL_TYPE}-${MODEL_SIZE}/$DATASET/two_stage \
|
||||
--plot_data_dir ./results/ \
|
||||
--max_seq_length 128 \
|
||||
--early_exit_entropy $ENTROPY \
|
||||
--eval_highway \
|
||||
--overwrite_cache \
|
||||
--per_gpu_eval_batch_size=1
|
||||
done
|
||||
30
examples/research_projects/deebert/eval_deebert.sh
Executable file
30
examples/research_projects/deebert/eval_deebert.sh
Executable file
@@ -0,0 +1,30 @@
|
||||
#!/bin/bash
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
PATH_TO_DATA=/h/xinji/projects/GLUE
|
||||
|
||||
MODEL_TYPE=bert # bert or roberta
|
||||
MODEL_SIZE=base # base or large
|
||||
DATASET=MRPC # SST-2, MRPC, RTE, QNLI, QQP, or MNLI
|
||||
|
||||
MODEL_NAME=${MODEL_TYPE}-${MODEL_SIZE}
|
||||
if [ $MODEL_TYPE = 'bert' ]
|
||||
then
|
||||
MODEL_NAME=${MODEL_NAME}-uncased
|
||||
fi
|
||||
|
||||
|
||||
python -u run_glue_deebert.py \
|
||||
--model_type $MODEL_TYPE \
|
||||
--model_name_or_path ./saved_models/${MODEL_TYPE}-${MODEL_SIZE}/$DATASET/two_stage \
|
||||
--task_name $DATASET \
|
||||
--do_eval \
|
||||
--do_lower_case \
|
||||
--data_dir $PATH_TO_DATA/$DATASET \
|
||||
--output_dir ./saved_models/${MODEL_TYPE}-${MODEL_SIZE}/$DATASET/two_stage \
|
||||
--plot_data_dir ./results/ \
|
||||
--max_seq_length 128 \
|
||||
--eval_each_highway \
|
||||
--eval_highway \
|
||||
--overwrite_cache \
|
||||
--per_gpu_eval_batch_size=1
|
||||
1
examples/research_projects/deebert/requirements.txt
Normal file
1
examples/research_projects/deebert/requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
transformers == 3.5.1
|
||||
730
examples/research_projects/deebert/run_glue_deebert.py
Normal file
730
examples/research_projects/deebert/run_glue_deebert.py
Normal file
@@ -0,0 +1,730 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
import transformers
|
||||
from src.modeling_highway_bert import DeeBertForSequenceClassification
|
||||
from src.modeling_highway_roberta import DeeRobertaForSequenceClassification
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
BertConfig,
|
||||
BertTokenizer,
|
||||
RobertaConfig,
|
||||
RobertaTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from transformers import glue_compute_metrics as compute_metrics
|
||||
from transformers import glue_convert_examples_to_features as convert_examples_to_features
|
||||
from transformers import glue_output_modes as output_modes
|
||||
from transformers import glue_processors as processors
|
||||
from transformers.trainer_utils import is_main_process
|
||||
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except ImportError:
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, DeeBertForSequenceClassification, BertTokenizer),
|
||||
"roberta": (RobertaConfig, DeeRobertaForSequenceClassification, RobertaTokenizer),
|
||||
}
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
def get_wanted_result(result):
|
||||
if "spearmanr" in result:
|
||||
print_result = result["spearmanr"]
|
||||
elif "f1" in result:
|
||||
print_result = result["f1"]
|
||||
elif "mcc" in result:
|
||||
print_result = result["mcc"]
|
||||
elif "acc" in result:
|
||||
print_result = result["acc"]
|
||||
else:
|
||||
raise ValueError("Primary metric unclear in the results")
|
||||
return print_result
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer, train_highway=False):
|
||||
""" Train the model """
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter()
|
||||
|
||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
if args.max_steps > 0:
|
||||
t_total = args.max_steps
|
||||
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
||||
else:
|
||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
if train_highway:
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if ("highway" in n) and (not any(nd in n for nd in no_decay))
|
||||
],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p for n, p in model.named_parameters() if ("highway" in n) and (any(nd in n for nd in no_decay))
|
||||
],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
else:
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if ("highway" not in n) and (not any(nd in n for nd in no_decay))
|
||||
],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if ("highway" not in n) and (any(nd in n for nd in no_decay))
|
||||
],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||
|
||||
# multi-gpu training (should be after apex fp16 initialization)
|
||||
if args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||
)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size
|
||||
* args.gradient_accumulation_steps
|
||||
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
global_step = 0
|
||||
tr_loss, logging_loss = 0.0, 0.0
|
||||
model.zero_grad()
|
||||
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
||||
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
model.train()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = (
|
||||
batch[2] if args.model_type in ["bert", "xlnet"] else None
|
||||
) # XLM, DistilBERT and RoBERTa don't use segment_ids
|
||||
inputs["train_highway"] = train_highway
|
||||
outputs = model(**inputs)
|
||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||
|
||||
if args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
if args.fp16:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
tr_loss += loss.item()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
if args.fp16:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step() # Update learning rate schedule
|
||||
model.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
# Log metrics
|
||||
if (
|
||||
args.local_rank == -1 and args.evaluate_during_training
|
||||
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||
results = evaluate(args, model, tokenizer)
|
||||
for key, value in results.items():
|
||||
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
||||
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
|
||||
logging_loss = tr_loss
|
||||
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
epoch_iterator.close()
|
||||
break
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
train_iterator.close()
|
||||
break
|
||||
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer.close()
|
||||
|
||||
return global_step, tr_loss / global_step
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, prefix="", output_layer=-1, eval_highway=False):
|
||||
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
||||
eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
|
||||
eval_outputs_dirs = (args.output_dir, args.output_dir + "-MM") if args.task_name == "mnli" else (args.output_dir,)
|
||||
|
||||
results = {}
|
||||
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
|
||||
eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
|
||||
|
||||
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(eval_output_dir)
|
||||
|
||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
||||
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
# multi-gpu eval
|
||||
if args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Eval!
|
||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||
logger.info(" Num examples = %d", len(eval_dataset))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
eval_loss = 0.0
|
||||
nb_eval_steps = 0
|
||||
preds = None
|
||||
out_label_ids = None
|
||||
exit_layer_counter = {(i + 1): 0 for i in range(model.num_layers)}
|
||||
st = time.time()
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
model.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = (
|
||||
batch[2] if args.model_type in ["bert", "xlnet"] else None
|
||||
) # XLM, DistilBERT and RoBERTa don't use segment_ids
|
||||
if output_layer >= 0:
|
||||
inputs["output_layer"] = output_layer
|
||||
outputs = model(**inputs)
|
||||
if eval_highway:
|
||||
exit_layer_counter[outputs[-1]] += 1
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
|
||||
eval_loss += tmp_eval_loss.mean().item()
|
||||
nb_eval_steps += 1
|
||||
if preds is None:
|
||||
preds = logits.detach().cpu().numpy()
|
||||
out_label_ids = inputs["labels"].detach().cpu().numpy()
|
||||
else:
|
||||
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
||||
out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
|
||||
eval_time = time.time() - st
|
||||
logger.info("Eval time: {}".format(eval_time))
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
if args.output_mode == "classification":
|
||||
preds = np.argmax(preds, axis=1)
|
||||
elif args.output_mode == "regression":
|
||||
preds = np.squeeze(preds)
|
||||
result = compute_metrics(eval_task, preds, out_label_ids)
|
||||
results.update(result)
|
||||
|
||||
if eval_highway:
|
||||
logger.info("Exit layer counter: {}".format(exit_layer_counter))
|
||||
actual_cost = sum([l * c for l, c in exit_layer_counter.items()])
|
||||
full_cost = len(eval_dataloader) * model.num_layers
|
||||
logger.info("Expected saving: {}".format(actual_cost / full_cost))
|
||||
if args.early_exit_entropy >= 0:
|
||||
save_fname = (
|
||||
args.plot_data_dir
|
||||
+ "/"
|
||||
+ args.model_name_or_path[2:]
|
||||
+ "/entropy_{}.npy".format(args.early_exit_entropy)
|
||||
)
|
||||
if not os.path.exists(os.path.dirname(save_fname)):
|
||||
os.makedirs(os.path.dirname(save_fname))
|
||||
print_result = get_wanted_result(result)
|
||||
np.save(save_fname, np.array([exit_layer_counter, eval_time, actual_cost / full_cost, print_result]))
|
||||
logger.info("Entropy={}\tResult={:.2f}".format(args.early_exit_entropy, 100 * print_result))
|
||||
|
||||
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results {} *****".format(prefix))
|
||||
for key in sorted(result.keys()):
|
||||
logger.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
||||
if args.local_rank not in [-1, 0] and not evaluate:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
processor = processors[task]()
|
||||
output_mode = output_modes[task]
|
||||
# Load data features from cache or dataset file
|
||||
cached_features_file = os.path.join(
|
||||
args.data_dir,
|
||||
"cached_{}_{}_{}_{}".format(
|
||||
"dev" if evaluate else "train",
|
||||
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||||
str(args.max_seq_length),
|
||||
str(task),
|
||||
),
|
||||
)
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features = torch.load(cached_features_file)
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s", args.data_dir)
|
||||
label_list = processor.get_labels()
|
||||
if task in ["mnli", "mnli-mm"] and args.model_type in ["roberta"]:
|
||||
# HACK(label indices are swapped in RoBERTa pretrained model)
|
||||
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||
examples = (
|
||||
processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
||||
)
|
||||
features = convert_examples_to_features(
|
||||
examples,
|
||||
tokenizer,
|
||||
label_list=label_list,
|
||||
max_length=args.max_seq_length,
|
||||
output_mode=output_mode,
|
||||
)
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(features, cached_features_file)
|
||||
|
||||
if args.local_rank == 0 and not evaluate:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
# Convert to Tensors and build dataset
|
||||
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
||||
|
||||
if features[0].token_type_ids is None:
|
||||
# For RoBERTa (a potential bug!)
|
||||
all_token_type_ids = torch.tensor([[0] * args.max_seq_length for f in features], dtype=torch.long)
|
||||
else:
|
||||
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
|
||||
if output_mode == "classification":
|
||||
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
||||
elif output_mode == "regression":
|
||||
all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
|
||||
|
||||
dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_name",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plot_data_dir",
|
||||
default="./plotting/",
|
||||
type=str,
|
||||
required=False,
|
||||
help="The directory to store data for plotting figures.",
|
||||
)
|
||||
|
||||
# Other parameters
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from huggingface.co",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||
parser.add_argument(
|
||||
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||
)
|
||||
parser.add_argument("--eval_each_highway", action="store_true", help="Set this flag to evaluate each highway.")
|
||||
parser.add_argument(
|
||||
"--eval_after_first_stage",
|
||||
action="store_true",
|
||||
help="Set this flag to evaluate after training only bert (not highway).",
|
||||
)
|
||||
parser.add_argument("--eval_highway", action="store_true", help="Set this flag if it's evaluating highway models")
|
||||
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument(
|
||||
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||
)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument("--early_exit_entropy", default=-1, type=float, help="Entropy threshold for early exit.")
|
||||
|
||||
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument(
|
||||
"--eval_all_checkpoints",
|
||||
action="store_true",
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||
)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||
parser.add_argument(
|
||||
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if (
|
||||
os.path.exists(args.output_dir)
|
||||
and os.listdir(args.output_dir)
|
||||
and args.do_train
|
||||
and not args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||
args.output_dir
|
||||
)
|
||||
)
|
||||
|
||||
# Setup distant debugging if needed
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
|
||||
# Setup CUDA, GPU & distributed training
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
args.n_gpu = torch.cuda.device_count()
|
||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank,
|
||||
device,
|
||||
args.n_gpu,
|
||||
bool(args.local_rank != -1),
|
||||
args.fp16,
|
||||
)
|
||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||
if is_main_process(args.local_rank):
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
|
||||
# Prepare GLUE task
|
||||
args.task_name = args.task_name.lower()
|
||||
if args.task_name not in processors:
|
||||
raise ValueError("Task not found: %s" % (args.task_name))
|
||||
processor = processors[args.task_name]()
|
||||
args.output_mode = output_modes[args.task_name]
|
||||
label_list = processor.get_labels()
|
||||
num_labels = len(label_list)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=args.task_name,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
|
||||
if args.model_type == "bert":
|
||||
model.bert.encoder.set_early_exit_entropy(args.early_exit_entropy)
|
||||
model.bert.init_highway_pooler()
|
||||
elif args.model_type == "roberta":
|
||||
model.roberta.encoder.set_early_exit_entropy(args.early_exit_entropy)
|
||||
model.roberta.init_highway_pooler()
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
model.to(args.device)
|
||||
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Training
|
||||
if args.do_train:
|
||||
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
if args.eval_after_first_stage:
|
||||
result = evaluate(args, model, tokenizer, prefix="")
|
||||
print_result = get_wanted_result(result)
|
||||
|
||||
train(args, train_dataset, model, tokenizer, train_highway=True)
|
||||
|
||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
# Create output directory if needed
|
||||
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(
|
||||
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||
)
|
||||
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
for checkpoint in checkpoints:
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
if args.model_type == "bert":
|
||||
model.bert.encoder.set_early_exit_entropy(args.early_exit_entropy)
|
||||
elif args.model_type == "roberta":
|
||||
model.roberta.encoder.set_early_exit_entropy(args.early_exit_entropy)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
model.to(args.device)
|
||||
result = evaluate(args, model, tokenizer, prefix=prefix, eval_highway=args.eval_highway)
|
||||
print_result = get_wanted_result(result)
|
||||
logger.info("Result: {}".format(print_result))
|
||||
if args.eval_each_highway:
|
||||
last_layer_results = print_result
|
||||
each_layer_results = []
|
||||
for i in range(model.num_layers):
|
||||
logger.info("\n")
|
||||
_result = evaluate(
|
||||
args, model, tokenizer, prefix=prefix, output_layer=i, eval_highway=args.eval_highway
|
||||
)
|
||||
if i + 1 < model.num_layers:
|
||||
each_layer_results.append(get_wanted_result(_result))
|
||||
each_layer_results.append(last_layer_results)
|
||||
save_fname = args.plot_data_dir + "/" + args.model_name_or_path[2:] + "/each_layer.npy"
|
||||
if not os.path.exists(os.path.dirname(save_fname)):
|
||||
os.makedirs(os.path.dirname(save_fname))
|
||||
np.save(save_fname, np.array(each_layer_results))
|
||||
info_str = "Score of each layer:"
|
||||
for i in range(model.num_layers):
|
||||
info_str += " {:.2f}".format(100 * each_layer_results[i])
|
||||
logger.info(info_str)
|
||||
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
||||
results.update(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
examples/research_projects/deebert/src/__init__.py
Normal file
0
examples/research_projects/deebert/src/__init__.py
Normal file
396
examples/research_projects/deebert/src/modeling_highway_bert.py
Normal file
396
examples/research_projects/deebert/src/modeling_highway_bert.py
Normal file
@@ -0,0 +1,396 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from transformers.models.bert.modeling_bert import (
|
||||
BERT_INPUTS_DOCSTRING,
|
||||
BERT_START_DOCSTRING,
|
||||
BertEmbeddings,
|
||||
BertLayer,
|
||||
BertPooler,
|
||||
BertPreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
def entropy(x):
|
||||
"""Calculate entropy of a pre-softmax logit Tensor"""
|
||||
exp_x = torch.exp(x)
|
||||
A = torch.sum(exp_x, dim=1) # sum of exp(x_i)
|
||||
B = torch.sum(x * exp_x, dim=1) # sum of x_i * exp(x_i)
|
||||
return torch.log(A) - B / A
|
||||
|
||||
|
||||
class DeeBertEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.highway = nn.ModuleList([BertHighway(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
self.early_exit_entropy = [-1 for _ in range(config.num_hidden_layers)]
|
||||
|
||||
def set_early_exit_entropy(self, x):
|
||||
if (type(x) is float) or (type(x) is int):
|
||||
for i in range(len(self.early_exit_entropy)):
|
||||
self.early_exit_entropy[i] = x
|
||||
else:
|
||||
self.early_exit_entropy = x
|
||||
|
||||
def init_highway_pooler(self, pooler):
|
||||
loaded_model = pooler.state_dict()
|
||||
for highway in self.highway:
|
||||
for name, param in highway.pooler.state_dict().items():
|
||||
param.copy_(loaded_model[name])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
):
|
||||
all_hidden_states = ()
|
||||
all_attentions = ()
|
||||
all_highway_exits = ()
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if self.output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
current_outputs = (hidden_states,)
|
||||
if self.output_hidden_states:
|
||||
current_outputs = current_outputs + (all_hidden_states,)
|
||||
if self.output_attentions:
|
||||
current_outputs = current_outputs + (all_attentions,)
|
||||
|
||||
highway_exit = self.highway[i](current_outputs)
|
||||
# logits, pooled_output
|
||||
|
||||
if not self.training:
|
||||
highway_logits = highway_exit[0]
|
||||
highway_entropy = entropy(highway_logits)
|
||||
highway_exit = highway_exit + (highway_entropy,) # logits, hidden_states(?), entropy
|
||||
all_highway_exits = all_highway_exits + (highway_exit,)
|
||||
|
||||
if highway_entropy < self.early_exit_entropy[i]:
|
||||
new_output = (highway_logits,) + current_outputs[1:] + (all_highway_exits,)
|
||||
raise HighwayException(new_output, i + 1)
|
||||
else:
|
||||
all_highway_exits = all_highway_exits + (highway_exit,)
|
||||
|
||||
# Add last layer
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if self.output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if self.output_attentions:
|
||||
outputs = outputs + (all_attentions,)
|
||||
|
||||
outputs = outputs + (all_highway_exits,)
|
||||
return outputs # last-layer hidden state, (all hidden states), (all attentions), all highway exits
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The Bert Model transformer with early exiting (DeeBERT). ",
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class DeeBertModel(BertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.embeddings = BertEmbeddings(config)
|
||||
self.encoder = DeeBertEncoder(config)
|
||||
self.pooler = BertPooler(config)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_highway_pooler(self):
|
||||
self.encoder.init_highway_pooler(self.pooler)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
"""Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
See base class PreTrainedModel
|
||||
"""
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
):
|
||||
r"""
|
||||
Return:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
|
||||
Last layer hidden-state of the first token of the sequence (classification token)
|
||||
further processed by a Linear layer and a Tanh activation function. The Linear
|
||||
layer weights are trained from the next sentence prediction (classification)
|
||||
objective during pre-training.
|
||||
|
||||
This output is usually *not* a good summary
|
||||
of the semantic content of the input, you're often better with averaging or pooling
|
||||
the sequence of hidden-states for the whole input sequence.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
highway_exits (:obj:`tuple(tuple(torch.Tensor))`:
|
||||
Tuple of each early exit's results (total length: number of layers)
|
||||
Each tuple is again, a tuple of length 2 - the first entry is logits and the second entry is hidden states.
|
||||
"""
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(input_shape, device=device)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(input_shape, device=device)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
|
||||
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if encoder_attention_mask.dim() == 3:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
||||
if encoder_attention_mask.dim() == 2:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
||||
|
||||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(
|
||||
dtype=next(self.parameters()).dtype
|
||||
) # fp16 compatibility
|
||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
outputs = (sequence_output, pooled_output,) + encoder_outputs[
|
||||
1:
|
||||
] # add hidden_states and attentions if they are here
|
||||
return outputs # sequence_output, pooled_output, (hidden_states), (attentions), highway exits
|
||||
|
||||
|
||||
class HighwayException(Exception):
|
||||
def __init__(self, message, exit_layer):
|
||||
self.message = message
|
||||
self.exit_layer = exit_layer # start from 1!
|
||||
|
||||
|
||||
class BertHighway(nn.Module):
|
||||
"""A module to provide a shortcut
|
||||
from (the output of one non-final BertLayer in BertEncoder) to (cross-entropy computation in BertForSequenceClassification)
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.pooler = BertPooler(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
def forward(self, encoder_outputs):
|
||||
# Pooler
|
||||
pooler_input = encoder_outputs[0]
|
||||
pooler_output = self.pooler(pooler_input)
|
||||
# "return" pooler_output
|
||||
|
||||
# BertModel
|
||||
bmodel_output = (pooler_input, pooler_output) + encoder_outputs[1:]
|
||||
# "return" bmodel_output
|
||||
|
||||
# Dropout and classification
|
||||
pooled_output = bmodel_output[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
return logits, pooled_output
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""Bert Model (with early exiting - DeeBERT) with a classifier on top,
|
||||
also takes care of multi-layer training. """,
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class DeeBertForSequenceClassification(BertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.bert = DeeBertModel(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_layer=-1,
|
||||
train_highway=False,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the sequence classification/regression loss.
|
||||
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
|
||||
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Returns:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
highway_exits (:obj:`tuple(tuple(torch.Tensor))`:
|
||||
Tuple of each early exit's results (total length: number of layers)
|
||||
Each tuple is again, a tuple of length 2 - the first entry is logits and the second entry is hidden states.
|
||||
"""
|
||||
|
||||
exit_layer = self.num_layers
|
||||
try:
|
||||
outputs = self.bert(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
# sequence_output, pooled_output, (hidden_states), (attentions), highway exits
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
except HighwayException as e:
|
||||
outputs = e.message
|
||||
exit_layer = e.exit_layer
|
||||
logits = outputs[0]
|
||||
|
||||
if not self.training:
|
||||
original_entropy = entropy(logits)
|
||||
highway_entropy = []
|
||||
highway_logits_all = []
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
# work with highway exits
|
||||
highway_losses = []
|
||||
for highway_exit in outputs[-1]:
|
||||
highway_logits = highway_exit[0]
|
||||
if not self.training:
|
||||
highway_logits_all.append(highway_logits)
|
||||
highway_entropy.append(highway_exit[2])
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
highway_loss = loss_fct(highway_logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
highway_loss = loss_fct(highway_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
highway_losses.append(highway_loss)
|
||||
|
||||
if train_highway:
|
||||
outputs = (sum(highway_losses[:-1]),) + outputs
|
||||
# exclude the final highway, of course
|
||||
else:
|
||||
outputs = (loss,) + outputs
|
||||
if not self.training:
|
||||
outputs = outputs + ((original_entropy, highway_entropy), exit_layer)
|
||||
if output_layer >= 0:
|
||||
outputs = (
|
||||
(outputs[0],) + (highway_logits_all[output_layer],) + outputs[2:]
|
||||
) # use the highway of the last layer
|
||||
|
||||
return outputs # (loss), logits, (hidden_states), (attentions), (highway_exits)
|
||||
@@ -0,0 +1,156 @@
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from transformers import RobertaConfig
|
||||
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from transformers.models.roberta.modeling_roberta import (
|
||||
ROBERTA_INPUTS_DOCSTRING,
|
||||
ROBERTA_START_DOCSTRING,
|
||||
RobertaEmbeddings,
|
||||
)
|
||||
|
||||
from .modeling_highway_bert import BertPreTrainedModel, DeeBertModel, HighwayException, entropy
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The RoBERTa Model transformer with early exiting (DeeRoBERTa). ",
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class DeeRobertaModel(DeeBertModel):
|
||||
|
||||
config_class = RobertaConfig
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.embeddings = RobertaEmbeddings(config)
|
||||
self.init_weights()
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""RoBERTa Model (with early exiting - DeeRoBERTa) with a classifier on top,
|
||||
also takes care of multi-layer training. """,
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class DeeRobertaForSequenceClassification(BertPreTrainedModel):
|
||||
|
||||
config_class = RobertaConfig
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.roberta = DeeRobertaModel(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
|
||||
|
||||
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_layer=-1,
|
||||
train_highway=False,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the sequence classification/regression loss.
|
||||
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
|
||||
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Returns:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
highway_exits (:obj:`tuple(tuple(torch.Tensor))`:
|
||||
Tuple of each early exit's results (total length: number of layers)
|
||||
Each tuple is again, a tuple of length 2 - the first entry is logits and the second entry is hidden states.
|
||||
"""
|
||||
|
||||
exit_layer = self.num_layers
|
||||
try:
|
||||
outputs = self.roberta(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
except HighwayException as e:
|
||||
outputs = e.message
|
||||
exit_layer = e.exit_layer
|
||||
logits = outputs[0]
|
||||
|
||||
if not self.training:
|
||||
original_entropy = entropy(logits)
|
||||
highway_entropy = []
|
||||
highway_logits_all = []
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
# work with highway exits
|
||||
highway_losses = []
|
||||
for highway_exit in outputs[-1]:
|
||||
highway_logits = highway_exit[0]
|
||||
if not self.training:
|
||||
highway_logits_all.append(highway_logits)
|
||||
highway_entropy.append(highway_exit[2])
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
highway_loss = loss_fct(highway_logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
highway_loss = loss_fct(highway_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
highway_losses.append(highway_loss)
|
||||
|
||||
if train_highway:
|
||||
outputs = (sum(highway_losses[:-1]),) + outputs
|
||||
# exclude the final highway, of course
|
||||
else:
|
||||
outputs = (loss,) + outputs
|
||||
if not self.training:
|
||||
outputs = outputs + ((original_entropy, highway_entropy), exit_layer)
|
||||
if output_layer >= 0:
|
||||
outputs = (
|
||||
(outputs[0],) + (highway_logits_all[output_layer],) + outputs[2:]
|
||||
) # use the highway of the last layer
|
||||
|
||||
return outputs # (loss), logits, (hidden_states), (attentions), entropy
|
||||
99
examples/research_projects/deebert/test_glue_deebert.py
Normal file
99
examples/research_projects/deebert/test_glue_deebert.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import run_glue_deebert
|
||||
from transformers.testing_utils import require_torch_non_multi_gpu_but_fix_me, slow
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def get_setup_file():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-f")
|
||||
args = parser.parse_args()
|
||||
return args.f
|
||||
|
||||
|
||||
class DeeBertTests(unittest.TestCase):
|
||||
def setup(self) -> None:
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
@slow
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_glue_deebert_train(self):
|
||||
|
||||
train_args = """
|
||||
run_glue_deebert.py
|
||||
--model_type roberta
|
||||
--model_name_or_path roberta-base
|
||||
--task_name MRPC
|
||||
--do_train
|
||||
--do_eval
|
||||
--do_lower_case
|
||||
--data_dir ./tests/fixtures/tests_samples/MRPC/
|
||||
--max_seq_length 128
|
||||
--per_gpu_eval_batch_size=1
|
||||
--per_gpu_train_batch_size=8
|
||||
--learning_rate 2e-4
|
||||
--num_train_epochs 3
|
||||
--overwrite_output_dir
|
||||
--seed 42
|
||||
--output_dir ./examples/deebert/saved_models/roberta-base/MRPC/two_stage
|
||||
--plot_data_dir ./examples/deebert/results/
|
||||
--save_steps 0
|
||||
--overwrite_cache
|
||||
--eval_after_first_stage
|
||||
""".split()
|
||||
with patch.object(sys, "argv", train_args):
|
||||
result = run_glue_deebert.main()
|
||||
for value in result.values():
|
||||
self.assertGreaterEqual(value, 0.666)
|
||||
|
||||
eval_args = """
|
||||
run_glue_deebert.py
|
||||
--model_type roberta
|
||||
--model_name_or_path ./examples/deebert/saved_models/roberta-base/MRPC/two_stage
|
||||
--task_name MRPC
|
||||
--do_eval
|
||||
--do_lower_case
|
||||
--data_dir ./tests/fixtures/tests_samples/MRPC/
|
||||
--output_dir ./examples/deebert/saved_models/roberta-base/MRPC/two_stage
|
||||
--plot_data_dir ./examples/deebert/results/
|
||||
--max_seq_length 128
|
||||
--eval_each_highway
|
||||
--eval_highway
|
||||
--overwrite_cache
|
||||
--per_gpu_eval_batch_size=1
|
||||
""".split()
|
||||
with patch.object(sys, "argv", eval_args):
|
||||
result = run_glue_deebert.main()
|
||||
for value in result.values():
|
||||
self.assertGreaterEqual(value, 0.666)
|
||||
|
||||
entropy_eval_args = """
|
||||
run_glue_deebert.py
|
||||
--model_type roberta
|
||||
--model_name_or_path ./examples/deebert/saved_models/roberta-base/MRPC/two_stage
|
||||
--task_name MRPC
|
||||
--do_eval
|
||||
--do_lower_case
|
||||
--data_dir ./tests/fixtures/tests_samples/MRPC/
|
||||
--output_dir ./examples/deebert/saved_models/roberta-base/MRPC/two_stage
|
||||
--plot_data_dir ./examples/deebert/results/
|
||||
--max_seq_length 128
|
||||
--early_exit_entropy 0.1
|
||||
--eval_highway
|
||||
--overwrite_cache
|
||||
--per_gpu_eval_batch_size=1
|
||||
""".split()
|
||||
with patch.object(sys, "argv", entropy_eval_args):
|
||||
result = run_glue_deebert.main()
|
||||
for value in result.values():
|
||||
self.assertGreaterEqual(value, 0.666)
|
||||
38
examples/research_projects/deebert/train_deebert.sh
Executable file
38
examples/research_projects/deebert/train_deebert.sh
Executable file
@@ -0,0 +1,38 @@
|
||||
#!/bin/bash
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
PATH_TO_DATA=/h/xinji/projects/GLUE
|
||||
|
||||
MODEL_TYPE=bert # bert or roberta
|
||||
MODEL_SIZE=base # base or large
|
||||
DATASET=MRPC # SST-2, MRPC, RTE, QNLI, QQP, or MNLI
|
||||
|
||||
MODEL_NAME=${MODEL_TYPE}-${MODEL_SIZE}
|
||||
EPOCHS=10
|
||||
if [ $MODEL_TYPE = 'bert' ]
|
||||
then
|
||||
EPOCHS=3
|
||||
MODEL_NAME=${MODEL_NAME}-uncased
|
||||
fi
|
||||
|
||||
|
||||
python -u run_glue_deebert.py \
|
||||
--model_type $MODEL_TYPE \
|
||||
--model_name_or_path $MODEL_NAME \
|
||||
--task_name $DATASET \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--do_lower_case \
|
||||
--data_dir $PATH_TO_DATA/$DATASET \
|
||||
--max_seq_length 128 \
|
||||
--per_gpu_eval_batch_size=1 \
|
||||
--per_gpu_train_batch_size=8 \
|
||||
--learning_rate 2e-5 \
|
||||
--num_train_epochs $EPOCHS \
|
||||
--overwrite_output_dir \
|
||||
--seed 42 \
|
||||
--output_dir ./saved_models/${MODEL_TYPE}-${MODEL_SIZE}/$DATASET/two_stage \
|
||||
--plot_data_dir ./results/ \
|
||||
--save_steps 0 \
|
||||
--overwrite_cache \
|
||||
--eval_after_first_stage
|
||||
193
examples/research_projects/distillation/README.md
Normal file
193
examples/research_projects/distillation/README.md
Normal file
@@ -0,0 +1,193 @@
|
||||
# Distil*
|
||||
|
||||
Author: @VictorSanh
|
||||
|
||||
This folder contains the original code used to train Distil* as well as examples showcasing how to use DistilBERT, DistilRoBERTa and DistilGPT2.
|
||||
|
||||
**January 20, 2020 - Bug fixing** We have recently discovered and fixed [a bug](https://github.com/huggingface/transformers/commit/48cbf267c988b56c71a2380f748a3e6092ccaed3) in the evaluation of our `run_*.py` scripts that caused the reported metrics to be over-estimated on average. We have updated all the metrics with the latest runs.
|
||||
|
||||
**December 6, 2019 - Update** We release **DistilmBERT**: 92% of `bert-base-multilingual-cased` on XNLI. The model supports 104 different languages listed [here](https://github.com/google-research/bert/blob/master/multilingual.md#list-of-languages).
|
||||
|
||||
**November 19, 2019 - Update** We release German **DistilBERT**: 98.8% of `bert-base-german-dbmdz-cased` on NER tasks.
|
||||
|
||||
**October 23, 2019 - Update** We release **DistilRoBERTa**: 95% of `RoBERTa-base`'s performance on GLUE, twice as fast as RoBERTa while being 35% smaller.
|
||||
|
||||
**October 3, 2019 - Update** We release our [NeurIPS workshop paper](https://arxiv.org/abs/1910.01108) explaining our approach on **DistilBERT**. It includes updated results and further experiments. We applied the same method to GPT2 and release the weights of **DistilGPT2**. DistilGPT2 is two times faster and 33% smaller than GPT2. **The paper supersedes our [previous blogpost](https://medium.com/huggingface/distilbert-8cf3380435b5) with a different distillation loss and better performances. Please use the paper as a reference when comparing/reporting results on DistilBERT.**
|
||||
|
||||
**September 19, 2019 - Update:** We fixed bugs in the code and released an updated version of the weights trained with a modification of the distillation loss. DistilBERT now reaches 99% of `BERT-base`'s performance on GLUE, and 86.9 F1 score on SQuAD v1.1 dev set (compared to 88.5 for `BERT-base`). We will publish a formal write-up of our approach in the near future!
|
||||
|
||||
|
||||
## What is Distil*
|
||||
|
||||
Distil* is a class of compressed models that started with DistilBERT. DistilBERT stands for Distilled-BERT. DistilBERT is a small, fast, cheap and light Transformer model based on Bert architecture. It has 40% less parameters than `bert-base-uncased`, runs 60% faster while preserving 97% of BERT's performances as measured on the GLUE language understanding benchmark. DistilBERT is trained using knowledge distillation, a technique to compress a large model called the teacher into a smaller model called the student. By distillating Bert, we obtain a smaller Transformer model that bears a lot of similarities with the original BERT model while being lighter, smaller and faster to run. DistilBERT is thus an interesting option to put large-scaled trained Transformer model into production.
|
||||
|
||||
We have applied the same method to other Transformer architectures and released the weights:
|
||||
- GPT2: on the [WikiText-103](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/) benchmark, GPT2 reaches a perplexity on the test set of 16.3 compared to 21.1 for **DistilGPT2** (after fine-tuning on the train set).
|
||||
- RoBERTa: **DistilRoBERTa** reaches 95% of `RoBERTa-base`'s performance on GLUE while being twice faster and 35% smaller.
|
||||
- German BERT: **German DistilBERT** reaches 99% of `bert-base-german-dbmdz-cased`'s performance on German NER (CoNLL-2003).
|
||||
- Multilingual BERT: **DistilmBERT** reaches 92% of Multilingual BERT's performance on XNLI while being twice faster and 25% smaller. The model supports 104 languages listed [here](https://github.com/google-research/bert/blob/master/multilingual.md#list-of-languages).
|
||||
|
||||
For more information on DistilBERT, please refer to our [NeurIPS workshop paper](https://arxiv.org/abs/1910.01108).
|
||||
|
||||
Here are the results on the dev sets of GLUE:
|
||||
|
||||
| Model | Macro-score | CoLA | MNLI | MRPC | QNLI | QQP | RTE | SST-2| STS-B| WNLI |
|
||||
| :---: | :---: | :---:| :---:| :---:| :---:| :---:| :---:| :---:| :---:| :---: |
|
||||
| BERT-base-uncased | **79.5** | 56.3 | 84.7 | 88.6 | 91.8 | 89.6 | 69.3 | 92.7 | 89.0 | 53.5 |
|
||||
| DistilBERT-base-uncased | **77.0** | 51.3 | 82.1 | 87.5 | 89.2 | 88.5 | 59.9 | 91.3 | 86.9 | 56.3 |
|
||||
| BERT-base-cased | **78.2** | 58.2 | 83.9 | 87.8 | 91.0 | 89.2 | 66.1 | 91.7 | 89.2 | 46.5 |
|
||||
| DistilBERT-base-cased | **75.9** | 47.2 | 81.5 | 85.6 | 88.2 | 87.8 | 60.6 | 90.4 | 85.5 | 56.3 |
|
||||
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
|
||||
| RoBERTa-base (reported) | **83.2**/**86.4**<sup>2</sup> | 63.6 | 87.6 | 90.2 | 92.8 | 91.9 | 78.7 | 94.8 | 91.2 | 57.7<sup>3</sup> |
|
||||
| DistilRoBERTa<sup>1</sup> | **79.0**/**82.3**<sup>2</sup> | 59.3 | 84.0 | 86.6 | 90.8 | 89.4 | 67.9 | 92.5 | 88.3 | 52.1 |
|
||||
|
||||
<sup>1</sup> We did not use the MNLI checkpoint for fine-tuning but directly perform transfer learning on the pre-trained DistilRoBERTa.
|
||||
|
||||
<sup>2</sup> Macro-score computed without WNLI.
|
||||
|
||||
<sup>3</sup> We compute this score ourselves for completeness.
|
||||
|
||||
Here are the results on the *test* sets for 6 of the languages available in XNLI. The results are computed in the zero shot setting (trained on the English portion and evaluated on the target language portion):
|
||||
|
||||
| Model | English | Spanish | Chinese | German | Arabic | Urdu |
|
||||
| :---: | :---: | :---: | :---: | :---: | :---: | :---:|
|
||||
| mBERT base cased (computed) | 82.1 | 74.6 | 69.1 | 72.3 | 66.4 | 58.5 |
|
||||
| mBERT base uncased (reported)| 81.4 | 74.3 | 63.8 | 70.5 | 62.1 | 58.3 |
|
||||
| DistilmBERT | 78.2 | 69.1 | 64.0 | 66.3 | 59.1 | 54.7 |
|
||||
|
||||
## Setup
|
||||
|
||||
This part of the library has only be tested with Python3.6+. There are few specific dependencies to install before launching a distillation, you can install them with the command `pip install -r requirements.txt`.
|
||||
|
||||
**Important note:** The training scripts have been updated to support PyTorch v1.2.0 (there are breaking changes compared to v1.1.0).
|
||||
|
||||
|
||||
## How to use DistilBERT
|
||||
|
||||
Transformers includes five pre-trained Distil* models, currently only provided for English and German (we are investigating the possibility to train and release a multilingual version of DistilBERT):
|
||||
|
||||
- `distilbert-base-uncased`: DistilBERT English language model pretrained on the same data used to pretrain Bert (concatenation of the Toronto Book Corpus and full English Wikipedia) using distillation with the supervision of the `bert-base-uncased` version of Bert. The model has 6 layers, 768 dimension and 12 heads, totalizing 66M parameters.
|
||||
- `distilbert-base-uncased-distilled-squad`: A finetuned version of `distilbert-base-uncased` finetuned using (a second step of) knowledge distillation on SQuAD 1.0. This model reaches a F1 score of 86.9 on the dev set (for comparison, Bert `bert-base-uncased` version reaches a 88.5 F1 score).
|
||||
- `distilbert-base-cased`: DistilBERT English language model pretrained on the same data used to pretrain Bert (concatenation of the Toronto Book Corpus and full English Wikipedia) using distillation with the supervision of the `bert-base-cased` version of Bert. The model has 6 layers, 768 dimension and 12 heads, totalizing 65M parameters.
|
||||
- `distilbert-base-cased-distilled-squad`: A finetuned version of `distilbert-base-cased` finetuned using (a second step of) knowledge distillation on SQuAD 1.0. This model reaches a F1 score of 87.1 on the dev set (for comparison, Bert `bert-base-cased` version reaches a 88.7 F1 score).
|
||||
- `distilbert-base-german-cased`: DistilBERT German language model pretrained on 1/2 of the data used to pretrain Bert using distillation with the supervision of the `bert-base-german-dbmdz-cased` version of German DBMDZ Bert. For NER tasks the model reaches a F1 score of 83.49 on the CoNLL-2003 test set (for comparison, `bert-base-german-dbmdz-cased` reaches a 84.52 F1 score), and a F1 score of 85.23 on the GermEval 2014 test set (`bert-base-german-dbmdz-cased` reaches a 86.89 F1 score).
|
||||
- `distilgpt2`: DistilGPT2 English language model pretrained with the supervision of `gpt2` (the smallest version of GPT2) on [OpenWebTextCorpus](https://skylion007.github.io/OpenWebTextCorpus/), a reproduction of OpenAI's WebText dataset. The model has 6 layers, 768 dimension and 12 heads, totalizing 82M parameters (compared to 124M parameters for GPT2). On average, DistilGPT2 is two times faster than GPT2.
|
||||
- `distilroberta-base`: DistilRoBERTa English language model pretrained with the supervision of `roberta-base` solely on [OpenWebTextCorpus](https://skylion007.github.io/OpenWebTextCorpus/), a reproduction of OpenAI's WebText dataset (it is ~4 times less training data than the teacher RoBERTa). The model has 6 layers, 768 dimension and 12 heads, totalizing 82M parameters (compared to 125M parameters for RoBERTa-base). On average DistilRoBERTa is twice as fast as Roberta-base.
|
||||
- `distilbert-base-multilingual-cased`: DistilmBERT multilingual model pretrained with the supervision of `bert-base-multilingual-cased` on the concatenation of Wikipedia in 104 different languages. The model supports the 104 languages listed [here](https://github.com/google-research/bert/blob/master/multilingual.md#list-of-languages). The model has 6 layers, 768 dimension and 12 heads, totalizing 134M parameters (compared to 177M parameters for mBERT-base). On average DistilmBERT is twice as fast as mBERT-base.
|
||||
|
||||
Using DistilBERT is very similar to using BERT. DistilBERT share the same tokenizer as BERT's `bert-base-uncased` even though we provide a link to this tokenizer under the `DistilBertTokenizer` name to have a consistent naming between the library models.
|
||||
|
||||
```python
|
||||
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
|
||||
model = DistilBertModel.from_pretrained('distilbert-base-cased')
|
||||
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)
|
||||
outputs = model(input_ids)
|
||||
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||
```
|
||||
|
||||
Similarly, using the other Distil* models simply consists in calling the base classes with a different pretrained checkpoint:
|
||||
- DistilBERT uncased: `model = DistilBertModel.from_pretrained('distilbert-base-uncased')`
|
||||
- DistilGPT2: `model = GPT2Model.from_pretrained('distilgpt2')`
|
||||
- DistilRoBERTa: `model = RobertaModel.from_pretrained('distilroberta-base')`
|
||||
- DistilmBERT: `model = DistilBertModel.from_pretrained('distilbert-base-multilingual-cased')`
|
||||
|
||||
|
||||
## How to train Distil*
|
||||
|
||||
In the following, we will explain how you can train DistilBERT.
|
||||
|
||||
### A. Preparing the data
|
||||
|
||||
The weights we release are trained using a concatenation of Toronto Book Corpus and English Wikipedia (same training data as the English version of BERT).
|
||||
|
||||
To avoid processing the data several time, we do it once and for all before the training. From now on, will suppose that you have a text file `dump.txt` which contains one sequence per line (a sequence being composed of one of several coherent sentences).
|
||||
|
||||
First, we will binarize the data, i.e. tokenize the data and convert each token in an index in our model's vocabulary.
|
||||
|
||||
```bash
|
||||
python scripts/binarized_data.py \
|
||||
--file_path data/dump.txt \
|
||||
--tokenizer_type bert \
|
||||
--tokenizer_name bert-base-uncased \
|
||||
--dump_file data/binarized_text
|
||||
```
|
||||
|
||||
Our implementation of masked language modeling loss follows [XLM](https://github.com/facebookresearch/XLM)'s one and smooths the probability of masking with a factor that put more emphasis on rare words. Thus we count the occurrences of each tokens in the data:
|
||||
|
||||
```bash
|
||||
python scripts/token_counts.py \
|
||||
--data_file data/binarized_text.bert-base-uncased.pickle \
|
||||
--token_counts_dump data/token_counts.bert-base-uncased.pickle \
|
||||
--vocab_size 30522
|
||||
```
|
||||
|
||||
### B. Training
|
||||
|
||||
Training with distillation is really simple once you have pre-processed the data:
|
||||
|
||||
```bash
|
||||
python train.py \
|
||||
--student_type distilbert \
|
||||
--student_config training_configs/distilbert-base-uncased.json \
|
||||
--teacher_type bert \
|
||||
--teacher_name bert-base-uncased \
|
||||
--alpha_ce 5.0 --alpha_mlm 2.0 --alpha_cos 1.0 --alpha_clm 0.0 --mlm \
|
||||
--freeze_pos_embs \
|
||||
--dump_path serialization_dir/my_first_training \
|
||||
--data_file data/binarized_text.bert-base-uncased.pickle \
|
||||
--token_counts data/token_counts.bert-base-uncased.pickle \
|
||||
--force # overwrites the `dump_path` if it already exists.
|
||||
```
|
||||
|
||||
By default, this will launch a training on a single GPU (even if more are available on the cluster). Other parameters are available in the command line, please look in `train.py` or run `python train.py --help` to list them.
|
||||
|
||||
We highly encourage you to use distributed training for training DistilBERT as the training corpus is quite large. Here's an example that runs a distributed training on a single node having 4 GPUs:
|
||||
|
||||
```bash
|
||||
export NODE_RANK=0
|
||||
export N_NODES=1
|
||||
|
||||
export N_GPU_NODE=4
|
||||
export WORLD_SIZE=4
|
||||
export MASTER_PORT=<AN_OPEN_PORT>
|
||||
export MASTER_ADDR=<I.P.>
|
||||
|
||||
pkill -f 'python -u train.py'
|
||||
|
||||
python -m torch.distributed.launch \
|
||||
--nproc_per_node=$N_GPU_NODE \
|
||||
--nnodes=$N_NODES \
|
||||
--node_rank $NODE_RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT \
|
||||
train.py \
|
||||
--force \
|
||||
--gpus $WORLD_SIZE \
|
||||
--student_type distilbert \
|
||||
--student_config training_configs/distilbert-base-uncased.json \
|
||||
--teacher_type bert \
|
||||
--teacher_name bert-base-uncased \
|
||||
--alpha_ce 0.33 --alpha_mlm 0.33 --alpha_cos 0.33 --alpha_clm 0.0 --mlm \
|
||||
--freeze_pos_embs \
|
||||
--dump_path serialization_dir/my_first_training \
|
||||
--data_file data/binarized_text.bert-base-uncased.pickle \
|
||||
--token_counts data/token_counts.bert-base-uncased.pickle
|
||||
```
|
||||
|
||||
**Tips:** Starting distilled training with good initialization of the model weights is crucial to reach decent performance. In our experiments, we initialized our model from a few layers of the teacher (Bert) itself! Please refer to `scripts/extract.py` and `scripts/extract_distilbert.py` to create a valid initialization checkpoint and use `--student_pretrained_weights` argument to use this initialization for the distilled training!
|
||||
|
||||
Happy distillation!
|
||||
|
||||
## Citation
|
||||
|
||||
If you find the resource useful, you should cite the following paper:
|
||||
|
||||
```
|
||||
@inproceedings{sanh2019distilbert,
|
||||
title={DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter},
|
||||
author={Sanh, Victor and Debut, Lysandre and Chaumond, Julien and Wolf, Thomas},
|
||||
booktitle={NeurIPS EMC^2 Workshop},
|
||||
year={2019}
|
||||
}
|
||||
```
|
||||
603
examples/research_projects/distillation/distiller.py
Normal file
603
examples/research_projects/distillation/distiller.py
Normal file
@@ -0,0 +1,603 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" The distiller to distil the student.
|
||||
Adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
|
||||
"""
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
|
||||
from lm_seqs_dataset import LmSeqsDataset
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
from utils import logger
|
||||
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except ImportError:
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
|
||||
class Distiller:
|
||||
def __init__(
|
||||
self, params: dict, dataset: LmSeqsDataset, token_probs: torch.tensor, student: nn.Module, teacher: nn.Module
|
||||
):
|
||||
logger.info("Initializing Distiller")
|
||||
self.params = params
|
||||
self.dump_path = params.dump_path
|
||||
self.multi_gpu = params.multi_gpu
|
||||
self.fp16 = params.fp16
|
||||
|
||||
self.student = student
|
||||
self.teacher = teacher
|
||||
|
||||
self.student_config = student.config
|
||||
self.vocab_size = student.config.vocab_size
|
||||
|
||||
if params.n_gpu <= 1:
|
||||
sampler = RandomSampler(dataset)
|
||||
else:
|
||||
sampler = DistributedSampler(dataset)
|
||||
|
||||
if params.group_by_size:
|
||||
groups = create_lengths_groups(lengths=dataset.lengths, k=params.max_model_input_size)
|
||||
sampler = GroupedBatchSampler(sampler=sampler, group_ids=groups, batch_size=params.batch_size)
|
||||
else:
|
||||
sampler = BatchSampler(sampler=sampler, batch_size=params.batch_size, drop_last=False)
|
||||
|
||||
self.dataloader = DataLoader(dataset=dataset, batch_sampler=sampler, collate_fn=dataset.batch_sequences)
|
||||
|
||||
self.temperature = params.temperature
|
||||
assert self.temperature > 0.0
|
||||
|
||||
self.alpha_ce = params.alpha_ce
|
||||
self.alpha_mlm = params.alpha_mlm
|
||||
self.alpha_clm = params.alpha_clm
|
||||
self.alpha_mse = params.alpha_mse
|
||||
self.alpha_cos = params.alpha_cos
|
||||
|
||||
self.mlm = params.mlm
|
||||
if self.mlm:
|
||||
logger.info("Using MLM loss for LM step.")
|
||||
self.mlm_mask_prop = params.mlm_mask_prop
|
||||
assert 0.0 <= self.mlm_mask_prop <= 1.0
|
||||
assert params.word_mask + params.word_keep + params.word_rand == 1.0
|
||||
self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand])
|
||||
self.pred_probs = self.pred_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else self.pred_probs
|
||||
self.token_probs = token_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else token_probs
|
||||
if self.fp16:
|
||||
self.pred_probs = self.pred_probs.half()
|
||||
self.token_probs = self.token_probs.half()
|
||||
else:
|
||||
logger.info("Using CLM loss for LM step.")
|
||||
|
||||
self.epoch = 0
|
||||
self.n_iter = 0
|
||||
self.n_total_iter = 0
|
||||
self.n_sequences_epoch = 0
|
||||
self.total_loss_epoch = 0
|
||||
self.last_loss = 0
|
||||
self.last_loss_ce = 0
|
||||
self.last_loss_mlm = 0
|
||||
self.last_loss_clm = 0
|
||||
if self.alpha_mse > 0.0:
|
||||
self.last_loss_mse = 0
|
||||
if self.alpha_cos > 0.0:
|
||||
self.last_loss_cos = 0
|
||||
self.last_log = 0
|
||||
|
||||
self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
|
||||
self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
||||
if self.alpha_mse > 0.0:
|
||||
self.mse_loss_fct = nn.MSELoss(reduction="sum")
|
||||
if self.alpha_cos > 0.0:
|
||||
self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")
|
||||
|
||||
logger.info("--- Initializing model optimizer")
|
||||
assert params.gradient_accumulation_steps >= 1
|
||||
self.num_steps_epoch = len(self.dataloader)
|
||||
num_train_optimization_steps = (
|
||||
int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1
|
||||
)
|
||||
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [
|
||||
p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad
|
||||
],
|
||||
"weight_decay": params.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad
|
||||
],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
logger.info(
|
||||
"------ Number of trainable parameters (student): %i"
|
||||
% sum([p.numel() for p in self.student.parameters() if p.requires_grad])
|
||||
)
|
||||
logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()]))
|
||||
self.optimizer = AdamW(
|
||||
optimizer_grouped_parameters, lr=params.learning_rate, eps=params.adam_epsilon, betas=(0.9, 0.98)
|
||||
)
|
||||
|
||||
warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
|
||||
self.scheduler = get_linear_schedule_with_warmup(
|
||||
self.optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps
|
||||
)
|
||||
|
||||
if self.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
logger.info(f"Using fp16 training: {self.params.fp16_opt_level} level")
|
||||
self.student, self.optimizer = amp.initialize(
|
||||
self.student, self.optimizer, opt_level=self.params.fp16_opt_level
|
||||
)
|
||||
self.teacher = self.teacher.half()
|
||||
|
||||
if self.multi_gpu:
|
||||
if self.fp16:
|
||||
from apex.parallel import DistributedDataParallel
|
||||
|
||||
logger.info("Using apex.parallel.DistributedDataParallel for distributed training.")
|
||||
self.student = DistributedDataParallel(self.student)
|
||||
else:
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
logger.info("Using nn.parallel.DistributedDataParallel for distributed training.")
|
||||
self.student = DistributedDataParallel(
|
||||
self.student,
|
||||
device_ids=[params.local_rank],
|
||||
output_device=params.local_rank,
|
||||
find_unused_parameters=True,
|
||||
)
|
||||
|
||||
self.is_master = params.is_master
|
||||
if self.is_master:
|
||||
logger.info("--- Initializing Tensorboard")
|
||||
self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, "log", "train"))
|
||||
self.tensorboard.add_text(tag="config/training", text_string=str(self.params), global_step=0)
|
||||
self.tensorboard.add_text(tag="config/student", text_string=str(self.student_config), global_step=0)
|
||||
|
||||
def prepare_batch_mlm(self, batch):
|
||||
"""
|
||||
Prepare the batch: from the token_ids and the lengths, compute the attention mask and the masked label for MLM.
|
||||
|
||||
Input:
|
||||
------
|
||||
batch: `Tuple`
|
||||
token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
|
||||
lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.
|
||||
|
||||
Output:
|
||||
-------
|
||||
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
|
||||
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
|
||||
mlm_labels: `torch.tensor(bs, seq_length)` - The masked language modeling labels. There is a -100 where there is nothing to predict.
|
||||
"""
|
||||
token_ids, lengths = batch
|
||||
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
|
||||
assert token_ids.size(0) == lengths.size(0)
|
||||
|
||||
attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]
|
||||
|
||||
bs, max_seq_len = token_ids.size()
|
||||
mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
|
||||
|
||||
x_prob = self.token_probs[token_ids.flatten()]
|
||||
n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
|
||||
tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False)
|
||||
pred_mask = torch.zeros(
|
||||
bs * max_seq_len, dtype=torch.bool, device=token_ids.device
|
||||
) # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
|
||||
pred_mask[tgt_ids] = 1
|
||||
pred_mask = pred_mask.view(bs, max_seq_len)
|
||||
|
||||
pred_mask[token_ids == self.params.special_tok_ids["pad_token"]] = 0
|
||||
|
||||
# mask a number of words == 0 [8] (faster with fp16)
|
||||
if self.fp16:
|
||||
n1 = pred_mask.sum().item()
|
||||
if n1 > 8:
|
||||
pred_mask = pred_mask.view(-1)
|
||||
n2 = max(n1 % 8, 8 * (n1 // 8))
|
||||
if n2 != n1:
|
||||
pred_mask[torch.nonzero(pred_mask).view(-1)[: n1 - n2]] = 0
|
||||
pred_mask = pred_mask.view(bs, max_seq_len)
|
||||
assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item()
|
||||
|
||||
_token_ids_real = token_ids[pred_mask]
|
||||
_token_ids_rand = _token_ids_real.clone().random_(self.vocab_size)
|
||||
_token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids["mask_token"])
|
||||
probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True)
|
||||
_token_ids = (
|
||||
_token_ids_mask * (probs == 0).long()
|
||||
+ _token_ids_real * (probs == 1).long()
|
||||
+ _token_ids_rand * (probs == 2).long()
|
||||
)
|
||||
token_ids = token_ids.masked_scatter(pred_mask, _token_ids)
|
||||
|
||||
mlm_labels[~pred_mask] = -100 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
|
||||
|
||||
# sanity checks
|
||||
assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size
|
||||
|
||||
return token_ids, attn_mask, mlm_labels
|
||||
|
||||
def prepare_batch_clm(self, batch):
|
||||
"""
|
||||
Prepare the batch: from the token_ids and the lengths, compute the attention mask and the labels for CLM.
|
||||
|
||||
Input:
|
||||
------
|
||||
batch: `Tuple`
|
||||
token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
|
||||
lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.
|
||||
|
||||
Output:
|
||||
-------
|
||||
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
|
||||
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
|
||||
clm_labels: `torch.tensor(bs, seq_length)` - The causal language modeling labels. There is a -100 where there is nothing to predict.
|
||||
"""
|
||||
token_ids, lengths = batch
|
||||
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
|
||||
assert token_ids.size(0) == lengths.size(0)
|
||||
|
||||
attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]
|
||||
clm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
|
||||
clm_labels[~attn_mask] = -100 # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility
|
||||
|
||||
# sanity checks
|
||||
assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size
|
||||
|
||||
return token_ids, attn_mask, clm_labels
|
||||
|
||||
def round_batch(self, x: torch.tensor, lengths: torch.tensor):
|
||||
"""
|
||||
For float16 only.
|
||||
Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.
|
||||
|
||||
Input:
|
||||
------
|
||||
x: `torch.tensor(bs, seq_length)` - The token ids.
|
||||
lengths: `torch.tensor(bs, seq_length)` - The lengths of each of the sequence in the batch.
|
||||
|
||||
Output:
|
||||
-------
|
||||
x: `torch.tensor(new_bs, new_seq_length)` - The updated token ids.
|
||||
lengths: `torch.tensor(new_bs, new_seq_length)` - The updated lengths.
|
||||
"""
|
||||
if not self.fp16 or len(lengths) < 8:
|
||||
return x, lengths
|
||||
|
||||
# number of sentences == 0 [8]
|
||||
bs1 = len(lengths)
|
||||
bs2 = 8 * (bs1 // 8)
|
||||
assert bs2 > 0 and bs2 % 8 == 0
|
||||
if bs1 != bs2:
|
||||
idx = torch.randperm(bs1)[:bs2]
|
||||
lengths = lengths[idx]
|
||||
slen = lengths.max().item()
|
||||
x = x[idx, :slen]
|
||||
else:
|
||||
idx = None
|
||||
|
||||
# sequence length == 0 [8]
|
||||
ml1 = x.size(1)
|
||||
if ml1 % 8 != 0:
|
||||
pad = 8 - (ml1 % 8)
|
||||
ml2 = ml1 + pad
|
||||
if self.mlm:
|
||||
pad_id = self.params.special_tok_ids["pad_token"]
|
||||
else:
|
||||
pad_id = self.params.special_tok_ids["unk_token"]
|
||||
padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id)
|
||||
x = torch.cat([x, padding_tensor], 1)
|
||||
assert x.size() == (bs2, ml2)
|
||||
|
||||
assert x.size(0) % 8 == 0
|
||||
assert x.size(1) % 8 == 0
|
||||
return x, lengths
|
||||
|
||||
def train(self):
|
||||
"""
|
||||
The real training loop.
|
||||
"""
|
||||
if self.is_master:
|
||||
logger.info("Starting training")
|
||||
self.last_log = time.time()
|
||||
self.student.train()
|
||||
self.teacher.eval()
|
||||
|
||||
for _ in range(self.params.n_epoch):
|
||||
if self.is_master:
|
||||
logger.info(f"--- Starting epoch {self.epoch}/{self.params.n_epoch-1}")
|
||||
if self.multi_gpu:
|
||||
torch.distributed.barrier()
|
||||
|
||||
iter_bar = tqdm(self.dataloader, desc="-Iter", disable=self.params.local_rank not in [-1, 0])
|
||||
for batch in iter_bar:
|
||||
if self.params.n_gpu > 0:
|
||||
batch = tuple(t.to(f"cuda:{self.params.local_rank}") for t in batch)
|
||||
|
||||
if self.mlm:
|
||||
token_ids, attn_mask, lm_labels = self.prepare_batch_mlm(batch=batch)
|
||||
else:
|
||||
token_ids, attn_mask, lm_labels = self.prepare_batch_clm(batch=batch)
|
||||
self.step(input_ids=token_ids, attention_mask=attn_mask, lm_labels=lm_labels)
|
||||
|
||||
iter_bar.update()
|
||||
iter_bar.set_postfix(
|
||||
{"Last_loss": f"{self.last_loss:.2f}", "Avg_cum_loss": f"{self.total_loss_epoch/self.n_iter:.2f}"}
|
||||
)
|
||||
iter_bar.close()
|
||||
|
||||
if self.is_master:
|
||||
logger.info(f"--- Ending epoch {self.epoch}/{self.params.n_epoch-1}")
|
||||
self.end_epoch()
|
||||
|
||||
if self.is_master:
|
||||
logger.info("Save very last checkpoint as `pytorch_model.bin`.")
|
||||
self.save_checkpoint(checkpoint_name="pytorch_model.bin")
|
||||
logger.info("Training is finished")
|
||||
|
||||
def step(self, input_ids: torch.tensor, attention_mask: torch.tensor, lm_labels: torch.tensor):
|
||||
"""
|
||||
One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
|
||||
and possibly a parameter update (depending on the gradient accumulation).
|
||||
|
||||
Input:
|
||||
------
|
||||
input_ids: `torch.tensor(bs, seq_length)` - The token ids.
|
||||
attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
|
||||
lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
|
||||
"""
|
||||
if self.mlm:
|
||||
s_logits, s_hidden_states = self.student(
|
||||
input_ids=input_ids, attention_mask=attention_mask
|
||||
) # (bs, seq_length, voc_size)
|
||||
with torch.no_grad():
|
||||
t_logits, t_hidden_states = self.teacher(
|
||||
input_ids=input_ids, attention_mask=attention_mask
|
||||
) # (bs, seq_length, voc_size)
|
||||
else:
|
||||
s_logits, _, s_hidden_states = self.student(
|
||||
input_ids=input_ids, attention_mask=None
|
||||
) # (bs, seq_length, voc_size)
|
||||
with torch.no_grad():
|
||||
t_logits, _, t_hidden_states = self.teacher(
|
||||
input_ids=input_ids, attention_mask=None
|
||||
) # (bs, seq_length, voc_size)
|
||||
assert s_logits.size() == t_logits.size()
|
||||
|
||||
# https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
|
||||
# https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
|
||||
if self.params.restrict_ce_to_mask:
|
||||
mask = (lm_labels > -1).unsqueeze(-1).expand_as(s_logits) # (bs, seq_length, voc_size)
|
||||
else:
|
||||
mask = attention_mask.unsqueeze(-1).expand_as(s_logits) # (bs, seq_length, voc_size)
|
||||
s_logits_slct = torch.masked_select(s_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||
s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
|
||||
t_logits_slct = torch.masked_select(t_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||
t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
|
||||
assert t_logits_slct.size() == s_logits_slct.size()
|
||||
|
||||
loss_ce = (
|
||||
self.ce_loss_fct(
|
||||
F.log_softmax(s_logits_slct / self.temperature, dim=-1),
|
||||
F.softmax(t_logits_slct / self.temperature, dim=-1),
|
||||
)
|
||||
* (self.temperature) ** 2
|
||||
)
|
||||
loss = self.alpha_ce * loss_ce
|
||||
|
||||
if self.alpha_mlm > 0.0:
|
||||
loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)), lm_labels.view(-1))
|
||||
loss += self.alpha_mlm * loss_mlm
|
||||
if self.alpha_clm > 0.0:
|
||||
shift_logits = s_logits[..., :-1, :].contiguous()
|
||||
shift_labels = lm_labels[..., 1:].contiguous()
|
||||
loss_clm = self.lm_loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
loss += self.alpha_clm * loss_clm
|
||||
|
||||
if self.alpha_mse > 0.0:
|
||||
loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct) / s_logits_slct.size(
|
||||
0
|
||||
) # Reproducing batchmean reduction
|
||||
loss += self.alpha_mse * loss_mse
|
||||
if self.alpha_cos > 0.0:
|
||||
s_hidden_states = s_hidden_states[-1] # (bs, seq_length, dim)
|
||||
t_hidden_states = t_hidden_states[-1] # (bs, seq_length, dim)
|
||||
mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states) # (bs, seq_length, dim)
|
||||
assert s_hidden_states.size() == t_hidden_states.size()
|
||||
dim = s_hidden_states.size(-1)
|
||||
|
||||
s_hidden_states_slct = torch.masked_select(s_hidden_states, mask) # (bs * seq_length * dim)
|
||||
s_hidden_states_slct = s_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim)
|
||||
t_hidden_states_slct = torch.masked_select(t_hidden_states, mask) # (bs * seq_length * dim)
|
||||
t_hidden_states_slct = t_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim)
|
||||
|
||||
target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1) # (bs * seq_length,)
|
||||
loss_cos = self.cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target)
|
||||
loss += self.alpha_cos * loss_cos
|
||||
|
||||
self.total_loss_epoch += loss.item()
|
||||
self.last_loss = loss.item()
|
||||
self.last_loss_ce = loss_ce.item()
|
||||
if self.alpha_mlm > 0.0:
|
||||
self.last_loss_mlm = loss_mlm.item()
|
||||
if self.alpha_clm > 0.0:
|
||||
self.last_loss_clm = loss_clm.item()
|
||||
if self.alpha_mse > 0.0:
|
||||
self.last_loss_mse = loss_mse.item()
|
||||
if self.alpha_cos > 0.0:
|
||||
self.last_loss_cos = loss_cos.item()
|
||||
|
||||
self.optimize(loss)
|
||||
|
||||
self.n_sequences_epoch += input_ids.size(0)
|
||||
|
||||
def optimize(self, loss):
|
||||
"""
|
||||
Normalization on the loss (gradient accumulation or distributed training), followed by
|
||||
backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
|
||||
Also update the metrics for tensorboard.
|
||||
"""
|
||||
# Check for NaN
|
||||
if (loss != loss).data.any():
|
||||
logger.error("NaN detected")
|
||||
exit()
|
||||
|
||||
if self.multi_gpu:
|
||||
loss = loss.mean()
|
||||
if self.params.gradient_accumulation_steps > 1:
|
||||
loss = loss / self.params.gradient_accumulation_steps
|
||||
|
||||
if self.fp16:
|
||||
from apex import amp
|
||||
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
self.iter()
|
||||
if self.n_iter % self.params.gradient_accumulation_steps == 0:
|
||||
if self.fp16:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm)
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.scheduler.step()
|
||||
|
||||
def iter(self):
|
||||
"""
|
||||
Update global counts, write to tensorboard and save checkpoint.
|
||||
"""
|
||||
self.n_iter += 1
|
||||
self.n_total_iter += 1
|
||||
|
||||
if self.n_total_iter % self.params.log_interval == 0:
|
||||
self.log_tensorboard()
|
||||
self.last_log = time.time()
|
||||
if self.n_total_iter % self.params.checkpoint_interval == 0:
|
||||
self.save_checkpoint()
|
||||
|
||||
def log_tensorboard(self):
|
||||
"""
|
||||
Log into tensorboard. Only by the master process.
|
||||
"""
|
||||
if not self.is_master:
|
||||
return
|
||||
|
||||
for param_name, param in self.student.named_parameters():
|
||||
self.tensorboard.add_scalar(
|
||||
tag="parameter_mean/" + param_name, scalar_value=param.data.mean(), global_step=self.n_total_iter
|
||||
)
|
||||
self.tensorboard.add_scalar(
|
||||
tag="parameter_std/" + param_name, scalar_value=param.data.std(), global_step=self.n_total_iter
|
||||
)
|
||||
if param.grad is None:
|
||||
continue
|
||||
self.tensorboard.add_scalar(
|
||||
tag="grad_mean/" + param_name, scalar_value=param.grad.data.mean(), global_step=self.n_total_iter
|
||||
)
|
||||
self.tensorboard.add_scalar(
|
||||
tag="grad_std/" + param_name, scalar_value=param.grad.data.std(), global_step=self.n_total_iter
|
||||
)
|
||||
|
||||
self.tensorboard.add_scalar(
|
||||
tag="losses/cum_avg_loss_epoch",
|
||||
scalar_value=self.total_loss_epoch / self.n_iter,
|
||||
global_step=self.n_total_iter,
|
||||
)
|
||||
self.tensorboard.add_scalar(tag="losses/loss", scalar_value=self.last_loss, global_step=self.n_total_iter)
|
||||
self.tensorboard.add_scalar(
|
||||
tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter
|
||||
)
|
||||
if self.alpha_mlm > 0.0:
|
||||
self.tensorboard.add_scalar(
|
||||
tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter
|
||||
)
|
||||
if self.alpha_clm > 0.0:
|
||||
self.tensorboard.add_scalar(
|
||||
tag="losses/loss_clm", scalar_value=self.last_loss_clm, global_step=self.n_total_iter
|
||||
)
|
||||
if self.alpha_mse > 0.0:
|
||||
self.tensorboard.add_scalar(
|
||||
tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter
|
||||
)
|
||||
if self.alpha_cos > 0.0:
|
||||
self.tensorboard.add_scalar(
|
||||
tag="losses/loss_cos", scalar_value=self.last_loss_cos, global_step=self.n_total_iter
|
||||
)
|
||||
self.tensorboard.add_scalar(
|
||||
tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter
|
||||
)
|
||||
|
||||
self.tensorboard.add_scalar(
|
||||
tag="global/memory_usage",
|
||||
scalar_value=psutil.virtual_memory()._asdict()["used"] / 1_000_000,
|
||||
global_step=self.n_total_iter,
|
||||
)
|
||||
self.tensorboard.add_scalar(
|
||||
tag="global/speed", scalar_value=time.time() - self.last_log, global_step=self.n_total_iter
|
||||
)
|
||||
|
||||
def end_epoch(self):
|
||||
"""
|
||||
Finally arrived at the end of epoch (full pass on dataset).
|
||||
Do some tensorboard logging and checkpoint saving.
|
||||
"""
|
||||
logger.info(f"{self.n_sequences_epoch} sequences have been trained during this epoch.")
|
||||
|
||||
if self.is_master:
|
||||
self.save_checkpoint(checkpoint_name=f"model_epoch_{self.epoch}.pth")
|
||||
self.tensorboard.add_scalar(
|
||||
tag="epoch/loss", scalar_value=self.total_loss_epoch / self.n_iter, global_step=self.epoch
|
||||
)
|
||||
|
||||
self.epoch += 1
|
||||
self.n_sequences_epoch = 0
|
||||
self.n_iter = 0
|
||||
self.total_loss_epoch = 0
|
||||
|
||||
def save_checkpoint(self, checkpoint_name: str = "checkpoint.pth"):
|
||||
"""
|
||||
Save the current state. Only by the master process.
|
||||
"""
|
||||
if not self.is_master:
|
||||
return
|
||||
mdl_to_save = self.student.module if hasattr(self.student, "module") else self.student
|
||||
mdl_to_save.config.save_pretrained(self.dump_path)
|
||||
state_dict = mdl_to_save.state_dict()
|
||||
torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))
|
||||
108
examples/research_projects/distillation/grouped_batch_sampler.py
Normal file
108
examples/research_projects/distillation/grouped_batch_sampler.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Adapted from PyTorch Vision (https://github.com/pytorch/vision/blob/master/references/detection/group_by_aspect_ratio.py)
|
||||
"""
|
||||
import bisect
|
||||
import copy
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data.sampler import BatchSampler, Sampler
|
||||
|
||||
from utils import logger
|
||||
|
||||
|
||||
def _quantize(x, bins):
|
||||
bins = copy.deepcopy(bins)
|
||||
bins = sorted(bins)
|
||||
quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
|
||||
return quantized
|
||||
|
||||
|
||||
def create_lengths_groups(lengths, k=0):
|
||||
bins = np.arange(start=3, stop=k, step=4).tolist() if k > 0 else [10]
|
||||
groups = _quantize(lengths, bins)
|
||||
# count number of elements per group
|
||||
counts = np.unique(groups, return_counts=True)[1]
|
||||
fbins = [0] + bins + [np.inf]
|
||||
logger.info("Using {} as bins for aspect lengths quantization".format(fbins))
|
||||
logger.info("Count of instances per bin: {}".format(counts))
|
||||
return groups
|
||||
|
||||
|
||||
class GroupedBatchSampler(BatchSampler):
|
||||
"""
|
||||
Wraps another sampler to yield a mini-batch of indices.
|
||||
It enforces that the batch only contain elements from the same group.
|
||||
It also tries to provide mini-batches which follows an ordering which is
|
||||
as close as possible to the ordering from the original sampler.
|
||||
Arguments:
|
||||
sampler (Sampler): Base sampler.
|
||||
group_ids (list[int]): If the sampler produces indices in range [0, N),
|
||||
`group_ids` must be a list of `N` ints which contains the group id of each sample.
|
||||
The group ids must be a continuous set of integers starting from
|
||||
0, i.e. they must be in the range [0, num_groups).
|
||||
batch_size (int): Size of mini-batch.
|
||||
"""
|
||||
|
||||
def __init__(self, sampler, group_ids, batch_size):
|
||||
if not isinstance(sampler, Sampler):
|
||||
raise ValueError(
|
||||
"sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}".format(sampler)
|
||||
)
|
||||
self.sampler = sampler
|
||||
self.group_ids = group_ids
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __iter__(self):
|
||||
buffer_per_group = defaultdict(list)
|
||||
samples_per_group = defaultdict(list)
|
||||
|
||||
num_batches = 0
|
||||
for idx in self.sampler:
|
||||
group_id = self.group_ids[idx]
|
||||
buffer_per_group[group_id].append(idx)
|
||||
samples_per_group[group_id].append(idx)
|
||||
if len(buffer_per_group[group_id]) == self.batch_size:
|
||||
yield buffer_per_group[group_id] # TODO
|
||||
num_batches += 1
|
||||
del buffer_per_group[group_id]
|
||||
assert len(buffer_per_group[group_id]) < self.batch_size
|
||||
|
||||
# now we have run out of elements that satisfy
|
||||
# the group criteria, let's return the remaining
|
||||
# elements so that the size of the sampler is
|
||||
# deterministic
|
||||
expected_num_batches = len(self)
|
||||
num_remaining = expected_num_batches - num_batches
|
||||
if num_remaining > 0:
|
||||
# for the remaining batches, group the batches by similar lengths
|
||||
batch_idx = []
|
||||
for group_id, idxs in sorted(buffer_per_group.items(), key=lambda x: x[0]):
|
||||
batch_idx.extend(idxs)
|
||||
if len(batch_idx) >= self.batch_size:
|
||||
yield batch_idx[: self.batch_size]
|
||||
batch_idx = batch_idx[self.batch_size :]
|
||||
num_remaining -= 1
|
||||
if len(batch_idx) > 0:
|
||||
yield batch_idx
|
||||
num_remaining -= 1
|
||||
assert num_remaining == 0
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Return the number of mini-batches rather than the number of samples.
|
||||
"""
|
||||
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
|
||||
166
examples/research_projects/distillation/lm_seqs_dataset.py
Normal file
166
examples/research_projects/distillation/lm_seqs_dataset.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Dataset to distilled models
|
||||
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
|
||||
"""
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from utils import logger
|
||||
|
||||
|
||||
class LmSeqsDataset(Dataset):
|
||||
"""Custom Dataset wrapping language modeling sequences.
|
||||
|
||||
Each sample will be retrieved by indexing the list of token_ids and their corresponding lengths.
|
||||
|
||||
Input:
|
||||
------
|
||||
params: `NameSpace` parameters
|
||||
data: `List[np.array[int]]
|
||||
"""
|
||||
|
||||
def __init__(self, params, data):
|
||||
self.params = params
|
||||
|
||||
self.token_ids = np.array(data)
|
||||
self.lengths = np.array([len(t) for t in data])
|
||||
|
||||
self.check()
|
||||
self.remove_long_sequences()
|
||||
self.remove_empty_sequences()
|
||||
self.remove_unknown_sequences()
|
||||
self.check()
|
||||
self.print_statistics()
|
||||
|
||||
def __getitem__(self, index):
|
||||
return (self.token_ids[index], self.lengths[index])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.lengths)
|
||||
|
||||
def check(self):
|
||||
"""
|
||||
Some sanity checks
|
||||
"""
|
||||
assert len(self.token_ids) == len(self.lengths)
|
||||
assert all(self.lengths[i] == len(self.token_ids[i]) for i in range(len(self.lengths)))
|
||||
|
||||
def remove_long_sequences(self):
|
||||
"""
|
||||
Sequences that are too long are split by chunk of max_model_input_size.
|
||||
"""
|
||||
max_len = self.params.max_model_input_size
|
||||
indices = self.lengths > max_len
|
||||
logger.info(f"Splitting {sum(indices)} too long sequences.")
|
||||
|
||||
def divide_chunks(l, n):
|
||||
return [l[i : i + n] for i in range(0, len(l), n)]
|
||||
|
||||
new_tok_ids = []
|
||||
new_lengths = []
|
||||
if self.params.mlm:
|
||||
cls_id, sep_id = self.params.special_tok_ids["cls_token"], self.params.special_tok_ids["sep_token"]
|
||||
else:
|
||||
cls_id, sep_id = self.params.special_tok_ids["bos_token"], self.params.special_tok_ids["eos_token"]
|
||||
|
||||
for seq_, len_ in zip(self.token_ids, self.lengths):
|
||||
assert (seq_[0] == cls_id) and (seq_[-1] == sep_id), seq_
|
||||
if len_ <= max_len:
|
||||
new_tok_ids.append(seq_)
|
||||
new_lengths.append(len_)
|
||||
else:
|
||||
sub_seqs = []
|
||||
for sub_s in divide_chunks(seq_, max_len - 2):
|
||||
if sub_s[0] != cls_id:
|
||||
sub_s = np.insert(sub_s, 0, cls_id)
|
||||
if sub_s[-1] != sep_id:
|
||||
sub_s = np.insert(sub_s, len(sub_s), sep_id)
|
||||
assert len(sub_s) <= max_len
|
||||
assert (sub_s[0] == cls_id) and (sub_s[-1] == sep_id), sub_s
|
||||
sub_seqs.append(sub_s)
|
||||
|
||||
new_tok_ids.extend(sub_seqs)
|
||||
new_lengths.extend([len(l) for l in sub_seqs])
|
||||
|
||||
self.token_ids = np.array(new_tok_ids)
|
||||
self.lengths = np.array(new_lengths)
|
||||
|
||||
def remove_empty_sequences(self):
|
||||
"""
|
||||
Too short sequences are simply removed. This could be tuned.
|
||||
"""
|
||||
init_size = len(self)
|
||||
indices = self.lengths > 11
|
||||
self.token_ids = self.token_ids[indices]
|
||||
self.lengths = self.lengths[indices]
|
||||
new_size = len(self)
|
||||
logger.info(f"Remove {init_size - new_size} too short (<=11 tokens) sequences.")
|
||||
|
||||
def remove_unknown_sequences(self):
|
||||
"""
|
||||
Remove sequences with a (too) high level of unknown tokens.
|
||||
"""
|
||||
if "unk_token" not in self.params.special_tok_ids:
|
||||
return
|
||||
else:
|
||||
unk_token_id = self.params.special_tok_ids["unk_token"]
|
||||
init_size = len(self)
|
||||
unk_occs = np.array([np.count_nonzero(a == unk_token_id) for a in self.token_ids])
|
||||
indices = (unk_occs / self.lengths) < 0.5
|
||||
self.token_ids = self.token_ids[indices]
|
||||
self.lengths = self.lengths[indices]
|
||||
new_size = len(self)
|
||||
logger.info(f"Remove {init_size - new_size} sequences with a high level of unknown tokens (50%).")
|
||||
|
||||
def print_statistics(self):
|
||||
"""
|
||||
Print some statistics on the corpus. Only the master process.
|
||||
"""
|
||||
if not self.params.is_master:
|
||||
return
|
||||
logger.info(f"{len(self)} sequences")
|
||||
# data_len = sum(self.lengths)
|
||||
# nb_unique_tokens = len(Counter(list(chain(*self.token_ids))))
|
||||
# logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)')
|
||||
|
||||
# unk_idx = self.params.special_tok_ids['unk_token']
|
||||
# nb_unknown = sum([(t==unk_idx).sum() for t in self.token_ids])
|
||||
# logger.info(f'{nb_unknown} unknown tokens (covering {100*nb_unknown/data_len:.2f}% of the data)')
|
||||
|
||||
def batch_sequences(self, batch):
|
||||
"""
|
||||
Do the padding and transform into torch.tensor.
|
||||
"""
|
||||
token_ids = [t[0] for t in batch]
|
||||
lengths = [t[1] for t in batch]
|
||||
assert len(token_ids) == len(lengths)
|
||||
|
||||
# Max for paddings
|
||||
max_seq_len_ = max(lengths)
|
||||
|
||||
# Pad token ids
|
||||
if self.params.mlm:
|
||||
pad_idx = self.params.special_tok_ids["pad_token"]
|
||||
else:
|
||||
pad_idx = self.params.special_tok_ids["unk_token"]
|
||||
tk_ = [list(t.astype(int)) + [pad_idx] * (max_seq_len_ - len(t)) for t in token_ids]
|
||||
assert len(tk_) == len(token_ids)
|
||||
assert all(len(t) == max_seq_len_ for t in tk_)
|
||||
|
||||
tk_t = torch.tensor(tk_) # (bs, max_seq_len_)
|
||||
lg_t = torch.tensor(lengths) # (bs)
|
||||
return tk_t, lg_t
|
||||
7
examples/research_projects/distillation/requirements.txt
Normal file
7
examples/research_projects/distillation/requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
transformers
|
||||
|
||||
gitpython==3.0.2
|
||||
tensorboard>=1.14.0
|
||||
tensorboardX==1.8
|
||||
psutil==5.6.6
|
||||
scipy>=1.4.1
|
||||
@@ -0,0 +1,872 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" This is the exact same script as `examples/question-answering/run_squad.py` (as of 2020, January 8th) with an additional and optional step of distillation."""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import timeit
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
BertConfig,
|
||||
BertForQuestionAnswering,
|
||||
BertTokenizer,
|
||||
DistilBertConfig,
|
||||
DistilBertForQuestionAnswering,
|
||||
DistilBertTokenizer,
|
||||
RobertaConfig,
|
||||
RobertaForQuestionAnswering,
|
||||
RobertaTokenizer,
|
||||
XLMConfig,
|
||||
XLMForQuestionAnswering,
|
||||
XLMTokenizer,
|
||||
XLNetConfig,
|
||||
XLNetForQuestionAnswering,
|
||||
XLNetTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
squad_convert_examples_to_features,
|
||||
)
|
||||
from transformers.data.metrics.squad_metrics import (
|
||||
compute_predictions_log_probs,
|
||||
compute_predictions_logits,
|
||||
squad_evaluate,
|
||||
)
|
||||
from transformers.data.processors.squad import SquadResult, SquadV1Processor, SquadV2Processor
|
||||
from transformers.trainer_utils import is_main_process
|
||||
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except ImportError:
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
||||
"xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
|
||||
"xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
|
||||
"distilbert": (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
|
||||
"roberta": (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer),
|
||||
}
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
def to_list(tensor):
|
||||
return tensor.detach().cpu().tolist()
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
""" Train the model """
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter()
|
||||
|
||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
if args.max_steps > 0:
|
||||
t_total = args.max_steps
|
||||
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
||||
else:
|
||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
|
||||
os.path.join(args.model_name_or_path, "scheduler.pt")
|
||||
):
|
||||
# Load in optimizer and scheduler states
|
||||
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
||||
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
|
||||
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||
|
||||
# multi-gpu training (should be after apex fp16 initialization)
|
||||
if args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||
)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size
|
||||
* args.gradient_accumulation_steps
|
||||
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
global_step = 1
|
||||
epochs_trained = 0
|
||||
steps_trained_in_current_epoch = 0
|
||||
# Check if continuing training from a checkpoint
|
||||
if os.path.exists(args.model_name_or_path):
|
||||
try:
|
||||
# set global_step to gobal_step of last saved checkpoint from model path
|
||||
checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
|
||||
global_step = int(checkpoint_suffix)
|
||||
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||
|
||||
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
||||
logger.info(" Continuing training from epoch %d", epochs_trained)
|
||||
logger.info(" Continuing training from global step %d", global_step)
|
||||
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
||||
except ValueError:
|
||||
logger.info(" Starting fine-tuning.")
|
||||
|
||||
tr_loss, logging_loss = 0.0, 0.0
|
||||
model.zero_grad()
|
||||
train_iterator = trange(
|
||||
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
|
||||
)
|
||||
# Added here for reproductibility
|
||||
set_seed(args)
|
||||
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
|
||||
# Skip past any already trained steps if resuming training
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
steps_trained_in_current_epoch -= 1
|
||||
continue
|
||||
|
||||
model.train()
|
||||
if teacher is not None:
|
||||
teacher.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
|
||||
inputs = {
|
||||
"input_ids": batch[0],
|
||||
"attention_mask": batch[1],
|
||||
"start_positions": batch[3],
|
||||
"end_positions": batch[4],
|
||||
}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2]
|
||||
if args.model_type in ["xlnet", "xlm"]:
|
||||
inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
|
||||
if args.version_2_with_negative:
|
||||
inputs.update({"is_impossible": batch[7]})
|
||||
outputs = model(**inputs)
|
||||
loss, start_logits_stu, end_logits_stu = outputs
|
||||
|
||||
# Distillation loss
|
||||
if teacher is not None:
|
||||
if "token_type_ids" not in inputs:
|
||||
inputs["token_type_ids"] = None if args.teacher_type == "xlm" else batch[2]
|
||||
with torch.no_grad():
|
||||
start_logits_tea, end_logits_tea = teacher(
|
||||
input_ids=inputs["input_ids"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
)
|
||||
assert start_logits_tea.size() == start_logits_stu.size()
|
||||
assert end_logits_tea.size() == end_logits_stu.size()
|
||||
|
||||
loss_fct = nn.KLDivLoss(reduction="batchmean")
|
||||
loss_start = (
|
||||
loss_fct(
|
||||
F.log_softmax(start_logits_stu / args.temperature, dim=-1),
|
||||
F.softmax(start_logits_tea / args.temperature, dim=-1),
|
||||
)
|
||||
* (args.temperature ** 2)
|
||||
)
|
||||
loss_end = (
|
||||
loss_fct(
|
||||
F.log_softmax(end_logits_stu / args.temperature, dim=-1),
|
||||
F.softmax(end_logits_tea / args.temperature, dim=-1),
|
||||
)
|
||||
* (args.temperature ** 2)
|
||||
)
|
||||
loss_ce = (loss_start + loss_end) / 2.0
|
||||
|
||||
loss = args.alpha_ce * loss_ce + args.alpha_squad * loss
|
||||
|
||||
if args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
if args.fp16:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
tr_loss += loss.item()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
if args.fp16:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step() # Update learning rate schedule
|
||||
model.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
# Log metrics
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
# Only evaluate when single GPU otherwise metrics may not average well
|
||||
if args.local_rank == -1 and args.evaluate_during_training:
|
||||
results = evaluate(args, model, tokenizer)
|
||||
for key, value in results.items():
|
||||
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
||||
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
|
||||
logging_loss = tr_loss
|
||||
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
epoch_iterator.close()
|
||||
break
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
train_iterator.close()
|
||||
break
|
||||
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer.close()
|
||||
|
||||
return global_step, tr_loss / global_step
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, prefix=""):
|
||||
dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
|
||||
|
||||
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(dataset)
|
||||
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
# multi-gpu evaluate
|
||||
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Eval!
|
||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||
logger.info(" Num examples = %d", len(dataset))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
|
||||
all_results = []
|
||||
start_time = timeit.default_timer()
|
||||
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
model.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2] # XLM don't use segment_ids
|
||||
example_indices = batch[3]
|
||||
if args.model_type in ["xlnet", "xlm"]:
|
||||
inputs.update({"cls_index": batch[4], "p_mask": batch[5]})
|
||||
|
||||
outputs = model(**inputs)
|
||||
|
||||
for i, example_index in enumerate(example_indices):
|
||||
eval_feature = features[example_index.item()]
|
||||
unique_id = int(eval_feature.unique_id)
|
||||
|
||||
output = [to_list(output[i]) for output in outputs]
|
||||
|
||||
# Some models (XLNet, XLM) use 5 arguments for their predictions, while the other "simpler"
|
||||
# models only use two.
|
||||
if len(output) >= 5:
|
||||
start_logits = output[0]
|
||||
start_top_index = output[1]
|
||||
end_logits = output[2]
|
||||
end_top_index = output[3]
|
||||
cls_logits = output[4]
|
||||
|
||||
result = SquadResult(
|
||||
unique_id,
|
||||
start_logits,
|
||||
end_logits,
|
||||
start_top_index=start_top_index,
|
||||
end_top_index=end_top_index,
|
||||
cls_logits=cls_logits,
|
||||
)
|
||||
|
||||
else:
|
||||
start_logits, end_logits = output
|
||||
result = SquadResult(unique_id, start_logits, end_logits)
|
||||
|
||||
all_results.append(result)
|
||||
|
||||
evalTime = timeit.default_timer() - start_time
|
||||
logger.info(" Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset))
|
||||
|
||||
# Compute predictions
|
||||
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
|
||||
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
|
||||
|
||||
if args.version_2_with_negative:
|
||||
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
|
||||
else:
|
||||
output_null_log_odds_file = None
|
||||
|
||||
if args.model_type in ["xlnet", "xlm"]:
|
||||
# XLNet uses a more complex post-processing procedure
|
||||
predictions = compute_predictions_log_probs(
|
||||
examples,
|
||||
features,
|
||||
all_results,
|
||||
args.n_best_size,
|
||||
args.max_answer_length,
|
||||
output_prediction_file,
|
||||
output_nbest_file,
|
||||
output_null_log_odds_file,
|
||||
model.config.start_n_top,
|
||||
model.config.end_n_top,
|
||||
args.version_2_with_negative,
|
||||
tokenizer,
|
||||
args.verbose_logging,
|
||||
)
|
||||
else:
|
||||
predictions = compute_predictions_logits(
|
||||
examples,
|
||||
features,
|
||||
all_results,
|
||||
args.n_best_size,
|
||||
args.max_answer_length,
|
||||
args.do_lower_case,
|
||||
output_prediction_file,
|
||||
output_nbest_file,
|
||||
output_null_log_odds_file,
|
||||
args.verbose_logging,
|
||||
args.version_2_with_negative,
|
||||
args.null_score_diff_threshold,
|
||||
tokenizer,
|
||||
)
|
||||
|
||||
# Compute the F1 and exact scores.
|
||||
results = squad_evaluate(examples, predictions)
|
||||
return results
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
|
||||
if args.local_rank not in [-1, 0] and not evaluate:
|
||||
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
torch.distributed.barrier()
|
||||
|
||||
# Load data features from cache or dataset file
|
||||
input_file = args.predict_file if evaluate else args.train_file
|
||||
cached_features_file = os.path.join(
|
||||
os.path.dirname(input_file),
|
||||
"cached_distillation_{}_{}_{}".format(
|
||||
"dev" if evaluate else "train",
|
||||
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||||
str(args.max_seq_length),
|
||||
),
|
||||
)
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features_and_dataset = torch.load(cached_features_file)
|
||||
|
||||
try:
|
||||
features, dataset, examples = (
|
||||
features_and_dataset["features"],
|
||||
features_and_dataset["dataset"],
|
||||
features_and_dataset["examples"],
|
||||
)
|
||||
except KeyError:
|
||||
raise DeprecationWarning(
|
||||
"You seem to be loading features from an older version of this script please delete the "
|
||||
"file %s in order for it to be created again" % cached_features_file
|
||||
)
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s", input_file)
|
||||
processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
|
||||
if evaluate:
|
||||
examples = processor.get_dev_examples(args.data_dir, filename=args.predict_file)
|
||||
else:
|
||||
examples = processor.get_train_examples(args.data_dir, filename=args.train_file)
|
||||
|
||||
features, dataset = squad_convert_examples_to_features(
|
||||
examples=examples,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=args.max_seq_length,
|
||||
doc_stride=args.doc_stride,
|
||||
max_query_length=args.max_query_length,
|
||||
is_training=not evaluate,
|
||||
return_dataset="pt",
|
||||
threads=args.threads,
|
||||
)
|
||||
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file)
|
||||
|
||||
if args.local_rank == 0 and not evaluate:
|
||||
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
torch.distributed.barrier()
|
||||
|
||||
if output_examples:
|
||||
return dataset, examples, features
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model checkpoints and predictions will be written.",
|
||||
)
|
||||
|
||||
# Distillation parameters (optional)
|
||||
parser.add_argument(
|
||||
"--teacher_type",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--teacher_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to the already SQuAD fine-tuned teacher model. Only for distillation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha_ce", default=0.5, type=float, help="Distillation loss linear weight. Only for distillation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha_squad", default=0.5, type=float, help="True SQuAD loss linear weight. Only for distillation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature", default=2.0, type=float, help="Distillation temperature. Only for distillation."
|
||||
)
|
||||
|
||||
# Other parameters
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="The input data dir. Should contain the .json files for the task."
|
||||
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_file",
|
||||
default=None,
|
||||
type=str,
|
||||
help="The input training file. If a data dir is specified, will look for the file there"
|
||||
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--predict_file",
|
||||
default=None,
|
||||
type=str,
|
||||
help="The input evaluation file. If a data dir is specified, will look for the file there"
|
||||
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from huggingface.co",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--version_2_with_negative",
|
||||
action="store_true",
|
||||
help="If true, the SQuAD examples contain some that do not have an answer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--null_score_diff_threshold",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="If null_score - best_non_null is greater than the threshold predict null.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=384,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
|
||||
"longer than this will be truncated, and sequences shorter than this will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--doc_stride",
|
||||
default=128,
|
||||
type=int,
|
||||
help="When splitting up a long document into chunks, how much stride to take between chunks.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_query_length",
|
||||
default=64,
|
||||
type=int,
|
||||
help="The maximum number of tokens for the question. Questions longer than this will "
|
||||
"be truncated to this length.",
|
||||
)
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||
parser.add_argument(
|
||||
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||
)
|
||||
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument(
|
||||
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||
)
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||
)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument(
|
||||
"--n_best_size",
|
||||
default=20,
|
||||
type=int,
|
||||
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_answer_length",
|
||||
default=30,
|
||||
type=int,
|
||||
help="The maximum length of an answer that can be generated. This is needed because the start "
|
||||
"and end predictions are not conditioned on one another.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose_logging",
|
||||
action="store_true",
|
||||
help="If true, all of the warnings related to data processing will be printed. "
|
||||
"A number of warnings are expected for a normal SQuAD evaluation.",
|
||||
)
|
||||
|
||||
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument(
|
||||
"--eval_all_checkpoints",
|
||||
action="store_true",
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||
)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
|
||||
parser.add_argument(
|
||||
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
||||
|
||||
parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features")
|
||||
args = parser.parse_args()
|
||||
|
||||
if (
|
||||
os.path.exists(args.output_dir)
|
||||
and os.listdir(args.output_dir)
|
||||
and args.do_train
|
||||
and not args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||
args.output_dir
|
||||
)
|
||||
)
|
||||
|
||||
# Setup distant debugging if needed
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
|
||||
# Setup CUDA, GPU & distributed training
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
|
||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank,
|
||||
device,
|
||||
args.n_gpu,
|
||||
bool(args.local_rank != -1),
|
||||
args.fp16,
|
||||
)
|
||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||
if is_main_process(args.local_rank):
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
if args.local_rank not in [-1, 0]:
|
||||
# Make sure only the first process in distributed training will download model & vocab
|
||||
torch.distributed.barrier()
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
|
||||
if args.teacher_type is not None:
|
||||
assert args.teacher_name_or_path is not None
|
||||
assert args.alpha_ce > 0.0
|
||||
assert args.alpha_ce + args.alpha_squad > 0.0
|
||||
assert args.teacher_type != "distilbert", "We constraint teachers not to be of type DistilBERT."
|
||||
teacher_config_class, teacher_model_class, _ = MODEL_CLASSES[args.teacher_type]
|
||||
teacher_config = teacher_config_class.from_pretrained(
|
||||
args.teacher_name_or_path, cache_dir=args.cache_dir if args.cache_dir else None
|
||||
)
|
||||
teacher = teacher_model_class.from_pretrained(
|
||||
args.teacher_name_or_path, config=teacher_config, cache_dir=args.cache_dir if args.cache_dir else None
|
||||
)
|
||||
teacher.to(args.device)
|
||||
else:
|
||||
teacher = None
|
||||
|
||||
if args.local_rank == 0:
|
||||
# Make sure only the first process in distributed training will download model & vocab
|
||||
torch.distributed.barrier()
|
||||
|
||||
model.to(args.device)
|
||||
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
|
||||
# Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
|
||||
# remove the need for this code, but it is still valid.
|
||||
if args.fp16:
|
||||
try:
|
||||
import apex
|
||||
|
||||
apex.amp.register_half_function(torch, "einsum")
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
|
||||
# Training
|
||||
if args.do_train:
|
||||
train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer, teacher=teacher)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
# Save the trained model and the tokenizer
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
if args.do_train:
|
||||
logger.info("Loading checkpoints saved during training for evaluation")
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(
|
||||
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||
)
|
||||
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
# Reload the model
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluate
|
||||
result = evaluate(args, model, tokenizer, prefix=global_step)
|
||||
|
||||
result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
|
||||
results.update(result)
|
||||
|
||||
logger.info("Results: {}".format(results))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,96 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Preprocessing script before distillation.
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import pickle
|
||||
import random
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import BertTokenizer, GPT2Tokenizer, RobertaTokenizer
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids)."
|
||||
)
|
||||
parser.add_argument("--file_path", type=str, default="data/dump.txt", help="The path to the data.")
|
||||
parser.add_argument("--tokenizer_type", type=str, default="bert", choices=["bert", "roberta", "gpt2"])
|
||||
parser.add_argument("--tokenizer_name", type=str, default="bert-base-uncased", help="The tokenizer to use.")
|
||||
parser.add_argument("--dump_file", type=str, default="data/dump", help="The dump file prefix.")
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info(f"Loading Tokenizer ({args.tokenizer_name})")
|
||||
if args.tokenizer_type == "bert":
|
||||
tokenizer = BertTokenizer.from_pretrained(args.tokenizer_name)
|
||||
bos = tokenizer.special_tokens_map["cls_token"] # `[CLS]`
|
||||
sep = tokenizer.special_tokens_map["sep_token"] # `[SEP]`
|
||||
elif args.tokenizer_type == "roberta":
|
||||
tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name)
|
||||
bos = tokenizer.special_tokens_map["cls_token"] # `<s>`
|
||||
sep = tokenizer.special_tokens_map["sep_token"] # `</s>`
|
||||
elif args.tokenizer_type == "gpt2":
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(args.tokenizer_name)
|
||||
bos = tokenizer.special_tokens_map["bos_token"] # `<|endoftext|>`
|
||||
sep = tokenizer.special_tokens_map["eos_token"] # `<|endoftext|>`
|
||||
|
||||
logger.info(f"Loading text from {args.file_path}")
|
||||
with open(args.file_path, "r", encoding="utf8") as fp:
|
||||
data = fp.readlines()
|
||||
|
||||
logger.info("Start encoding")
|
||||
logger.info(f"{len(data)} examples to process.")
|
||||
|
||||
rslt = []
|
||||
iter = 0
|
||||
interval = 10000
|
||||
start = time.time()
|
||||
for text in data:
|
||||
text = f"{bos} {text.strip()} {sep}"
|
||||
token_ids = tokenizer.encode(text, add_special_tokens=False)
|
||||
rslt.append(token_ids)
|
||||
|
||||
iter += 1
|
||||
if iter % interval == 0:
|
||||
end = time.time()
|
||||
logger.info(f"{iter} examples processed. - {(end-start):.2f}s/{interval}expl")
|
||||
start = time.time()
|
||||
logger.info("Finished binarization")
|
||||
logger.info(f"{len(data)} examples processed.")
|
||||
|
||||
dp_file = f"{args.dump_file}.{args.tokenizer_name}.pickle"
|
||||
vocab_size = tokenizer.vocab_size
|
||||
if vocab_size < (1 << 16):
|
||||
rslt_ = [np.uint16(d) for d in rslt]
|
||||
else:
|
||||
rslt_ = [np.int32(d) for d in rslt]
|
||||
random.shuffle(rslt_)
|
||||
logger.info(f"Dump to {dp_file}")
|
||||
with open(dp_file, "wb") as handle:
|
||||
pickle.dump(rslt_, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
102
examples/research_projects/distillation/scripts/extract.py
Normal file
102
examples/research_projects/distillation/scripts/extract.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Preprocessing script before training the distilled model.
|
||||
Specific to RoBERTa -> DistilRoBERTa and GPT2 -> DistilGPT2.
|
||||
"""
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import GPT2LMHeadModel, RobertaForMaskedLM
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned Distillation"
|
||||
)
|
||||
parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"])
|
||||
parser.add_argument("--model_name", default="roberta-large", type=str)
|
||||
parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_roberta_048131723.pth", type=str)
|
||||
parser.add_argument("--vocab_transform", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model_type == "roberta":
|
||||
model = RobertaForMaskedLM.from_pretrained(args.model_name)
|
||||
prefix = "roberta"
|
||||
elif args.model_type == "gpt2":
|
||||
model = GPT2LMHeadModel.from_pretrained(args.model_name)
|
||||
prefix = "transformer"
|
||||
|
||||
state_dict = model.state_dict()
|
||||
compressed_sd = {}
|
||||
|
||||
# Embeddings #
|
||||
if args.model_type == "gpt2":
|
||||
for param_name in ["wte.weight", "wpe.weight"]:
|
||||
compressed_sd[f"{prefix}.{param_name}"] = state_dict[f"{prefix}.{param_name}"]
|
||||
else:
|
||||
for w in ["word_embeddings", "position_embeddings", "token_type_embeddings"]:
|
||||
param_name = f"{prefix}.embeddings.{w}.weight"
|
||||
compressed_sd[param_name] = state_dict[param_name]
|
||||
for w in ["weight", "bias"]:
|
||||
param_name = f"{prefix}.embeddings.LayerNorm.{w}"
|
||||
compressed_sd[param_name] = state_dict[param_name]
|
||||
|
||||
# Transformer Blocks #
|
||||
std_idx = 0
|
||||
for teacher_idx in [0, 2, 4, 7, 9, 11]:
|
||||
if args.model_type == "gpt2":
|
||||
for layer in ["ln_1", "attn.c_attn", "attn.c_proj", "ln_2", "mlp.c_fc", "mlp.c_proj"]:
|
||||
for w in ["weight", "bias"]:
|
||||
compressed_sd[f"{prefix}.h.{std_idx}.{layer}.{w}"] = state_dict[
|
||||
f"{prefix}.h.{teacher_idx}.{layer}.{w}"
|
||||
]
|
||||
compressed_sd[f"{prefix}.h.{std_idx}.attn.bias"] = state_dict[f"{prefix}.h.{teacher_idx}.attn.bias"]
|
||||
else:
|
||||
for layer in [
|
||||
"attention.self.query",
|
||||
"attention.self.key",
|
||||
"attention.self.value",
|
||||
"attention.output.dense",
|
||||
"attention.output.LayerNorm",
|
||||
"intermediate.dense",
|
||||
"output.dense",
|
||||
"output.LayerNorm",
|
||||
]:
|
||||
for w in ["weight", "bias"]:
|
||||
compressed_sd[f"{prefix}.encoder.layer.{std_idx}.{layer}.{w}"] = state_dict[
|
||||
f"{prefix}.encoder.layer.{teacher_idx}.{layer}.{w}"
|
||||
]
|
||||
std_idx += 1
|
||||
|
||||
# Language Modeling Head ###s
|
||||
if args.model_type == "roberta":
|
||||
for layer in ["lm_head.decoder.weight", "lm_head.bias"]:
|
||||
compressed_sd[f"{layer}"] = state_dict[f"{layer}"]
|
||||
if args.vocab_transform:
|
||||
for w in ["weight", "bias"]:
|
||||
compressed_sd[f"lm_head.dense.{w}"] = state_dict[f"lm_head.dense.{w}"]
|
||||
compressed_sd[f"lm_head.layer_norm.{w}"] = state_dict[f"lm_head.layer_norm.{w}"]
|
||||
elif args.model_type == "gpt2":
|
||||
for w in ["weight", "bias"]:
|
||||
compressed_sd[f"{prefix}.ln_f.{w}"] = state_dict[f"{prefix}.ln_f.{w}"]
|
||||
compressed_sd["lm_head.weight"] = state_dict["lm_head.weight"]
|
||||
|
||||
print(f"N layers selected for distillation: {std_idx}")
|
||||
print(f"Number of params transferred for distillation: {len(compressed_sd.keys())}")
|
||||
|
||||
print(f"Save transferred checkpoint to {args.dump_checkpoint}.")
|
||||
torch.save(compressed_sd, args.dump_checkpoint)
|
||||
@@ -0,0 +1,92 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Preprocessing script before training DistilBERT.
|
||||
Specific to BERT -> DistilBERT.
|
||||
"""
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BertForMaskedLM
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation"
|
||||
)
|
||||
parser.add_argument("--model_type", default="bert", choices=["bert"])
|
||||
parser.add_argument("--model_name", default="bert-base-uncased", type=str)
|
||||
parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_bert-base-uncased_0247911.pth", type=str)
|
||||
parser.add_argument("--vocab_transform", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model_type == "bert":
|
||||
model = BertForMaskedLM.from_pretrained(args.model_name)
|
||||
prefix = "bert"
|
||||
else:
|
||||
raise ValueError('args.model_type should be "bert".')
|
||||
|
||||
state_dict = model.state_dict()
|
||||
compressed_sd = {}
|
||||
|
||||
for w in ["word_embeddings", "position_embeddings"]:
|
||||
compressed_sd[f"distilbert.embeddings.{w}.weight"] = state_dict[f"{prefix}.embeddings.{w}.weight"]
|
||||
for w in ["weight", "bias"]:
|
||||
compressed_sd[f"distilbert.embeddings.LayerNorm.{w}"] = state_dict[f"{prefix}.embeddings.LayerNorm.{w}"]
|
||||
|
||||
std_idx = 0
|
||||
for teacher_idx in [0, 2, 4, 7, 9, 11]:
|
||||
for w in ["weight", "bias"]:
|
||||
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}"] = state_dict[
|
||||
f"{prefix}.encoder.layer.{teacher_idx}.attention.self.query.{w}"
|
||||
]
|
||||
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}"] = state_dict[
|
||||
f"{prefix}.encoder.layer.{teacher_idx}.attention.self.key.{w}"
|
||||
]
|
||||
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}"] = state_dict[
|
||||
f"{prefix}.encoder.layer.{teacher_idx}.attention.self.value.{w}"
|
||||
]
|
||||
|
||||
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}"] = state_dict[
|
||||
f"{prefix}.encoder.layer.{teacher_idx}.attention.output.dense.{w}"
|
||||
]
|
||||
compressed_sd[f"distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}"] = state_dict[
|
||||
f"{prefix}.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}"
|
||||
]
|
||||
|
||||
compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}"] = state_dict[
|
||||
f"{prefix}.encoder.layer.{teacher_idx}.intermediate.dense.{w}"
|
||||
]
|
||||
compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}"] = state_dict[
|
||||
f"{prefix}.encoder.layer.{teacher_idx}.output.dense.{w}"
|
||||
]
|
||||
compressed_sd[f"distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}"] = state_dict[
|
||||
f"{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}"
|
||||
]
|
||||
std_idx += 1
|
||||
|
||||
compressed_sd["vocab_projector.weight"] = state_dict["cls.predictions.decoder.weight"]
|
||||
compressed_sd["vocab_projector.bias"] = state_dict["cls.predictions.bias"]
|
||||
if args.vocab_transform:
|
||||
for w in ["weight", "bias"]:
|
||||
compressed_sd[f"vocab_transform.{w}"] = state_dict[f"cls.predictions.transform.dense.{w}"]
|
||||
compressed_sd[f"vocab_layer_norm.{w}"] = state_dict[f"cls.predictions.transform.LayerNorm.{w}"]
|
||||
|
||||
print(f"N layers selected for distillation: {std_idx}")
|
||||
print(f"Number of params transferred for distillation: {len(compressed_sd.keys())}")
|
||||
|
||||
print(f"Save transferred checkpoint to {args.dump_checkpoint}.")
|
||||
torch.save(compressed_sd, args.dump_checkpoint)
|
||||
@@ -0,0 +1,56 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Preprocessing script before training the distilled model.
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import pickle
|
||||
from collections import Counter
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Token Counts for smoothing the masking probabilities in MLM (cf XLM/word2vec)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_file", type=str, default="data/dump.bert-base-uncased.pickle", help="The binarized dataset."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token_counts_dump", type=str, default="data/token_counts.bert-base-uncased.pickle", help="The dump file."
|
||||
)
|
||||
parser.add_argument("--vocab_size", default=30522, type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info(f"Loading data from {args.data_file}")
|
||||
with open(args.data_file, "rb") as fp:
|
||||
data = pickle.load(fp)
|
||||
|
||||
logger.info("Counting occurences for MLM.")
|
||||
counter = Counter()
|
||||
for tk_ids in data:
|
||||
counter.update(tk_ids)
|
||||
counts = [0] * args.vocab_size
|
||||
for k, v in counter.items():
|
||||
counts[k] = v
|
||||
|
||||
logger.info(f"Dump to {args.token_counts_dump}")
|
||||
with open(args.token_counts_dump, "wb") as handle:
|
||||
pickle.dump(counts, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
322
examples/research_projects/distillation/train.py
Normal file
322
examples/research_projects/distillation/train.py
Normal file
@@ -0,0 +1,322 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Training the distilled model.
|
||||
Supported architectures include: BERT -> DistilBERT, RoBERTa -> DistilRoBERTa, GPT2 -> DistilGPT2.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from distiller import Distiller
|
||||
from lm_seqs_dataset import LmSeqsDataset
|
||||
from transformers import (
|
||||
BertConfig,
|
||||
BertForMaskedLM,
|
||||
BertTokenizer,
|
||||
DistilBertConfig,
|
||||
DistilBertForMaskedLM,
|
||||
DistilBertTokenizer,
|
||||
GPT2Config,
|
||||
GPT2LMHeadModel,
|
||||
GPT2Tokenizer,
|
||||
RobertaConfig,
|
||||
RobertaForMaskedLM,
|
||||
RobertaTokenizer,
|
||||
)
|
||||
from utils import git_log, init_gpu_params, logger, set_seed
|
||||
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
|
||||
"roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
|
||||
"bert": (BertConfig, BertForMaskedLM, BertTokenizer),
|
||||
"gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
||||
}
|
||||
|
||||
|
||||
def sanity_checks(args):
|
||||
"""
|
||||
A bunch of args sanity checks to perform even starting...
|
||||
"""
|
||||
assert (args.mlm and args.alpha_mlm > 0.0) or (not args.mlm and args.alpha_mlm == 0.0)
|
||||
assert (args.alpha_mlm > 0.0 and args.alpha_clm == 0.0) or (args.alpha_mlm == 0.0 and args.alpha_clm > 0.0)
|
||||
if args.mlm:
|
||||
assert os.path.isfile(args.token_counts)
|
||||
assert (args.student_type in ["roberta", "distilbert"]) and (args.teacher_type in ["roberta", "bert"])
|
||||
else:
|
||||
assert (args.student_type in ["gpt2"]) and (args.teacher_type in ["gpt2"])
|
||||
|
||||
assert args.teacher_type == args.student_type or (
|
||||
args.student_type == "distilbert" and args.teacher_type == "bert"
|
||||
)
|
||||
assert os.path.isfile(args.student_config)
|
||||
if args.student_pretrained_weights is not None:
|
||||
assert os.path.isfile(args.student_pretrained_weights)
|
||||
|
||||
if args.freeze_token_type_embds:
|
||||
assert args.student_type in ["roberta"]
|
||||
|
||||
assert args.alpha_ce >= 0.0
|
||||
assert args.alpha_mlm >= 0.0
|
||||
assert args.alpha_clm >= 0.0
|
||||
assert args.alpha_mse >= 0.0
|
||||
assert args.alpha_cos >= 0.0
|
||||
assert args.alpha_ce + args.alpha_mlm + args.alpha_clm + args.alpha_mse + args.alpha_cos > 0.0
|
||||
|
||||
|
||||
def freeze_pos_embeddings(student, args):
|
||||
if args.student_type == "roberta":
|
||||
student.roberta.embeddings.position_embeddings.weight.requires_grad = False
|
||||
elif args.student_type == "gpt2":
|
||||
student.transformer.wpe.weight.requires_grad = False
|
||||
|
||||
|
||||
def freeze_token_type_embeddings(student, args):
|
||||
if args.student_type == "roberta":
|
||||
student.roberta.embeddings.token_type_embeddings.weight.requires_grad = False
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Training")
|
||||
parser.add_argument("--force", action="store_true", help="Overwrite dump_path if it already exists.")
|
||||
|
||||
parser.add_argument(
|
||||
"--dump_path", type=str, required=True, help="The output directory (log, checkpoints, parameters, etc.)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The binarized file (tokenized + tokens_to_ids) and grouped by sequence.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--student_type",
|
||||
type=str,
|
||||
choices=["distilbert", "roberta", "gpt2"],
|
||||
required=True,
|
||||
help="The student type (DistilBERT, RoBERTa).",
|
||||
)
|
||||
parser.add_argument("--student_config", type=str, required=True, help="Path to the student configuration.")
|
||||
parser.add_argument(
|
||||
"--student_pretrained_weights", default=None, type=str, help="Load student initialization checkpoint."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--teacher_type", choices=["bert", "roberta", "gpt2"], required=True, help="Teacher type (BERT, RoBERTa)."
|
||||
)
|
||||
parser.add_argument("--teacher_name", type=str, required=True, help="The teacher model.")
|
||||
|
||||
parser.add_argument("--temperature", default=2.0, type=float, help="Temperature for the softmax temperature.")
|
||||
parser.add_argument(
|
||||
"--alpha_ce", default=0.5, type=float, help="Linear weight for the distillation loss. Must be >=0."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha_mlm",
|
||||
default=0.0,
|
||||
type=float,
|
||||
help="Linear weight for the MLM loss. Must be >=0. Should be used in coonjunction with `mlm` flag.",
|
||||
)
|
||||
parser.add_argument("--alpha_clm", default=0.5, type=float, help="Linear weight for the CLM loss. Must be >=0.")
|
||||
parser.add_argument("--alpha_mse", default=0.0, type=float, help="Linear weight of the MSE loss. Must be >=0.")
|
||||
parser.add_argument(
|
||||
"--alpha_cos", default=0.0, type=float, help="Linear weight of the cosine embedding loss. Must be >=0."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mlm", action="store_true", help="The LM step: MLM or CLM. If `mlm` is True, the MLM is used over CLM."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlm_mask_prop",
|
||||
default=0.15,
|
||||
type=float,
|
||||
help="Proportion of tokens for which we need to make a prediction.",
|
||||
)
|
||||
parser.add_argument("--word_mask", default=0.8, type=float, help="Proportion of tokens to mask out.")
|
||||
parser.add_argument("--word_keep", default=0.1, type=float, help="Proportion of tokens to keep.")
|
||||
parser.add_argument("--word_rand", default=0.1, type=float, help="Proportion of tokens to randomly replace.")
|
||||
parser.add_argument(
|
||||
"--mlm_smoothing",
|
||||
default=0.7,
|
||||
type=float,
|
||||
help="Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).",
|
||||
)
|
||||
parser.add_argument("--token_counts", type=str, help="The token counts in the data_file for MLM.")
|
||||
|
||||
parser.add_argument(
|
||||
"--restrict_ce_to_mask",
|
||||
action="store_true",
|
||||
help="If true, compute the distilation loss only the [MLM] prediction distribution.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--freeze_pos_embs",
|
||||
action="store_true",
|
||||
help="Freeze positional embeddings during distillation. For student_type in ['roberta', 'gpt2'] only.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--freeze_token_type_embds",
|
||||
action="store_true",
|
||||
help="Freeze token type embeddings during distillation if existent. For student_type in ['roberta'] only.",
|
||||
)
|
||||
|
||||
parser.add_argument("--n_epoch", type=int, default=3, help="Number of pass on the whole dataset.")
|
||||
parser.add_argument("--batch_size", type=int, default=5, help="Batch size (for each process).")
|
||||
parser.add_argument(
|
||||
"--group_by_size",
|
||||
action="store_false",
|
||||
help="If true, group sequences that have similar length into the same batch. Default is true.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Gradient accumulation for larger training batches.",
|
||||
)
|
||||
parser.add_argument("--warmup_prop", default=0.05, type=float, help="Linear warmup proportion.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
|
||||
parser.add_argument("--learning_rate", default=5e-4, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-6, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=5.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument("--initializer_range", default=0.02, type=float, help="Random initialization range.")
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
parser.add_argument("--gpus", type=int, default=1, help="Number of GPUs in the node.")
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="Distributed training - Local rank")
|
||||
parser.add_argument("--seed", type=int, default=56, help="Random seed")
|
||||
|
||||
parser.add_argument("--log_interval", type=int, default=500, help="Tensorboard logging interval.")
|
||||
parser.add_argument("--checkpoint_interval", type=int, default=4000, help="Checkpoint interval.")
|
||||
args = parser.parse_args()
|
||||
sanity_checks(args)
|
||||
|
||||
# ARGS #
|
||||
init_gpu_params(args)
|
||||
set_seed(args)
|
||||
if args.is_master:
|
||||
if os.path.exists(args.dump_path):
|
||||
if not args.force:
|
||||
raise ValueError(
|
||||
f"Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite it"
|
||||
"Use `--force` if you want to overwrite it"
|
||||
)
|
||||
else:
|
||||
shutil.rmtree(args.dump_path)
|
||||
|
||||
if not os.path.exists(args.dump_path):
|
||||
os.makedirs(args.dump_path)
|
||||
logger.info(f"Experiment will be dumped and logged in {args.dump_path}")
|
||||
|
||||
# SAVE PARAMS #
|
||||
logger.info(f"Param: {args}")
|
||||
with open(os.path.join(args.dump_path, "parameters.json"), "w") as f:
|
||||
json.dump(vars(args), f, indent=4)
|
||||
git_log(args.dump_path)
|
||||
|
||||
student_config_class, student_model_class, _ = MODEL_CLASSES[args.student_type]
|
||||
teacher_config_class, teacher_model_class, teacher_tokenizer_class = MODEL_CLASSES[args.teacher_type]
|
||||
|
||||
# TOKENIZER #
|
||||
tokenizer = teacher_tokenizer_class.from_pretrained(args.teacher_name)
|
||||
special_tok_ids = {}
|
||||
for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
|
||||
idx = tokenizer.all_special_tokens.index(tok_symbol)
|
||||
special_tok_ids[tok_name] = tokenizer.all_special_ids[idx]
|
||||
logger.info(f"Special tokens {special_tok_ids}")
|
||||
args.special_tok_ids = special_tok_ids
|
||||
args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name]
|
||||
|
||||
# DATA LOADER #
|
||||
logger.info(f"Loading data from {args.data_file}")
|
||||
with open(args.data_file, "rb") as fp:
|
||||
data = pickle.load(fp)
|
||||
|
||||
if args.mlm:
|
||||
logger.info(f"Loading token counts from {args.token_counts} (already pre-computed)")
|
||||
with open(args.token_counts, "rb") as fp:
|
||||
counts = pickle.load(fp)
|
||||
|
||||
token_probs = np.maximum(counts, 1) ** -args.mlm_smoothing
|
||||
for idx in special_tok_ids.values():
|
||||
token_probs[idx] = 0.0 # do not predict special tokens
|
||||
token_probs = torch.from_numpy(token_probs)
|
||||
else:
|
||||
token_probs = None
|
||||
|
||||
train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)
|
||||
logger.info("Data loader created.")
|
||||
|
||||
# STUDENT #
|
||||
logger.info(f"Loading student config from {args.student_config}")
|
||||
stu_architecture_config = student_config_class.from_pretrained(args.student_config)
|
||||
stu_architecture_config.output_hidden_states = True
|
||||
|
||||
if args.student_pretrained_weights is not None:
|
||||
logger.info(f"Loading pretrained weights from {args.student_pretrained_weights}")
|
||||
student = student_model_class.from_pretrained(args.student_pretrained_weights, config=stu_architecture_config)
|
||||
else:
|
||||
student = student_model_class(stu_architecture_config)
|
||||
|
||||
if args.n_gpu > 0:
|
||||
student.to(f"cuda:{args.local_rank}")
|
||||
logger.info("Student loaded.")
|
||||
|
||||
# TEACHER #
|
||||
teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True)
|
||||
if args.n_gpu > 0:
|
||||
teacher.to(f"cuda:{args.local_rank}")
|
||||
logger.info(f"Teacher loaded from {args.teacher_name}.")
|
||||
|
||||
# FREEZING #
|
||||
if args.freeze_pos_embs:
|
||||
freeze_pos_embeddings(student, args)
|
||||
if args.freeze_token_type_embds:
|
||||
freeze_token_type_embeddings(student, args)
|
||||
|
||||
# SANITY CHECKS #
|
||||
assert student.config.vocab_size == teacher.config.vocab_size
|
||||
assert student.config.hidden_size == teacher.config.hidden_size
|
||||
assert student.config.max_position_embeddings == teacher.config.max_position_embeddings
|
||||
if args.mlm:
|
||||
assert token_probs.size(0) == stu_architecture_config.vocab_size
|
||||
|
||||
# DISTILLER #
|
||||
torch.cuda.empty_cache()
|
||||
distiller = Distiller(
|
||||
params=args, dataset=train_lm_seq_dataset, token_probs=token_probs, student=student, teacher=teacher
|
||||
)
|
||||
distiller.train()
|
||||
logger.info("Let's go get some drinks.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"activation": "gelu",
|
||||
"attention_dropout": 0.1,
|
||||
"dim": 768,
|
||||
"dropout": 0.1,
|
||||
"hidden_dim": 3072,
|
||||
"initializer_range": 0.02,
|
||||
"max_position_embeddings": 512,
|
||||
"n_heads": 12,
|
||||
"n_layers": 6,
|
||||
"sinusoidal_pos_embds": true,
|
||||
"tie_weights_": true,
|
||||
"vocab_size": 28996
|
||||
}
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"activation": "gelu",
|
||||
"attention_dropout": 0.1,
|
||||
"dim": 768,
|
||||
"dropout": 0.1,
|
||||
"hidden_dim": 3072,
|
||||
"initializer_range": 0.02,
|
||||
"max_position_embeddings": 512,
|
||||
"n_heads": 12,
|
||||
"n_layers": 6,
|
||||
"sinusoidal_pos_embds": true,
|
||||
"tie_weights_": true,
|
||||
"vocab_size": 119547
|
||||
}
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"activation": "gelu",
|
||||
"attention_dropout": 0.1,
|
||||
"dim": 768,
|
||||
"dropout": 0.1,
|
||||
"hidden_dim": 3072,
|
||||
"initializer_range": 0.02,
|
||||
"max_position_embeddings": 512,
|
||||
"n_heads": 12,
|
||||
"n_layers": 6,
|
||||
"sinusoidal_pos_embds": true,
|
||||
"tie_weights_": true,
|
||||
"vocab_size": 30522
|
||||
}
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"initializer_range": 0.02,
|
||||
"layer_norm_epsilon": 0.00001,
|
||||
"n_ctx": 1024,
|
||||
"n_embd": 768,
|
||||
"n_head": 12,
|
||||
"n_layer": 6,
|
||||
"n_positions": 1024,
|
||||
"vocab_size": 50257
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"vocab_size": 50265,
|
||||
"hidden_size": 768,
|
||||
"num_hidden_layers": 6,
|
||||
"num_attention_heads": 12,
|
||||
"intermediate_size": 3072,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"max_position_embeddings": 514,
|
||||
"type_vocab_size": 1,
|
||||
"initializer_range": 0.02,
|
||||
"layer_norm_eps": 0.00001
|
||||
}
|
||||
133
examples/research_projects/distillation/utils.py
Normal file
133
examples/research_projects/distillation/utils.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Utils to train DistilBERT
|
||||
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
|
||||
import git
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def git_log(folder_path: str):
|
||||
"""
|
||||
Log commit info.
|
||||
"""
|
||||
repo = git.Repo(search_parent_directories=True)
|
||||
repo_infos = {
|
||||
"repo_id": str(repo),
|
||||
"repo_sha": str(repo.head.object.hexsha),
|
||||
"repo_branch": str(repo.active_branch),
|
||||
}
|
||||
|
||||
with open(os.path.join(folder_path, "git_log.json"), "w") as f:
|
||||
json.dump(repo_infos, f, indent=4)
|
||||
|
||||
|
||||
def init_gpu_params(params):
|
||||
"""
|
||||
Handle single and multi-GPU / multi-node.
|
||||
"""
|
||||
if params.n_gpu <= 0:
|
||||
params.local_rank = 0
|
||||
params.master_port = -1
|
||||
params.is_master = True
|
||||
params.multi_gpu = False
|
||||
return
|
||||
|
||||
assert torch.cuda.is_available()
|
||||
|
||||
logger.info("Initializing GPUs")
|
||||
if params.n_gpu > 1:
|
||||
assert params.local_rank != -1
|
||||
|
||||
params.world_size = int(os.environ["WORLD_SIZE"])
|
||||
params.n_gpu_per_node = int(os.environ["N_GPU_NODE"])
|
||||
params.global_rank = int(os.environ["RANK"])
|
||||
|
||||
# number of nodes / node ID
|
||||
params.n_nodes = params.world_size // params.n_gpu_per_node
|
||||
params.node_id = params.global_rank // params.n_gpu_per_node
|
||||
params.multi_gpu = True
|
||||
|
||||
assert params.n_nodes == int(os.environ["N_NODES"])
|
||||
assert params.node_id == int(os.environ["NODE_RANK"])
|
||||
|
||||
# local job (single GPU)
|
||||
else:
|
||||
assert params.local_rank == -1
|
||||
|
||||
params.n_nodes = 1
|
||||
params.node_id = 0
|
||||
params.local_rank = 0
|
||||
params.global_rank = 0
|
||||
params.world_size = 1
|
||||
params.n_gpu_per_node = 1
|
||||
params.multi_gpu = False
|
||||
|
||||
# sanity checks
|
||||
assert params.n_nodes >= 1
|
||||
assert 0 <= params.node_id < params.n_nodes
|
||||
assert 0 <= params.local_rank <= params.global_rank < params.world_size
|
||||
assert params.world_size == params.n_nodes * params.n_gpu_per_node
|
||||
|
||||
# define whether this is the master process / if we are in multi-node distributed mode
|
||||
params.is_master = params.node_id == 0 and params.local_rank == 0
|
||||
params.multi_node = params.n_nodes > 1
|
||||
|
||||
# summary
|
||||
PREFIX = f"--- Global rank: {params.global_rank} - "
|
||||
logger.info(PREFIX + "Number of nodes: %i" % params.n_nodes)
|
||||
logger.info(PREFIX + "Node ID : %i" % params.node_id)
|
||||
logger.info(PREFIX + "Local rank : %i" % params.local_rank)
|
||||
logger.info(PREFIX + "World size : %i" % params.world_size)
|
||||
logger.info(PREFIX + "GPUs per node : %i" % params.n_gpu_per_node)
|
||||
logger.info(PREFIX + "Master : %s" % str(params.is_master))
|
||||
logger.info(PREFIX + "Multi-node : %s" % str(params.multi_node))
|
||||
logger.info(PREFIX + "Multi-GPU : %s" % str(params.multi_gpu))
|
||||
logger.info(PREFIX + "Hostname : %s" % socket.gethostname())
|
||||
|
||||
# set GPU device
|
||||
torch.cuda.set_device(params.local_rank)
|
||||
|
||||
# initialize multi-GPU
|
||||
if params.multi_gpu:
|
||||
logger.info("Initializing PyTorch distributed")
|
||||
torch.distributed.init_process_group(
|
||||
init_method="env://",
|
||||
backend="nccl",
|
||||
)
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
"""
|
||||
Set the random seed.
|
||||
"""
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
7
examples/research_projects/longform-qa/README.md
Normal file
7
examples/research_projects/longform-qa/README.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# Long Form Question Answering
|
||||
|
||||
Author: @yjernite
|
||||
|
||||
This folder contains the code for the Long Form Question answering [demo](http://35.226.96.115:8080/) as well as methods to train and use a fully end-to-end Long Form Question Answering system using the [🤗transformers](https://github.com/huggingface/transformers) and [🤗datasets](https://github.com/huggingface/datasets) libraries.
|
||||
|
||||
You can use these methods to train your own system by following along the associate [notebook](https://github.com/huggingface/notebooks/blob/master/longform-qa/Long_Form_Question_Answering_with_ELI5_and_Wikipedia.ipynb) or [blog post](https://yjernite.github.io/lfqa.html).
|
||||
351
examples/research_projects/longform-qa/eli5_app.py
Normal file
351
examples/research_projects/longform-qa/eli5_app.py
Normal file
@@ -0,0 +1,351 @@
|
||||
import datasets
|
||||
import numpy as np
|
||||
import streamlit as st
|
||||
import torch
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
import faiss
|
||||
import transformers
|
||||
from eli5_utils import (
|
||||
embed_questions_for_retrieval,
|
||||
make_qa_s2s_model,
|
||||
qa_s2s_generate,
|
||||
query_es_index,
|
||||
query_qa_dense_index,
|
||||
)
|
||||
from transformers import AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
|
||||
MODEL_TYPE = "bart"
|
||||
LOAD_DENSE_INDEX = True
|
||||
|
||||
|
||||
@st.cache(allow_output_mutation=True)
|
||||
def load_models():
|
||||
if LOAD_DENSE_INDEX:
|
||||
qar_tokenizer = AutoTokenizer.from_pretrained("yjernite/retribert-base-uncased")
|
||||
qar_model = AutoModel.from_pretrained("yjernite/retribert-base-uncased").to("cuda:0")
|
||||
_ = qar_model.eval()
|
||||
else:
|
||||
qar_tokenizer, qar_model = (None, None)
|
||||
if MODEL_TYPE == "bart":
|
||||
s2s_tokenizer = AutoTokenizer.from_pretrained("yjernite/bart_eli5")
|
||||
s2s_model = AutoModelForSeq2SeqLM.from_pretrained("yjernite/bart_eli5").to("cuda:0")
|
||||
save_dict = torch.load("seq2seq_models/eli5_bart_model_blm_2.pth")
|
||||
s2s_model.load_state_dict(save_dict["model"])
|
||||
_ = s2s_model.eval()
|
||||
else:
|
||||
s2s_tokenizer, s2s_model = make_qa_s2s_model(
|
||||
model_name="t5-small", from_file="seq2seq_models/eli5_t5_model_1024_4.pth", device="cuda:0"
|
||||
)
|
||||
return (qar_tokenizer, qar_model, s2s_tokenizer, s2s_model)
|
||||
|
||||
|
||||
@st.cache(allow_output_mutation=True)
|
||||
def load_indexes():
|
||||
if LOAD_DENSE_INDEX:
|
||||
faiss_res = faiss.StandardGpuResources()
|
||||
wiki40b_passages = datasets.load_dataset(path="wiki_snippets", name="wiki40b_en_100_0")["train"]
|
||||
wiki40b_passage_reps = np.memmap(
|
||||
"wiki40b_passages_reps_32_l-8_h-768_b-512-512.dat",
|
||||
dtype="float32",
|
||||
mode="r",
|
||||
shape=(wiki40b_passages.num_rows, 128),
|
||||
)
|
||||
wiki40b_index_flat = faiss.IndexFlatIP(128)
|
||||
wiki40b_gpu_index_flat = faiss.index_cpu_to_gpu(faiss_res, 1, wiki40b_index_flat)
|
||||
wiki40b_gpu_index_flat.add(wiki40b_passage_reps) # TODO fix for larger GPU
|
||||
else:
|
||||
wiki40b_passages, wiki40b_gpu_index_flat = (None, None)
|
||||
es_client = Elasticsearch([{"host": "localhost", "port": "9200"}])
|
||||
return (wiki40b_passages, wiki40b_gpu_index_flat, es_client)
|
||||
|
||||
|
||||
@st.cache(allow_output_mutation=True)
|
||||
def load_train_data():
|
||||
eli5 = datasets.load_dataset("eli5", name="LFQA_reddit")
|
||||
eli5_train = eli5["train_eli5"]
|
||||
eli5_train_q_reps = np.memmap(
|
||||
"eli5_questions_reps.dat", dtype="float32", mode="r", shape=(eli5_train.num_rows, 128)
|
||||
)
|
||||
eli5_train_q_index = faiss.IndexFlatIP(128)
|
||||
eli5_train_q_index.add(eli5_train_q_reps)
|
||||
return (eli5_train, eli5_train_q_index)
|
||||
|
||||
|
||||
passages, gpu_dense_index, es_client = load_indexes()
|
||||
qar_tokenizer, qar_model, s2s_tokenizer, s2s_model = load_models()
|
||||
eli5_train, eli5_train_q_index = load_train_data()
|
||||
|
||||
|
||||
def find_nearest_training(question, n_results=10):
|
||||
q_rep = embed_questions_for_retrieval([question], qar_tokenizer, qar_model)
|
||||
D, I = eli5_train_q_index.search(q_rep, n_results)
|
||||
nn_examples = [eli5_train[int(i)] for i in I[0]]
|
||||
return nn_examples
|
||||
|
||||
|
||||
def make_support(question, source="wiki40b", method="dense", n_results=10):
|
||||
if source == "none":
|
||||
support_doc, hit_lst = (" <P> ".join(["" for _ in range(11)]).strip(), [])
|
||||
else:
|
||||
if method == "dense":
|
||||
support_doc, hit_lst = query_qa_dense_index(
|
||||
question, qar_model, qar_tokenizer, passages, gpu_dense_index, n_results
|
||||
)
|
||||
else:
|
||||
support_doc, hit_lst = query_es_index(
|
||||
question,
|
||||
es_client,
|
||||
index_name="english_wiki40b_snippets_100w",
|
||||
n_results=n_results,
|
||||
)
|
||||
support_list = [
|
||||
(res["article_title"], res["section_title"].strip(), res["score"], res["passage_text"]) for res in hit_lst
|
||||
]
|
||||
question_doc = "question: {} context: {}".format(question, support_doc)
|
||||
return question_doc, support_list
|
||||
|
||||
|
||||
@st.cache(
|
||||
hash_funcs={
|
||||
torch.Tensor: (lambda _: None),
|
||||
transformers.models.bart.tokenization_bart.BartTokenizer: (lambda _: None),
|
||||
}
|
||||
)
|
||||
def answer_question(
|
||||
question_doc, s2s_model, s2s_tokenizer, min_len=64, max_len=256, sampling=False, n_beams=2, top_p=0.95, temp=0.8
|
||||
):
|
||||
with torch.no_grad():
|
||||
answer = qa_s2s_generate(
|
||||
question_doc,
|
||||
s2s_model,
|
||||
s2s_tokenizer,
|
||||
num_answers=1,
|
||||
num_beams=n_beams,
|
||||
min_len=min_len,
|
||||
max_len=max_len,
|
||||
do_sample=sampling,
|
||||
temp=temp,
|
||||
top_p=top_p,
|
||||
top_k=None,
|
||||
max_input_length=1024,
|
||||
device="cuda:0",
|
||||
)[0]
|
||||
return (answer, support_list)
|
||||
|
||||
|
||||
st.title("Long Form Question Answering with ELI5")
|
||||
|
||||
# Start sidebar
|
||||
header_html = "<img src='https://huggingface.co/front/assets/huggingface_logo.svg'>"
|
||||
header_full = """
|
||||
<html>
|
||||
<head>
|
||||
<style>
|
||||
.img-container {
|
||||
padding-left: 90px;
|
||||
padding-right: 90px;
|
||||
padding-top: 50px;
|
||||
padding-bottom: 50px;
|
||||
background-color: #f0f3f9;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<span class="img-container"> <!-- Inline parent element -->
|
||||
%s
|
||||
</span>
|
||||
</body>
|
||||
</html>
|
||||
""" % (
|
||||
header_html,
|
||||
)
|
||||
st.sidebar.markdown(
|
||||
header_full,
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
# Long Form QA with ELI5 and Wikipedia
|
||||
description = """
|
||||
This demo presents a model trained to [provide long-form answers to open-domain questions](https://yjernite.github.io/lfqa.html).
|
||||
First, a document retriever fetches a set of relevant Wikipedia passages given the question from the [Wiki40b](https://research.google/pubs/pub49029/) dataset,
|
||||
a pre-processed fixed snapshot of Wikipedia.
|
||||
"""
|
||||
st.sidebar.markdown(description, unsafe_allow_html=True)
|
||||
|
||||
action_list = [
|
||||
"Answer the question",
|
||||
"View the retrieved document only",
|
||||
"View the most similar ELI5 question and answer",
|
||||
"Show me everything, please!",
|
||||
]
|
||||
demo_options = st.sidebar.checkbox("Demo options")
|
||||
if demo_options:
|
||||
action_st = st.sidebar.selectbox(
|
||||
"",
|
||||
action_list,
|
||||
index=3,
|
||||
)
|
||||
action = action_list.index(action_st)
|
||||
show_type = st.sidebar.selectbox(
|
||||
"",
|
||||
["Show full text of passages", "Show passage section titles"],
|
||||
index=0,
|
||||
)
|
||||
show_passages = show_type == "Show full text of passages"
|
||||
else:
|
||||
action = 3
|
||||
show_passages = True
|
||||
|
||||
retrieval_options = st.sidebar.checkbox("Retrieval options")
|
||||
if retrieval_options:
|
||||
retriever_info = """
|
||||
### Information retriever options
|
||||
|
||||
The **sparse** retriever uses ElasticSearch, while the **dense** retriever uses max-inner-product search between a question and passage embedding
|
||||
trained using the [ELI5](https://arxiv.org/abs/1907.09190) questions-answer pairs.
|
||||
The answer is then generated by sequence to sequence model which takes the question and retrieved document as input.
|
||||
"""
|
||||
st.sidebar.markdown(retriever_info)
|
||||
wiki_source = st.sidebar.selectbox("Which Wikipedia format should the model use?", ["wiki40b", "none"])
|
||||
index_type = st.sidebar.selectbox("Which Wikipedia indexer should the model use?", ["dense", "sparse", "mixed"])
|
||||
else:
|
||||
wiki_source = "wiki40b"
|
||||
index_type = "dense"
|
||||
|
||||
sampled = "beam"
|
||||
n_beams = 2
|
||||
min_len = 64
|
||||
max_len = 256
|
||||
top_p = None
|
||||
temp = None
|
||||
generate_options = st.sidebar.checkbox("Generation options")
|
||||
if generate_options:
|
||||
generate_info = """
|
||||
### Answer generation options
|
||||
|
||||
The sequence-to-sequence model was initialized with [BART](https://huggingface.co/facebook/bart-large)
|
||||
weights and fine-tuned on the ELI5 QA pairs and retrieved documents. You can use the model for greedy decoding with
|
||||
**beam** search, or **sample** from the decoder's output probabilities.
|
||||
"""
|
||||
st.sidebar.markdown(generate_info)
|
||||
sampled = st.sidebar.selectbox("Would you like to use beam search or sample an answer?", ["beam", "sampled"])
|
||||
min_len = st.sidebar.slider(
|
||||
"Minimum generation length", min_value=8, max_value=256, value=64, step=8, format=None, key=None
|
||||
)
|
||||
max_len = st.sidebar.slider(
|
||||
"Maximum generation length", min_value=64, max_value=512, value=256, step=16, format=None, key=None
|
||||
)
|
||||
if sampled == "beam":
|
||||
n_beams = st.sidebar.slider("Beam size", min_value=1, max_value=8, value=2, step=None, format=None, key=None)
|
||||
else:
|
||||
top_p = st.sidebar.slider(
|
||||
"Nucleus sampling p", min_value=0.1, max_value=1.0, value=0.95, step=0.01, format=None, key=None
|
||||
)
|
||||
temp = st.sidebar.slider(
|
||||
"Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.01, format=None, key=None
|
||||
)
|
||||
n_beams = None
|
||||
|
||||
# start main text
|
||||
questions_list = [
|
||||
"<MY QUESTION>",
|
||||
"How do people make chocolate?",
|
||||
"Why do we get a fever when we are sick?",
|
||||
"How can different animals perceive different colors?",
|
||||
"What is natural language processing?",
|
||||
"What's the best way to treat a sunburn?",
|
||||
"What exactly are vitamins ?",
|
||||
"How does nuclear energy provide electricity?",
|
||||
"What's the difference between viruses and bacteria?",
|
||||
"Why are flutes classified as woodwinds when most of them are made out of metal ?",
|
||||
"Why do people like drinking coffee even though it tastes so bad?",
|
||||
"What happens when wine ages? How does it make the wine taste better?",
|
||||
"If an animal is an herbivore, where does it get the protein that it needs to survive if it only eats grass?",
|
||||
"How can we set a date to the beginning or end of an artistic period? Doesn't the change happen gradually?",
|
||||
"How does New Zealand have so many large bird predators?",
|
||||
]
|
||||
question_s = st.selectbox(
|
||||
"What would you like to ask? ---- select <MY QUESTION> to enter a new query",
|
||||
questions_list,
|
||||
index=1,
|
||||
)
|
||||
if question_s == "<MY QUESTION>":
|
||||
question = st.text_input("Enter your question here:", "")
|
||||
else:
|
||||
question = question_s
|
||||
|
||||
if st.button("Show me!"):
|
||||
if action in [0, 1, 3]:
|
||||
if index_type == "mixed":
|
||||
_, support_list_dense = make_support(question, source=wiki_source, method="dense", n_results=10)
|
||||
_, support_list_sparse = make_support(question, source=wiki_source, method="sparse", n_results=10)
|
||||
support_list = []
|
||||
for res_d, res_s in zip(support_list_dense, support_list_sparse):
|
||||
if tuple(res_d) not in support_list:
|
||||
support_list += [tuple(res_d)]
|
||||
if tuple(res_s) not in support_list:
|
||||
support_list += [tuple(res_s)]
|
||||
support_list = support_list[:10]
|
||||
question_doc = "<P> " + " <P> ".join([res[-1] for res in support_list])
|
||||
else:
|
||||
question_doc, support_list = make_support(question, source=wiki_source, method=index_type, n_results=10)
|
||||
if action in [0, 3]:
|
||||
answer, support_list = answer_question(
|
||||
question_doc,
|
||||
s2s_model,
|
||||
s2s_tokenizer,
|
||||
min_len=min_len,
|
||||
max_len=int(max_len),
|
||||
sampling=(sampled == "sampled"),
|
||||
n_beams=n_beams,
|
||||
top_p=top_p,
|
||||
temp=temp,
|
||||
)
|
||||
st.markdown("### The model generated answer is:")
|
||||
st.write(answer)
|
||||
if action in [0, 1, 3] and wiki_source != "none":
|
||||
st.markdown("--- \n ### The model is drawing information from the following Wikipedia passages:")
|
||||
for i, res in enumerate(support_list):
|
||||
wiki_url = "https://en.wikipedia.org/wiki/{}".format(res[0].replace(" ", "_"))
|
||||
sec_titles = res[1].strip()
|
||||
if sec_titles == "":
|
||||
sections = "[{}]({})".format(res[0], wiki_url)
|
||||
else:
|
||||
sec_list = sec_titles.split(" & ")
|
||||
sections = " & ".join(
|
||||
["[{}]({}#{})".format(sec.strip(), wiki_url, sec.strip().replace(" ", "_")) for sec in sec_list]
|
||||
)
|
||||
st.markdown(
|
||||
"{0:02d} - **Article**: {1:<18} <br> _Section_: {2}".format(i + 1, res[0], sections),
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
if show_passages:
|
||||
st.write(
|
||||
'> <span style="font-family:arial; font-size:10pt;">' + res[-1] + "</span>", unsafe_allow_html=True
|
||||
)
|
||||
if action in [2, 3]:
|
||||
nn_train_list = find_nearest_training(question)
|
||||
train_exple = nn_train_list[0]
|
||||
st.markdown(
|
||||
"--- \n ### The most similar question in the ELI5 training set was: \n\n {}".format(train_exple["title"])
|
||||
)
|
||||
answers_st = [
|
||||
"{}. {}".format(i + 1, " \n".join([line.strip() for line in ans.split("\n") if line.strip() != ""]))
|
||||
for i, (ans, sc) in enumerate(zip(train_exple["answers"]["text"], train_exple["answers"]["score"]))
|
||||
if i == 0 or sc > 2
|
||||
]
|
||||
st.markdown("##### Its answers were: \n\n {}".format("\n".join(answers_st)))
|
||||
|
||||
|
||||
disclaimer = """
|
||||
---
|
||||
|
||||
**Disclaimer**
|
||||
|
||||
*The intent of this app is to provide some (hopefully entertaining) insights into the behavior of a current LFQA system.
|
||||
Evaluating biases of such a model and ensuring factual generations are still very much open research problems.
|
||||
Therefore, until some significant progress is achieved, we caution against using the generated answers for practical purposes.*
|
||||
"""
|
||||
st.sidebar.markdown(disclaimer, unsafe_allow_html=True)
|
||||
687
examples/research_projects/longform-qa/eli5_utils.py
Normal file
687
examples/research_projects/longform-qa/eli5_utils.py
Normal file
@@ -0,0 +1,687 @@
|
||||
import functools
|
||||
import math
|
||||
import os # noqa: F401
|
||||
from random import choice, randint
|
||||
from time import time
|
||||
|
||||
import datasets # noqa: F401
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from elasticsearch import Elasticsearch # noqa: F401
|
||||
from elasticsearch.helpers import bulk, streaming_bulk # noqa: F401
|
||||
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
import faiss # noqa: F401
|
||||
from transformers import AdamW, AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup
|
||||
|
||||
|
||||
pd.set_option("display.max_colwidth", None)
|
||||
|
||||
|
||||
###############
|
||||
# Sparse index
|
||||
###############
|
||||
def make_es_index_snippets(es_client, passages_dset, index_name="english_wiki_kilt_snippets_100w"):
|
||||
index_config = {
|
||||
"settings": {
|
||||
"number_of_shards": 1,
|
||||
"analysis": {"analyzer": {"stop_standard": {"type": "standard", " stopwords": "_english_"}}},
|
||||
},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"article_title": {"type": "text", "analyzer": "standard", "similarity": "BM25"},
|
||||
"section_title": {"type": "text", "analyzer": "standard", "similarity": "BM25"},
|
||||
"passage_text": {"type": "text", "analyzer": "standard", "similarity": "BM25"},
|
||||
}
|
||||
},
|
||||
}
|
||||
es_client.indices.create(index=index_name, body=index_config)
|
||||
number_of_docs = passages_dset.num_rows
|
||||
progress = tqdm(unit="docs", total=number_of_docs)
|
||||
successes = 0
|
||||
|
||||
def passage_generator():
|
||||
for passage in passages_dset:
|
||||
yield passage
|
||||
|
||||
# create the ES index
|
||||
for ok, action in streaming_bulk(
|
||||
client=es_client,
|
||||
index=index_name,
|
||||
actions=passage_generator(),
|
||||
):
|
||||
progress.update(1)
|
||||
successes += ok
|
||||
print("Indexed %d documents" % (successes,))
|
||||
|
||||
|
||||
def query_es_index(question, es_client, index_name="english_wiki_kilt_snippets_100w", n_results=10, min_length=20):
|
||||
q = question.lower()
|
||||
banned = ["how", "why", "what", "where", "which", "do", "does", "is", "?", "eli5", "eli5:"]
|
||||
q = " ".join([w for w in q.split() if w not in banned])
|
||||
response = es_client.search(
|
||||
index=index_name,
|
||||
body={
|
||||
"query": {
|
||||
"multi_match": {
|
||||
"query": q,
|
||||
"fields": ["article_title", "section_title", "passage_text^2"],
|
||||
"type": "cross_fields",
|
||||
}
|
||||
},
|
||||
"size": 2 * n_results,
|
||||
},
|
||||
)
|
||||
hits = response["hits"]["hits"]
|
||||
support_doc = "<P> " + " <P> ".join([hit["_source"]["passage_text"] for hit in hits])
|
||||
res_list = [dict([(k, hit["_source"][k]) for k in hit["_source"] if k != "passage_text"]) for hit in hits]
|
||||
for r, hit in zip(res_list, hits):
|
||||
r["passage_id"] = hit["_id"]
|
||||
r["score"] = hit["_score"]
|
||||
r["passage_text"] = hit["_source"]["passage_text"]
|
||||
res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]
|
||||
return support_doc, res_list
|
||||
|
||||
|
||||
###############
|
||||
# ELI5 retriever training
|
||||
###############
|
||||
class ELI5DatasetQARetriver(Dataset):
|
||||
def __init__(self, examples_array, extra_answer_threshold=3, min_answer_length=64, training=True, n_samples=None):
|
||||
self.data = examples_array
|
||||
self.answer_thres = extra_answer_threshold
|
||||
self.min_length = min_answer_length
|
||||
self.training = training
|
||||
self.n_samples = self.data.num_rows if n_samples is None else n_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.n_samples
|
||||
|
||||
def make_example(self, idx):
|
||||
example = self.data[idx]
|
||||
question = example["title"]
|
||||
if self.training:
|
||||
answers = [a for i, (a, sc) in enumerate(zip(example["answers"]["text"], example["answers"]["score"]))]
|
||||
answer_tab = choice(answers).split(" ")
|
||||
start_idx = randint(0, max(0, len(answer_tab) - self.min_length))
|
||||
answer_span = " ".join(answer_tab[start_idx:])
|
||||
else:
|
||||
answer_span = example["answers"]["text"][0]
|
||||
return (question, answer_span)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.make_example(idx % self.data.num_rows)
|
||||
|
||||
|
||||
class RetrievalQAEmbedder(torch.nn.Module):
|
||||
def __init__(self, sent_encoder, dim):
|
||||
super(RetrievalQAEmbedder, self).__init__()
|
||||
self.sent_encoder = sent_encoder
|
||||
self.output_dim = 128
|
||||
self.project_q = torch.nn.Linear(dim, self.output_dim, bias=False)
|
||||
self.project_a = torch.nn.Linear(dim, self.output_dim, bias=False)
|
||||
self.ce_loss = torch.nn.CrossEntropyLoss(reduction="mean")
|
||||
|
||||
def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_batch_size=-1):
|
||||
# reproduces BERT forward pass with checkpointing
|
||||
if checkpoint_batch_size < 0 or input_ids.shape[0] < checkpoint_batch_size:
|
||||
return self.sent_encoder(input_ids, attention_mask=attention_mask)[1]
|
||||
else:
|
||||
# prepare implicit variables
|
||||
device = input_ids.device
|
||||
input_shape = input_ids.size()
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
head_mask = [None] * self.sent_encoder.config.num_hidden_layers
|
||||
extended_attention_mask: torch.Tensor = self.sent_encoder.get_extended_attention_mask(
|
||||
attention_mask, input_shape, device
|
||||
)
|
||||
|
||||
# define function for checkpointing
|
||||
def partial_encode(*inputs):
|
||||
encoder_outputs = self.sent_encoder.encoder(
|
||||
inputs[0],
|
||||
attention_mask=inputs[1],
|
||||
head_mask=head_mask,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.sent_encoder.pooler(sequence_output)
|
||||
return pooled_output
|
||||
|
||||
# run embedding layer on everything at once
|
||||
embedding_output = self.sent_encoder.embeddings(
|
||||
input_ids=input_ids, position_ids=None, token_type_ids=token_type_ids, inputs_embeds=None
|
||||
)
|
||||
# run encoding and pooling on one mini-batch at a time
|
||||
pooled_output_list = []
|
||||
for b in range(math.ceil(input_ids.shape[0] / checkpoint_batch_size)):
|
||||
b_embedding_output = embedding_output[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
|
||||
b_attention_mask = extended_attention_mask[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
|
||||
pooled_output = checkpoint.checkpoint(partial_encode, b_embedding_output, b_attention_mask)
|
||||
pooled_output_list.append(pooled_output)
|
||||
return torch.cat(pooled_output_list, dim=0)
|
||||
|
||||
def embed_questions(self, q_ids, q_mask, checkpoint_batch_size=-1):
|
||||
q_reps = self.embed_sentences_checkpointed(q_ids, q_mask, checkpoint_batch_size)
|
||||
return self.project_q(q_reps)
|
||||
|
||||
def embed_answers(self, a_ids, a_mask, checkpoint_batch_size=-1):
|
||||
a_reps = self.embed_sentences_checkpointed(a_ids, a_mask, checkpoint_batch_size)
|
||||
return self.project_a(a_reps)
|
||||
|
||||
def forward(self, q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=-1):
|
||||
device = q_ids.device
|
||||
q_reps = self.embed_questions(q_ids, q_mask, checkpoint_batch_size)
|
||||
a_reps = self.embed_answers(a_ids, a_mask, checkpoint_batch_size)
|
||||
compare_scores = torch.mm(q_reps, a_reps.t())
|
||||
loss_qa = self.ce_loss(compare_scores, torch.arange(compare_scores.shape[1]).to(device))
|
||||
loss_aq = self.ce_loss(compare_scores.t(), torch.arange(compare_scores.shape[0]).to(device))
|
||||
loss = (loss_qa + loss_aq) / 2
|
||||
return loss
|
||||
|
||||
|
||||
def make_qa_retriever_model(model_name="google/bert_uncased_L-8_H-512_A-8", from_file=None, device="cuda:0"):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
bert_model = AutoModel.from_pretrained(model_name).to(device)
|
||||
# run bert_model on a dummy batch to get output dimension
|
||||
d_ids = torch.LongTensor(
|
||||
[[bert_model.config.bos_token_id if bert_model.config.bos_token_id is not None else 1]]
|
||||
).to(device)
|
||||
d_mask = torch.LongTensor([[1]]).to(device)
|
||||
sent_dim = bert_model(d_ids, attention_mask=d_mask)[1].shape[-1]
|
||||
qa_embedder = RetrievalQAEmbedder(bert_model, sent_dim).to(device)
|
||||
if from_file is not None:
|
||||
param_dict = torch.load(from_file) # has model weights, optimizer, and scheduler states
|
||||
qa_embedder.load_state_dict(param_dict["model"])
|
||||
return tokenizer, qa_embedder
|
||||
|
||||
|
||||
def make_qa_retriever_batch(qa_list, tokenizer, max_len=64, device="cuda:0"):
|
||||
q_ls = [q for q, a in qa_list]
|
||||
a_ls = [a for q, a in qa_list]
|
||||
q_toks = tokenizer(q_ls, max_length=max_len, padding="max_length", truncation=True)
|
||||
q_ids, q_mask = (
|
||||
torch.LongTensor(q_toks["input_ids"]).to(device),
|
||||
torch.LongTensor(q_toks["attention_mask"]).to(device),
|
||||
)
|
||||
a_toks = tokenizer(a_ls, max_length=max_len, padding="max_length", truncation=True)
|
||||
a_ids, a_mask = (
|
||||
torch.LongTensor(a_toks["input_ids"]).to(device),
|
||||
torch.LongTensor(a_toks["attention_mask"]).to(device),
|
||||
)
|
||||
return (q_ids, q_mask, a_ids, a_mask)
|
||||
|
||||
|
||||
def train_qa_retriever_epoch(model, dataset, tokenizer, optimizer, scheduler, args, e=0):
|
||||
model.train()
|
||||
# make iterator
|
||||
train_sampler = RandomSampler(dataset)
|
||||
model_collate_fn = functools.partial(
|
||||
make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
|
||||
)
|
||||
data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
|
||||
epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
|
||||
# accumulate loss since last print
|
||||
loc_steps = 0
|
||||
loc_loss = 0.0
|
||||
st_time = time()
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
q_ids, q_mask, a_ids, a_mask = batch
|
||||
pre_loss = model(q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=args.checkpoint_batch_size)
|
||||
loss = pre_loss.sum()
|
||||
# optimizer
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
model.zero_grad()
|
||||
# some printing within the epoch
|
||||
loc_loss += loss.item()
|
||||
loc_steps += 1
|
||||
if step % args.print_freq == 0 or step == 1:
|
||||
print(
|
||||
"{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
|
||||
e,
|
||||
step,
|
||||
len(dataset) // args.batch_size,
|
||||
loc_loss / loc_steps,
|
||||
time() - st_time,
|
||||
)
|
||||
)
|
||||
loc_loss = 0
|
||||
loc_steps = 0
|
||||
|
||||
|
||||
def train_qa_retriever_joint_epoch(model, dataset_list, tokenizer, optimizer, scheduler, args, e=0):
|
||||
model.train()
|
||||
model_collate_fn = functools.partial(
|
||||
make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
|
||||
)
|
||||
# make iterator
|
||||
train_samplers = [RandomSampler(dataset) for dataset in dataset_list]
|
||||
data_loaders = [
|
||||
DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
|
||||
for dataset, train_sampler in zip(dataset_list, train_samplers)
|
||||
]
|
||||
iterators = [iter(dloader) for dloader in data_loaders]
|
||||
joint_iter = zip(*iterators)
|
||||
# accumulate loss since last print
|
||||
loc_steps = 0
|
||||
loc_loss = 0.0
|
||||
st_time = time()
|
||||
for step, (batches,) in enumerate(zip(joint_iter)):
|
||||
for batch in batches:
|
||||
q_ids, q_mask, a_ids, a_mask = batch
|
||||
loss = model(q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=args.checkpoint_batch_size)
|
||||
# optimizer
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
model.zero_grad()
|
||||
# some printing within the epoch
|
||||
loc_loss += loss.item()
|
||||
loc_steps += 1
|
||||
if step % args.print_freq == 0:
|
||||
print(
|
||||
"{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
|
||||
e,
|
||||
step,
|
||||
len(dataset_list[0]) // args.batch_size,
|
||||
loc_loss / loc_steps,
|
||||
time() - st_time,
|
||||
)
|
||||
)
|
||||
loc_loss = 0
|
||||
loc_steps = 0
|
||||
|
||||
|
||||
def evaluate_qa_retriever(model, dataset, tokenizer, args):
|
||||
model.eval()
|
||||
# make iterator
|
||||
eval_sampler = SequentialSampler(dataset)
|
||||
model_collate_fn = functools.partial(
|
||||
make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
|
||||
)
|
||||
data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=eval_sampler, collate_fn=model_collate_fn)
|
||||
epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
|
||||
tot_loss = 0.0
|
||||
with torch.no_grad():
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
q_ids, q_mask, a_ids, a_mask = batch
|
||||
loss = model(q_ids, q_mask, a_ids, a_mask)
|
||||
tot_loss += loss.item()
|
||||
return tot_loss / (step + 1)
|
||||
|
||||
|
||||
def train_qa_retriever(qar_model, qar_tokenizer, qar_train_dset, qar_valid_dset, qar_args):
|
||||
qar_optimizer = AdamW(qar_model.parameters(), lr=qar_args.learning_rate, eps=1e-8)
|
||||
qar_scheduler = get_linear_schedule_with_warmup(
|
||||
qar_optimizer,
|
||||
num_warmup_steps=100,
|
||||
num_training_steps=(qar_args.num_epochs + 1) * math.ceil(len(qar_train_dset) / qar_args.batch_size),
|
||||
)
|
||||
for e in range(qar_args.num_epochs):
|
||||
train_qa_retriever_epoch(qar_model, qar_train_dset, qar_tokenizer, qar_optimizer, qar_scheduler, qar_args, e)
|
||||
m_save_dict = {
|
||||
"model": qar_model.state_dict(),
|
||||
"optimizer": qar_optimizer.state_dict(),
|
||||
"scheduler": qar_scheduler.state_dict(),
|
||||
}
|
||||
print("Saving model {}".format(qar_args.model_save_name))
|
||||
torch.save(m_save_dict, "{}_{}.pth".format(qar_args.model_save_name, e))
|
||||
eval_loss = evaluate_qa_retriever(qar_model, qar_valid_dset, qar_tokenizer, qar_args)
|
||||
print("Evaluation loss epoch {:4d}: {:.3f}".format(e, eval_loss))
|
||||
|
||||
|
||||
###############
|
||||
# ELI5 seq2seq model training
|
||||
###############
|
||||
class ELI5DatasetS2S(Dataset):
|
||||
def __init__(
|
||||
self, examples_array, make_doc_fun=None, extra_answer_threshold=3, document_cache=None, training=True
|
||||
):
|
||||
self.training = training
|
||||
self.data = examples_array
|
||||
self.make_doc_function = make_doc_fun
|
||||
self.document_cache = {} if document_cache is None else document_cache
|
||||
assert not (make_doc_fun is None and document_cache is None)
|
||||
# make index of specific question-answer pairs from multi-answers
|
||||
if self.training:
|
||||
self.qa_id_list = [
|
||||
(i, j)
|
||||
for i, qa in enumerate(self.data)
|
||||
for j, (a, sc) in enumerate(zip(qa["answers"]["text"], qa["answers"]["score"]))
|
||||
if j == 0 or sc >= extra_answer_threshold
|
||||
]
|
||||
else:
|
||||
self.qa_id_list = [(i, 0) for i in range(self.data.num_rows)]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.qa_id_list)
|
||||
|
||||
def make_example(self, idx):
|
||||
i, j = self.qa_id_list[idx]
|
||||
example = self.data[i]
|
||||
question = example["title"] + " " + example["selftext"]
|
||||
answer = example["answers"]["text"][j]
|
||||
q_id = example["q_id"]
|
||||
if self.make_doc_function is not None:
|
||||
self.document_cache[q_id] = self.document_cache.get(q_id, self.make_doc_function(example["title"]))
|
||||
document = self.document_cache[q_id]
|
||||
in_st = "question: {} context: {}".format(
|
||||
question.lower().replace(" --t--", "").strip(),
|
||||
document.lower().strip(),
|
||||
)
|
||||
out_st = answer
|
||||
return (in_st, out_st)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.make_example(idx)
|
||||
|
||||
|
||||
def make_qa_s2s_model(model_name="facebook/bart-large", from_file=None, device="cuda:0"):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
|
||||
if from_file is not None:
|
||||
param_dict = torch.load(from_file) # has model weights, optimizer, and scheduler states
|
||||
model.load_state_dict(param_dict["model"])
|
||||
return tokenizer, model
|
||||
|
||||
|
||||
def make_qa_s2s_batch(qa_list, tokenizer, max_len=64, max_a_len=360, device="cuda:0"):
|
||||
q_ls = [q for q, a in qa_list]
|
||||
a_ls = [a for q, a in qa_list]
|
||||
q_toks = tokenizer(q_ls, max_length=max_len, padding="max_length", truncation=True)
|
||||
q_ids, q_mask = (
|
||||
torch.LongTensor(q_toks["input_ids"]).to(device),
|
||||
torch.LongTensor(q_toks["attention_mask"]).to(device),
|
||||
)
|
||||
a_toks = tokenizer(a_ls, max_length=min(max_len, max_a_len), padding="max_length", truncation=True)
|
||||
a_ids, a_mask = (
|
||||
torch.LongTensor(a_toks["input_ids"]).to(device),
|
||||
torch.LongTensor(a_toks["attention_mask"]).to(device),
|
||||
)
|
||||
lm_labels = a_ids[:, 1:].contiguous().clone()
|
||||
lm_labels[a_mask[:, 1:].contiguous() == 0] = -100
|
||||
model_inputs = {
|
||||
"input_ids": q_ids,
|
||||
"attention_mask": q_mask,
|
||||
"decoder_input_ids": a_ids[:, :-1].contiguous(),
|
||||
"lm_labels": lm_labels,
|
||||
}
|
||||
return model_inputs
|
||||
|
||||
|
||||
def train_qa_s2s_epoch(model, dataset, tokenizer, optimizer, scheduler, args, e=0, curriculum=False):
|
||||
model.train()
|
||||
# make iterator
|
||||
if curriculum:
|
||||
train_sampler = SequentialSampler(dataset)
|
||||
else:
|
||||
train_sampler = RandomSampler(dataset)
|
||||
model_collate_fn = functools.partial(
|
||||
make_qa_s2s_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
|
||||
)
|
||||
data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
|
||||
epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
|
||||
# accumulate loss since last print
|
||||
loc_steps = 0
|
||||
loc_loss = 0.0
|
||||
st_time = time()
|
||||
for step, batch_inputs in enumerate(epoch_iterator):
|
||||
pre_loss = model(**batch_inputs)[0]
|
||||
loss = pre_loss.sum() / pre_loss.shape[0]
|
||||
loss.backward()
|
||||
# optimizer
|
||||
if step % args.backward_freq == 0:
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
model.zero_grad()
|
||||
# some printing within the epoch
|
||||
loc_loss += loss.item()
|
||||
loc_steps += 1
|
||||
if step % args.print_freq == 0 or step == 1:
|
||||
print(
|
||||
"{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
|
||||
e,
|
||||
step,
|
||||
len(dataset) // args.batch_size,
|
||||
loc_loss / loc_steps,
|
||||
time() - st_time,
|
||||
)
|
||||
)
|
||||
loc_loss = 0
|
||||
loc_steps = 0
|
||||
|
||||
|
||||
def eval_qa_s2s_epoch(model, dataset, tokenizer, args):
|
||||
model.eval()
|
||||
# make iterator
|
||||
train_sampler = SequentialSampler(dataset)
|
||||
model_collate_fn = functools.partial(
|
||||
make_qa_s2s_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
|
||||
)
|
||||
data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
|
||||
epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
|
||||
# accumulate loss since last print
|
||||
loc_steps = 0
|
||||
loc_loss = 0.0
|
||||
st_time = time()
|
||||
with torch.no_grad():
|
||||
for step, batch_inputs in enumerate(epoch_iterator):
|
||||
pre_loss = model(**batch_inputs)[0]
|
||||
loss = pre_loss.sum() / pre_loss.shape[0]
|
||||
loc_loss += loss.item()
|
||||
loc_steps += 1
|
||||
if step % args.print_freq == 0:
|
||||
print(
|
||||
"{:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
|
||||
step,
|
||||
len(dataset) // args.batch_size,
|
||||
loc_loss / loc_steps,
|
||||
time() - st_time,
|
||||
)
|
||||
)
|
||||
print(
|
||||
"Total \t L: {:.3f} \t -- {:.3f}".format(
|
||||
loc_loss / loc_steps,
|
||||
time() - st_time,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def train_qa_s2s(qa_s2s_model, qa_s2s_tokenizer, s2s_train_dset, s2s_valid_dset, s2s_args):
|
||||
s2s_optimizer = AdamW(qa_s2s_model.parameters(), lr=s2s_args.learning_rate, eps=1e-8)
|
||||
s2s_scheduler = get_linear_schedule_with_warmup(
|
||||
s2s_optimizer,
|
||||
num_warmup_steps=400,
|
||||
num_training_steps=(s2s_args.num_epochs + 1) * math.ceil(len(s2s_train_dset) / s2s_args.batch_size),
|
||||
)
|
||||
for e in range(s2s_args.num_epochs):
|
||||
train_qa_s2s_epoch(
|
||||
qa_s2s_model,
|
||||
s2s_train_dset,
|
||||
qa_s2s_tokenizer,
|
||||
s2s_optimizer,
|
||||
s2s_scheduler,
|
||||
s2s_args,
|
||||
e,
|
||||
curriculum=(e == 0),
|
||||
)
|
||||
m_save_dict = {
|
||||
"model": qa_s2s_model.state_dict(),
|
||||
"optimizer": s2s_optimizer.state_dict(),
|
||||
"scheduler": s2s_scheduler.state_dict(),
|
||||
}
|
||||
print("Saving model {}".format(s2s_args.model_save_name))
|
||||
eval_qa_s2s_epoch(qa_s2s_model, s2s_valid_dset, qa_s2s_tokenizer, s2s_args)
|
||||
torch.save(m_save_dict, "{}_{}.pth".format(s2s_args.model_save_name, e))
|
||||
|
||||
|
||||
# generate answer from input "question: ... context: <p> ..."
|
||||
def qa_s2s_generate(
|
||||
question_doc,
|
||||
qa_s2s_model,
|
||||
qa_s2s_tokenizer,
|
||||
num_answers=1,
|
||||
num_beams=None,
|
||||
min_len=64,
|
||||
max_len=256,
|
||||
do_sample=False,
|
||||
temp=1.0,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
max_input_length=512,
|
||||
device="cuda:0",
|
||||
):
|
||||
model_inputs = make_qa_s2s_batch(
|
||||
[(question_doc, "A")],
|
||||
qa_s2s_tokenizer,
|
||||
max_input_length,
|
||||
device=device,
|
||||
)
|
||||
n_beams = num_answers if num_beams is None else max(num_beams, num_answers)
|
||||
generated_ids = qa_s2s_model.generate(
|
||||
input_ids=model_inputs["input_ids"],
|
||||
attention_mask=model_inputs["attention_mask"],
|
||||
min_length=min_len,
|
||||
max_length=max_len,
|
||||
do_sample=do_sample,
|
||||
early_stopping=True,
|
||||
num_beams=1 if do_sample else n_beams,
|
||||
temperature=temp,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
eos_token_id=qa_s2s_tokenizer.eos_token_id,
|
||||
no_repeat_ngram_size=3,
|
||||
num_return_sequences=num_answers,
|
||||
decoder_start_token_id=qa_s2s_tokenizer.bos_token_id,
|
||||
)
|
||||
return [qa_s2s_tokenizer.decode(ans_ids, skip_special_tokens=True).strip() for ans_ids in generated_ids]
|
||||
|
||||
|
||||
###############
|
||||
# ELI5-trained retrieval model usage
|
||||
###############
|
||||
def embed_passages_for_retrieval(passages, tokenizer, qa_embedder, max_length=128, device="cuda:0"):
|
||||
a_toks = tokenizer(passages, max_length=max_length, padding="max_length", truncation=True)
|
||||
a_ids, a_mask = (
|
||||
torch.LongTensor(a_toks["input_ids"]).to(device),
|
||||
torch.LongTensor(a_toks["attention_mask"]).to(device),
|
||||
)
|
||||
with torch.no_grad():
|
||||
a_reps = qa_embedder.embed_answers(a_ids, a_mask).cpu().type(torch.float)
|
||||
return a_reps.numpy()
|
||||
|
||||
|
||||
def embed_questions_for_retrieval(q_ls, tokenizer, qa_embedder, device="cuda:0"):
|
||||
q_toks = tokenizer(q_ls, max_length=128, padding="max_length", truncation=True)
|
||||
q_ids, q_mask = (
|
||||
torch.LongTensor(q_toks["input_ids"]).to(device),
|
||||
torch.LongTensor(q_toks["attention_mask"]).to(device),
|
||||
)
|
||||
with torch.no_grad():
|
||||
q_reps = qa_embedder.embed_questions(q_ids, q_mask).cpu().type(torch.float)
|
||||
return q_reps.numpy()
|
||||
|
||||
|
||||
def make_qa_dense_index(
|
||||
qa_embedder,
|
||||
tokenizer,
|
||||
passages_dset,
|
||||
batch_size=512,
|
||||
max_length=128,
|
||||
index_name="kilt_passages_reps.dat",
|
||||
dtype="float32",
|
||||
device="cuda:0",
|
||||
):
|
||||
st_time = time()
|
||||
fp = np.memmap(index_name, dtype=dtype, mode="w+", shape=(passages_dset.num_rows, 128))
|
||||
n_batches = math.ceil(passages_dset.num_rows / batch_size)
|
||||
for i in range(n_batches):
|
||||
passages = [p for p in passages_dset[i * batch_size : (i + 1) * batch_size]["passage_text"]]
|
||||
reps = embed_passages_for_retrieval(passages, tokenizer, qa_embedder, max_length, device)
|
||||
fp[i * batch_size : (i + 1) * batch_size] = reps
|
||||
if i % 50 == 0:
|
||||
print(i, time() - st_time)
|
||||
|
||||
|
||||
def evaluate_retriever(qa_list, retriever_func, scoring_func, n_ret=10, verbose=False):
|
||||
total_retriever_time = 0.0
|
||||
total_retriever_score = 0.0
|
||||
st_time = time()
|
||||
for i, (question, answer) in enumerate(qa_list):
|
||||
r_time = time()
|
||||
retrieved_passages = retriever_func(question, n_ret)
|
||||
total_retriever_time += time() - r_time
|
||||
total_retriever_score += scoring_func(retrieved_passages, answer)
|
||||
if verbose and ((i + 1) % 500 == 0 or i <= 1):
|
||||
print(
|
||||
"{:03d}: S-{:.4f} T-{:.4f} | {:.2f}".format(
|
||||
i + 1, total_retriever_score / (i + 1), total_retriever_time / (i + 1), time() - st_time
|
||||
)
|
||||
)
|
||||
return {"idf_recall": total_retriever_score / (i + 1), "retrieval_time": total_retriever_time / (i + 1)}
|
||||
|
||||
|
||||
# build a support document for the question out of Wikipedia snippets
|
||||
def query_qa_dense_index(
|
||||
question, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10, min_length=20, device="cuda:0"
|
||||
):
|
||||
q_rep = embed_questions_for_retrieval([question], tokenizer, qa_embedder, device=device)
|
||||
D, I = wiki_index.search(q_rep, 2 * n_results)
|
||||
res_passages = [wiki_passages[int(i)] for i in I[0]]
|
||||
support_doc = "<P> " + " <P> ".join([p["passage_text"] for p in res_passages])
|
||||
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
|
||||
res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]
|
||||
for r, sc in zip(res_list, D[0]):
|
||||
r["score"] = float(sc)
|
||||
return support_doc, res_list
|
||||
|
||||
|
||||
def batch_query_qa_dense_index(questions, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10):
|
||||
q_rep = embed_questions_for_retrieval(questions, tokenizer, qa_embedder)
|
||||
D, I = wiki_index.search(q_rep, n_results)
|
||||
res_passages_lst = [[wiki_passages[int(i)] for i in i_lst] for i_lst in I]
|
||||
support_doc_lst = [
|
||||
"<P> " + " <P> ".join([p["passage_text"] for p in res_passages]) for res_passages in res_passages_lst
|
||||
]
|
||||
all_res_lists = []
|
||||
for (res_passages, dl) in zip(res_passages_lst, D):
|
||||
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
|
||||
for r, sc in zip(res_list, dl):
|
||||
r["score"] = float(sc)
|
||||
all_res_lists += [res_list[:]]
|
||||
return support_doc_lst, all_res_lists
|
||||
|
||||
|
||||
# find nearest neighbors of an answer or declarative text in Wikipedia snippets
|
||||
def query_qa_dense_index_nn(passage, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10, min_length=20):
|
||||
a_rep = embed_passages_for_retrieval([passage], tokenizer, qa_embedder)
|
||||
D, I = wiki_index.search(a_rep, 2 * n_results)
|
||||
res_passages = [wiki_passages[int(i)] for i in I[0]]
|
||||
support_doc = "<P> " + " <P> ".join([p["passage_text"] for p in res_passages])
|
||||
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
|
||||
res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]
|
||||
for r, sc, i in zip(res_list, D[0], I[0]):
|
||||
r["passage_id"] = int(i)
|
||||
r["score"] = float(sc)
|
||||
return support_doc, res_list
|
||||
|
||||
|
||||
def batch_query_qa_dense_index_nn(passages, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10):
|
||||
a_reps = embed_passages_for_retrieval(passages, tokenizer, qa_embedder)
|
||||
D, I = wiki_index.search(a_reps, n_results)
|
||||
res_passages_lst = [[wiki_passages[int(i)] for i in i_lst] for i_lst in I]
|
||||
support_doc_lst = [
|
||||
"<P> " + " <P> ".join([p["passage_text"] for p in res_passages]) for res_passages in res_passages_lst
|
||||
]
|
||||
all_res_lists = []
|
||||
for (res_passages, dl, il) in zip(res_passages_lst, D, I):
|
||||
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
|
||||
for r, sc, i in zip(res_list, dl, il):
|
||||
r["passage_id"] = int(i)
|
||||
r["score"] = float(sc)
|
||||
all_res_lists += [res_list[:]]
|
||||
return support_doc_lst, all_res_lists
|
||||
4
examples/research_projects/longform-qa/requirements.txt
Normal file
4
examples/research_projects/longform-qa/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
datasets >= 1.1.3
|
||||
faiss-cpu
|
||||
streamlit
|
||||
elasticsearch
|
||||
23
examples/research_projects/mm-imdb/README.md
Normal file
23
examples/research_projects/mm-imdb/README.md
Normal file
@@ -0,0 +1,23 @@
|
||||
## MM-IMDb
|
||||
|
||||
Based on the script [`run_mmimdb.py`](https://github.com/huggingface/transformers/blob/master/examples/contrib/mm-imdb/run_mmimdb.py).
|
||||
|
||||
[MM-IMDb](http://lisi1.unal.edu.co/mmimdb/) is a Multimodal dataset with around 26,000 movies including images, plots and other metadata.
|
||||
|
||||
### Training on MM-IMDb
|
||||
|
||||
```
|
||||
python run_mmimdb.py \
|
||||
--data_dir /path/to/mmimdb/dataset/ \
|
||||
--model_type bert \
|
||||
--model_name_or_path bert-base-uncased \
|
||||
--output_dir /path/to/save/dir/ \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--max_seq_len 512 \
|
||||
--gradient_accumulation_steps 20 \
|
||||
--num_image_embeds 3 \
|
||||
--num_train_epochs 100 \
|
||||
--patience 5
|
||||
```
|
||||
|
||||
572
examples/research_projects/mm-imdb/run_mmimdb.py
Normal file
572
examples/research_projects/mm-imdb/run_mmimdb.py
Normal file
@@ -0,0 +1,572 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# Copyright (c) HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Finetuning the library models for multimodal multiclass prediction on MM-IMDB dataset."""
|
||||
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from sklearn.metrics import f1_score
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoTokenizer,
|
||||
MMBTConfig,
|
||||
MMBTForClassification,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from transformers.trainer_utils import is_main_process
|
||||
from utils_mmimdb import ImageEncoder, JsonlDataset, collate_fn, get_image_transforms, get_mmimdb_labels
|
||||
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except ImportError:
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer, criterion):
|
||||
""" Train the model """
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter()
|
||||
|
||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
sampler=train_sampler,
|
||||
batch_size=args.train_batch_size,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=args.num_workers,
|
||||
)
|
||||
|
||||
if args.max_steps > 0:
|
||||
t_total = args.max_steps
|
||||
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
||||
else:
|
||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||
]
|
||||
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||
|
||||
# multi-gpu training (should be after apex fp16 initialization)
|
||||
if args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||
)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size
|
||||
* args.gradient_accumulation_steps
|
||||
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
global_step = 0
|
||||
tr_loss, logging_loss = 0.0, 0.0
|
||||
best_f1, n_no_improve = 0, 0
|
||||
model.zero_grad()
|
||||
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
||||
set_seed(args) # Added here for reproductibility
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
model.train()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
labels = batch[5]
|
||||
inputs = {
|
||||
"input_ids": batch[0],
|
||||
"input_modal": batch[2],
|
||||
"attention_mask": batch[1],
|
||||
"modal_start_tokens": batch[3],
|
||||
"modal_end_tokens": batch[4],
|
||||
}
|
||||
outputs = model(**inputs)
|
||||
logits = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||
loss = criterion(logits, labels)
|
||||
|
||||
if args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
if args.fp16:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
tr_loss += loss.item()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
if args.fp16:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step() # Update learning rate schedule
|
||||
model.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
logs = {}
|
||||
if (
|
||||
args.local_rank == -1 and args.evaluate_during_training
|
||||
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||
results = evaluate(args, model, tokenizer, criterion)
|
||||
for key, value in results.items():
|
||||
eval_key = "eval_{}".format(key)
|
||||
logs[eval_key] = value
|
||||
|
||||
loss_scalar = (tr_loss - logging_loss) / args.logging_steps
|
||||
learning_rate_scalar = scheduler.get_lr()[0]
|
||||
logs["learning_rate"] = learning_rate_scalar
|
||||
logs["loss"] = loss_scalar
|
||||
logging_loss = tr_loss
|
||||
|
||||
for key, value in logs.items():
|
||||
tb_writer.add_scalar(key, value, global_step)
|
||||
print(json.dumps({**logs, **{"step": global_step}}))
|
||||
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
torch.save(model_to_save.state_dict(), os.path.join(output_dir, WEIGHTS_NAME))
|
||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
epoch_iterator.close()
|
||||
break
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
train_iterator.close()
|
||||
break
|
||||
|
||||
if args.local_rank == -1:
|
||||
results = evaluate(args, model, tokenizer, criterion)
|
||||
if results["micro_f1"] > best_f1:
|
||||
best_f1 = results["micro_f1"]
|
||||
n_no_improve = 0
|
||||
else:
|
||||
n_no_improve += 1
|
||||
|
||||
if n_no_improve > args.patience:
|
||||
train_iterator.close()
|
||||
break
|
||||
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer.close()
|
||||
|
||||
return global_step, tr_loss / global_step
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, criterion, prefix=""):
|
||||
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
||||
eval_output_dir = args.output_dir
|
||||
eval_dataset = load_examples(args, tokenizer, evaluate=True)
|
||||
|
||||
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(eval_output_dir)
|
||||
|
||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(eval_dataset)
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate_fn
|
||||
)
|
||||
|
||||
# multi-gpu eval
|
||||
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Eval!
|
||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||
logger.info(" Num examples = %d", len(eval_dataset))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
eval_loss = 0.0
|
||||
nb_eval_steps = 0
|
||||
preds = None
|
||||
out_label_ids = None
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
model.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
|
||||
with torch.no_grad():
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
labels = batch[5]
|
||||
inputs = {
|
||||
"input_ids": batch[0],
|
||||
"input_modal": batch[2],
|
||||
"attention_mask": batch[1],
|
||||
"modal_start_tokens": batch[3],
|
||||
"modal_end_tokens": batch[4],
|
||||
}
|
||||
outputs = model(**inputs)
|
||||
logits = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||
tmp_eval_loss = criterion(logits, labels)
|
||||
eval_loss += tmp_eval_loss.mean().item()
|
||||
nb_eval_steps += 1
|
||||
if preds is None:
|
||||
preds = torch.sigmoid(logits).detach().cpu().numpy() > 0.5
|
||||
out_label_ids = labels.detach().cpu().numpy()
|
||||
else:
|
||||
preds = np.append(preds, torch.sigmoid(logits).detach().cpu().numpy() > 0.5, axis=0)
|
||||
out_label_ids = np.append(out_label_ids, labels.detach().cpu().numpy(), axis=0)
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
result = {
|
||||
"loss": eval_loss,
|
||||
"macro_f1": f1_score(out_label_ids, preds, average="macro"),
|
||||
"micro_f1": f1_score(out_label_ids, preds, average="micro"),
|
||||
}
|
||||
|
||||
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results {} *****".format(prefix))
|
||||
for key in sorted(result.keys()):
|
||||
logger.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def load_examples(args, tokenizer, evaluate=False):
|
||||
path = os.path.join(args.data_dir, "dev.jsonl" if evaluate else "train.jsonl")
|
||||
transforms = get_image_transforms()
|
||||
labels = get_mmimdb_labels()
|
||||
dataset = JsonlDataset(path, tokenizer, transforms, labels, args.max_seq_length - args.num_image_embeds - 2)
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the .jsonl files for MMIMDB.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
|
||||
# Other parameters
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from huggingface.co",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_image_embeds", default=1, type=int, help="Number of Image Embeddings from the Image Encoder"
|
||||
)
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||
parser.add_argument(
|
||||
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||
)
|
||||
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument(
|
||||
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
|
||||
)
|
||||
parser.add_argument("--patience", default=5, type=int, help="Patience for Early Stopping.")
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||
)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
|
||||
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument(
|
||||
"--eval_all_checkpoints",
|
||||
action="store_true",
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||
)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||
parser.add_argument("--num_workers", type=int, default=8, help="number of worker threads for dataloading")
|
||||
parser.add_argument(
|
||||
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if (
|
||||
os.path.exists(args.output_dir)
|
||||
and os.listdir(args.output_dir)
|
||||
and args.do_train
|
||||
and not args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||
args.output_dir
|
||||
)
|
||||
)
|
||||
|
||||
# Setup distant debugging if needed
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
|
||||
# Setup CUDA, GPU & distributed training
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
|
||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
args.n_gpu = 1
|
||||
|
||||
args.device = device
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank,
|
||||
device,
|
||||
args.n_gpu,
|
||||
bool(args.local_rank != -1),
|
||||
args.fp16,
|
||||
)
|
||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||
if is_main_process(args.local_rank):
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
# Setup model
|
||||
labels = get_mmimdb_labels()
|
||||
num_labels = len(labels)
|
||||
transformer_config = AutoConfig.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
transformer = AutoModel.from_pretrained(
|
||||
args.model_name_or_path, config=transformer_config, cache_dir=args.cache_dir
|
||||
)
|
||||
img_encoder = ImageEncoder(args)
|
||||
config = MMBTConfig(transformer_config, num_labels=num_labels)
|
||||
model = MMBTForClassification(config, transformer, img_encoder)
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
model.to(args.device)
|
||||
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Training
|
||||
if args.do_train:
|
||||
train_dataset = load_examples(args, tokenizer, evaluate=False)
|
||||
label_frequences = train_dataset.get_label_frequencies()
|
||||
label_frequences = [label_frequences[l] for l in labels]
|
||||
label_weights = (
|
||||
torch.tensor(label_frequences, device=args.device, dtype=torch.float) / len(train_dataset)
|
||||
) ** -1
|
||||
criterion = nn.BCEWithLogitsLoss(pos_weight=label_weights)
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer, criterion)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
torch.save(model_to_save.state_dict(), os.path.join(args.output_dir, WEIGHTS_NAME))
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = MMBTForClassification(config, transformer, img_encoder)
|
||||
model.load_state_dict(torch.load(os.path.join(args.output_dir, WEIGHTS_NAME)))
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(
|
||||
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||
)
|
||||
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
for checkpoint in checkpoints:
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||
model = MMBTForClassification(config, transformer, img_encoder)
|
||||
model.load_state_dict(torch.load(checkpoint))
|
||||
model.to(args.device)
|
||||
result = evaluate(args, model, tokenizer, criterion, prefix=prefix)
|
||||
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
||||
results.update(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
146
examples/research_projects/mm-imdb/utils_mmimdb.py
Normal file
146
examples/research_projects/mm-imdb/utils_mmimdb.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# Copyright (c) HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import Counter
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
POOLING_BREAKDOWN = {1: (1, 1), 2: (2, 1), 3: (3, 1), 4: (2, 2), 5: (5, 1), 6: (3, 2), 7: (7, 1), 8: (4, 2), 9: (3, 3)}
|
||||
|
||||
|
||||
class ImageEncoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
model = torchvision.models.resnet152(pretrained=True)
|
||||
modules = list(model.children())[:-2]
|
||||
self.model = nn.Sequential(*modules)
|
||||
self.pool = nn.AdaptiveAvgPool2d(POOLING_BREAKDOWN[args.num_image_embeds])
|
||||
|
||||
def forward(self, x):
|
||||
# Bx3x224x224 -> Bx2048x7x7 -> Bx2048xN -> BxNx2048
|
||||
out = self.pool(self.model(x))
|
||||
out = torch.flatten(out, start_dim=2)
|
||||
out = out.transpose(1, 2).contiguous()
|
||||
return out # BxNx2048
|
||||
|
||||
|
||||
class JsonlDataset(Dataset):
|
||||
def __init__(self, data_path, tokenizer, transforms, labels, max_seq_length):
|
||||
self.data = [json.loads(l) for l in open(data_path)]
|
||||
self.data_dir = os.path.dirname(data_path)
|
||||
self.tokenizer = tokenizer
|
||||
self.labels = labels
|
||||
self.n_classes = len(labels)
|
||||
self.max_seq_length = max_seq_length
|
||||
|
||||
self.transforms = transforms
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sentence = torch.LongTensor(self.tokenizer.encode(self.data[index]["text"], add_special_tokens=True))
|
||||
start_token, sentence, end_token = sentence[0], sentence[1:-1], sentence[-1]
|
||||
sentence = sentence[: self.max_seq_length]
|
||||
|
||||
label = torch.zeros(self.n_classes)
|
||||
label[[self.labels.index(tgt) for tgt in self.data[index]["label"]]] = 1
|
||||
|
||||
image = Image.open(os.path.join(self.data_dir, self.data[index]["img"])).convert("RGB")
|
||||
image = self.transforms(image)
|
||||
|
||||
return {
|
||||
"image_start_token": start_token,
|
||||
"image_end_token": end_token,
|
||||
"sentence": sentence,
|
||||
"image": image,
|
||||
"label": label,
|
||||
}
|
||||
|
||||
def get_label_frequencies(self):
|
||||
label_freqs = Counter()
|
||||
for row in self.data:
|
||||
label_freqs.update(row["label"])
|
||||
return label_freqs
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
lens = [len(row["sentence"]) for row in batch]
|
||||
bsz, max_seq_len = len(batch), max(lens)
|
||||
|
||||
mask_tensor = torch.zeros(bsz, max_seq_len, dtype=torch.long)
|
||||
text_tensor = torch.zeros(bsz, max_seq_len, dtype=torch.long)
|
||||
|
||||
for i_batch, (input_row, length) in enumerate(zip(batch, lens)):
|
||||
text_tensor[i_batch, :length] = input_row["sentence"]
|
||||
mask_tensor[i_batch, :length] = 1
|
||||
|
||||
img_tensor = torch.stack([row["image"] for row in batch])
|
||||
tgt_tensor = torch.stack([row["label"] for row in batch])
|
||||
img_start_token = torch.stack([row["image_start_token"] for row in batch])
|
||||
img_end_token = torch.stack([row["image_end_token"] for row in batch])
|
||||
|
||||
return text_tensor, mask_tensor, img_tensor, img_start_token, img_end_token, tgt_tensor
|
||||
|
||||
|
||||
def get_mmimdb_labels():
|
||||
return [
|
||||
"Crime",
|
||||
"Drama",
|
||||
"Thriller",
|
||||
"Action",
|
||||
"Comedy",
|
||||
"Romance",
|
||||
"Documentary",
|
||||
"Short",
|
||||
"Mystery",
|
||||
"History",
|
||||
"Family",
|
||||
"Adventure",
|
||||
"Fantasy",
|
||||
"Sci-Fi",
|
||||
"Western",
|
||||
"Horror",
|
||||
"Sport",
|
||||
"War",
|
||||
"Music",
|
||||
"Musical",
|
||||
"Animation",
|
||||
"Biography",
|
||||
"Film-Noir",
|
||||
]
|
||||
|
||||
|
||||
def get_image_transforms():
|
||||
return transforms.Compose(
|
||||
[
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.46777044, 0.44531429, 0.40661017],
|
||||
std=[0.12221994, 0.12145835, 0.14380469],
|
||||
),
|
||||
]
|
||||
)
|
||||
185
examples/research_projects/movement-pruning/README.md
Normal file
185
examples/research_projects/movement-pruning/README.md
Normal file
@@ -0,0 +1,185 @@
|
||||
# Movement Pruning: Adaptive Sparsity by Fine-Tuning
|
||||
|
||||
Author: @VictorSanh
|
||||
|
||||
*Magnitude pruning is a widely used strategy for reducing model size in pure supervised learning; however, it is less effective in the transfer learning regime that has become standard for state-of-the-art natural language processing applications. We propose the use of *movement pruning*, a simple, deterministic first-order weight pruning method that is more adaptive to pretrained model fine-tuning. Experiments show that when pruning large pretrained language models, movement pruning shows significant improvements in high-sparsity regimes. When combined with distillation, the approach achieves minimal accuracy loss with down to only 3% of the model parameters:*
|
||||
|
||||
| Fine-pruning+Distillation<br>(Teacher=BERT-base fine-tuned) | BERT base<br>fine-tuned | Remaining<br>Weights (%) | Magnitude Pruning | L0 Regularization | Movement Pruning | Soft Movement Pruning |
|
||||
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
|
||||
| SQuAD - Dev<br>EM/F1 | 80.4/88.1 | 10%<br>3% | 70.2/80.1<br>45.5/59.6 | 72.4/81.9<br>64.3/75.8 | 75.6/84.3<br>67.5/78.0 | **76.6/84.9**<br>**72.7/82.3** |
|
||||
| MNLI - Dev<br>acc/MM acc | 84.5/84.9 | 10%<br>3% | 78.3/79.3<br>69.4/70.6 | 78.7/79.7<br>76.0/76.2 | 80.1/80.4<br>76.5/77.4 | **81.2/81.8**<br>**79.5/80.1** |
|
||||
| QQP - Dev<br>acc/F1 | 91.4/88.4 | 10%<br>3% | 79.8/65.0<br>72.4/57.8 | 88.1/82.8<br>87.0/81.9 | 89.7/86.2<br>86.1/81.5 | **90.2/86.8**<br>**89.1/85.5** |
|
||||
|
||||
This page contains information on how to fine-prune pre-trained models such as `BERT` to obtain extremely sparse models with movement pruning. In contrast to magnitude pruning which selects weights that are far from 0, movement pruning retains weights that are moving away from 0.
|
||||
|
||||
For more information, we invite you to check out [our paper](https://arxiv.org/abs/2005.07683).
|
||||
You can also have a look at this fun *Explain Like I'm Five* introductory [slide deck](https://www.slideshare.net/VictorSanh/movement-pruning-explain-like-im-five-234205241).
|
||||
|
||||
<div align="center">
|
||||
<img src="https://www.seekpng.com/png/detail/166-1669328_how-to-make-emmental-cheese-at-home-icooker.png" width="400">
|
||||
</div>
|
||||
|
||||
## Extreme sparsity and efficient storage
|
||||
|
||||
One promise of extreme pruning is to obtain extremely small models that can be easily sent (and stored) on edge devices. By setting weights to 0., we reduce the amount of information we need to store, and thus decreasing the memory size. We are able to obtain extremely sparse fine-pruned models with movement pruning: ~95% of the dense performance with ~5% of total remaining weights in the BERT encoder.
|
||||
|
||||
In [this notebook](https://github.com/huggingface/transformers/blob/master/examples/movement-pruning/Saving_PruneBERT.ipynb), we showcase how we can leverage standard tools that exist out-of-the-box to efficiently store an extremely sparse question answering model (only 6% of total remaining weights in the encoder). We are able to reduce the memory size of the encoder **from the 340MB (the original dense BERT) to 11MB**, without any additional training of the model (every operation is performed *post fine-pruning*). It is sufficiently small to store it on a [91' floppy disk](https://en.wikipedia.org/wiki/Floptical) 📎!
|
||||
|
||||
While movement pruning does not directly optimize for memory footprint (but rather the number of non-null weights), we hypothetize that further memory compression ratios can be achieved with specific quantization aware trainings (see for instance [Q8BERT](https://arxiv.org/abs/1910.06188), [And the Bit Goes Down](https://arxiv.org/abs/1907.05686) or [Quant-Noise](https://arxiv.org/abs/2004.07320)).
|
||||
|
||||
## Fine-pruned models
|
||||
|
||||
As examples, we release two English PruneBERT checkpoints (models fine-pruned from a pre-trained `BERT` checkpoint), one on SQuAD and the other on MNLI.
|
||||
|
||||
- **`prunebert-base-uncased-6-finepruned-w-distil-squad`**<br/>
|
||||
Pre-trained `BERT-base-uncased` fine-pruned with soft movement pruning on SQuAD v1.1. We use an additional distillation signal from `BERT-base-uncased` finetuned on SQuAD. The encoder counts 6% of total non-null weights and reaches 83.8 F1 score. The model can be accessed with: `pruned_bert = BertForQuestionAnswering.from_pretrained("huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad")`
|
||||
- **`prunebert-base-uncased-6-finepruned-w-distil-mnli`**<br/>
|
||||
Pre-trained `BERT-base-uncased` fine-pruned with soft movement pruning on MNLI. We use an additional distillation signal from `BERT-base-uncased` finetuned on MNLI. The encoder counts 6% of total non-null weights and reaches 80.7 (matched) accuracy. The model can be accessed with: `pruned_bert = BertForSequenceClassification.from_pretrained("huggingface/prunebert-base-uncased-6-finepruned-w-distil-mnli")`
|
||||
|
||||
## How to fine-prune?
|
||||
|
||||
### Setup
|
||||
|
||||
The code relies on the 🤗 Transformers library. In addition to the dependencies listed in the [`examples`](https://github.com/huggingface/transformers/tree/master/examples) folder, you should install a few additional dependencies listed in the `requirements.txt` file: `pip install -r requirements.txt`.
|
||||
|
||||
Note that we built our experiments on top of a stabilized version of the library (commit https://github.com/huggingface/transformers/commit/352d5472b0c1dec0f420d606d16747d851b4bda8): we do not guarantee that everything is still compatible with the latest version of the master branch.
|
||||
|
||||
### Fine-pruning with movement pruning
|
||||
|
||||
Below, we detail how to reproduce the results reported in the paper. We use SQuAD as a running example. Commands (and scripts) can be easily adapted for other tasks.
|
||||
|
||||
The following command fine-prunes a pre-trained `BERT-base` on SQuAD using movement pruning towards 15% of remaining weights (85% sparsity). Note that we freeze all the embeddings modules (from their pre-trained value) and only prune the Fully Connected layers in the encoder (12 layers of Transformer Block).
|
||||
|
||||
```bash
|
||||
SERIALIZATION_DIR=<OUTPUT_DIR>
|
||||
SQUAD_DATA=<SQUAD_DATA>
|
||||
|
||||
python examples/movement-pruning/masked_run_squad.py \
|
||||
--output_dir $SERIALIZATION_DIR \
|
||||
--data_dir $SQUAD_DATA \
|
||||
--train_file train-v1.1.json \
|
||||
--predict_file dev-v1.1.json \
|
||||
--do_train --do_eval --do_lower_case \
|
||||
--model_type masked_bert \
|
||||
--model_name_or_path bert-base-uncased \
|
||||
--per_gpu_train_batch_size 16 \
|
||||
--warmup_steps 5400 \
|
||||
--num_train_epochs 10 \
|
||||
--learning_rate 3e-5 --mask_scores_learning_rate 1e-2 \
|
||||
--initial_threshold 1 --final_threshold 0.15 \
|
||||
--initial_warmup 1 --final_warmup 2 \
|
||||
--pruning_method topK --mask_init constant --mask_scale 0.
|
||||
```
|
||||
|
||||
### Fine-pruning with other methods
|
||||
|
||||
We can also explore other fine-pruning methods by changing the `pruning_method` parameter:
|
||||
|
||||
Soft movement pruning
|
||||
```bash
|
||||
python examples/movement-pruning/masked_run_squad.py \
|
||||
--output_dir $SERIALIZATION_DIR \
|
||||
--data_dir $SQUAD_DATA \
|
||||
--train_file train-v1.1.json \
|
||||
--predict_file dev-v1.1.json \
|
||||
--do_train --do_eval --do_lower_case \
|
||||
--model_type masked_bert \
|
||||
--model_name_or_path bert-base-uncased \
|
||||
--per_gpu_train_batch_size 16 \
|
||||
--warmup_steps 5400 \
|
||||
--num_train_epochs 10 \
|
||||
--learning_rate 3e-5 --mask_scores_learning_rate 1e-2 \
|
||||
--initial_threshold 0 --final_threshold 0.1 \
|
||||
--initial_warmup 1 --final_warmup 2 \
|
||||
--pruning_method sigmoied_threshold --mask_init constant --mask_scale 0. \
|
||||
--regularization l1 --final_lambda 400.
|
||||
```
|
||||
|
||||
L0 regularization
|
||||
```bash
|
||||
python examples/movement-pruning/masked_run_squad.py \
|
||||
--output_dir $SERIALIZATION_DIR \
|
||||
--data_dir $SQUAD_DATA \
|
||||
--train_file train-v1.1.json \
|
||||
--predict_file dev-v1.1.json \
|
||||
--do_train --do_eval --do_lower_case \
|
||||
--model_type masked_bert \
|
||||
--model_name_or_path bert-base-uncased \
|
||||
--per_gpu_train_batch_size 16 \
|
||||
--warmup_steps 5400 \
|
||||
--num_train_epochs 10 \
|
||||
--learning_rate 3e-5 --mask_scores_learning_rate 1e-1 \
|
||||
--initial_threshold 1. --final_threshold 1. \
|
||||
--initial_warmup 1 --final_warmup 1 \
|
||||
--pruning_method l0 --mask_init constant --mask_scale 2.197 \
|
||||
--regularization l0 --final_lambda 125.
|
||||
```
|
||||
|
||||
Iterative Magnitude Pruning
|
||||
```bash
|
||||
python examples/movement-pruning/masked_run_squad.py \
|
||||
--output_dir ./dbg \
|
||||
--data_dir examples/distillation/data/squad_data \
|
||||
--train_file train-v1.1.json \
|
||||
--predict_file dev-v1.1.json \
|
||||
--do_train --do_eval --do_lower_case \
|
||||
--model_type masked_bert \
|
||||
--model_name_or_path bert-base-uncased \
|
||||
--per_gpu_train_batch_size 16 \
|
||||
--warmup_steps 5400 \
|
||||
--num_train_epochs 10 \
|
||||
--learning_rate 3e-5 \
|
||||
--initial_threshold 1 --final_threshold 0.15 \
|
||||
--initial_warmup 1 --final_warmup 2 \
|
||||
--pruning_method magnitude
|
||||
```
|
||||
|
||||
### After fine-pruning
|
||||
|
||||
**Counting parameters**
|
||||
|
||||
Regularization based pruning methods (soft movement pruning and L0 regularization) rely on the penalty to induce sparsity. The multiplicative coefficient controls the sparsity level.
|
||||
To obtain the effective sparsity level in the encoder, we simply count the number of activated (non-null) weights:
|
||||
|
||||
```bash
|
||||
python examples/movement-pruning/counts_parameters.py \
|
||||
--pruning_method sigmoied_threshold \
|
||||
--threshold 0.1 \
|
||||
--serialization_dir $SERIALIZATION_DIR
|
||||
```
|
||||
|
||||
**Pruning once for all**
|
||||
|
||||
Once the model has been fine-pruned, the pruned weights can be set to 0. once for all (reducing the amount of information to store). In our running experiments, we can convert a `MaskedBertForQuestionAnswering` (a BERT model augmented to enable on-the-fly pruning capabilities) to a standard `BertForQuestionAnswering`:
|
||||
|
||||
```bash
|
||||
python examples/movement-pruning/bertarize.py \
|
||||
--pruning_method sigmoied_threshold \
|
||||
--threshold 0.1 \
|
||||
--model_name_or_path $SERIALIZATION_DIR
|
||||
```
|
||||
|
||||
## Hyper-parameters
|
||||
|
||||
For reproducibility purposes, we share the detailed results presented in the paper. These [tables](https://docs.google.com/spreadsheets/d/17JgRq_OFFTniUrz6BZWW_87DjFkKXpI1kYDSsseT_7g/edit?usp=sharing) exhaustively describe the individual hyper-parameters used for each data point.
|
||||
|
||||
## Inference speed
|
||||
|
||||
Early experiments show that even though models fine-pruned with (soft) movement pruning are extremely sparse, they do not benefit from significant improvement in terms of inference speed when using the standard PyTorch inference.
|
||||
We are currently benchmarking and exploring inference setups specifically for sparse architectures.
|
||||
In particular, hardware manufacturers are announcing devices that will speedup inference for sparse networks considerably.
|
||||
|
||||
## Citation
|
||||
|
||||
If you find this resource useful, please consider citing the following paper:
|
||||
|
||||
```
|
||||
@article{sanh2020movement,
|
||||
title={Movement Pruning: Adaptive Sparsity by Fine-Tuning},
|
||||
author={Victor Sanh and Thomas Wolf and Alexander M. Rush},
|
||||
year={2020},
|
||||
eprint={2005.07683},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL}
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,634 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Saving PruneBERT\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"This notebook aims at showcasing how we can leverage standard tools to save (and load) an extremely sparse model fine-pruned with [movement pruning](https://arxiv.org/abs/2005.07683) (or any other unstructured pruning mehtod).\n",
|
||||
"\n",
|
||||
"In this example, we used BERT (base-uncased, but the procedure described here is not specific to BERT and can be applied to a large variety of models.\n",
|
||||
"\n",
|
||||
"We first obtain an extremely sparse model by fine-pruning with movement pruning on SQuAD v1.1. We then used the following combination of standard tools:\n",
|
||||
"- We reduce the precision of the model with Int8 dynamic quantization using [PyTorch implementation](https://pytorch.org/tutorials/intermediate/dynamic_quantization_bert_tutorial.html). We only quantized the Fully Connected Layers.\n",
|
||||
"- Sparse quantized matrices are converted into the [Compressed Sparse Row format](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.html).\n",
|
||||
"- We use HDF5 with `gzip` compression to store the weights.\n",
|
||||
"\n",
|
||||
"We experiment with a question answering model with only 6% of total remaining weights in the encoder (previously obtained with movement pruning). **We are able to reduce the memory size of the encoder from 340MB (original dense BERT) to 11MB**, which fits on a [91' floppy disk](https://en.wikipedia.org/wiki/Floptical)!\n",
|
||||
"\n",
|
||||
"<img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/0/00/Floptical_disk_21MB.jpg/440px-Floptical_disk_21MB.jpg\" width=\"200\">\n",
|
||||
"\n",
|
||||
"*Note: this notebook is compatible with `torch>=1.5.0` If you are using, `torch==1.4.0`, please refer to [this previous version of the notebook](https://github.com/huggingface/transformers/commit/b11386e158e86e62d4041eabd86d044cd1695737).*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Includes\n",
|
||||
"\n",
|
||||
"import h5py\n",
|
||||
"import os\n",
|
||||
"import json\n",
|
||||
"from collections import OrderedDict\n",
|
||||
"\n",
|
||||
"from scipy import sparse\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"from torch import nn\n",
|
||||
"\n",
|
||||
"from transformers import *\n",
|
||||
"\n",
|
||||
"os.chdir('../../')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Saving"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Dynamic quantization induces little or no loss of performance while significantly reducing the memory footprint."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load fine-pruned model and quantize the model\n",
|
||||
"\n",
|
||||
"model = BertForQuestionAnswering.from_pretrained(\"huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad\")\n",
|
||||
"model.to('cpu')\n",
|
||||
"\n",
|
||||
"quantized_model = torch.quantization.quantize_dynamic(\n",
|
||||
" model=model,\n",
|
||||
" qconfig_spec = {\n",
|
||||
" torch.nn.Linear : torch.quantization.default_dynamic_qconfig,\n",
|
||||
" },\n",
|
||||
" dtype=torch.qint8,\n",
|
||||
" )\n",
|
||||
"# print(quantized_model)\n",
|
||||
"\n",
|
||||
"qtz_st = quantized_model.state_dict()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Saving the original (encoder + classifier) in the standard torch.save format\n",
|
||||
"\n",
|
||||
"dense_st = {name: param for name, param in model.state_dict().items() \n",
|
||||
" if \"embedding\" not in name and \"pooler\" not in name}\n",
|
||||
"torch.save(dense_st, 'dbg/dense_squad.pt',)\n",
|
||||
"dense_mb_size = os.path.getsize(\"dbg/dense_squad.pt\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Decompose quantization for bert.encoder.layer.0.attention.self.query._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.0.attention.self.key._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.0.attention.self.value._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.0.attention.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.0.intermediate.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.0.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.1.attention.self.query._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.1.attention.self.key._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.1.attention.self.value._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.1.attention.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.1.intermediate.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.1.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.2.attention.self.query._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.2.attention.self.key._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.2.attention.self.value._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.2.attention.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.2.intermediate.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.2.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.3.attention.self.query._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.3.attention.self.key._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.3.attention.self.value._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.3.attention.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.3.intermediate.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.3.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.4.attention.self.query._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.4.attention.self.key._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.4.attention.self.value._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.4.attention.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.4.intermediate.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.4.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.5.attention.self.query._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.5.attention.self.key._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.5.attention.self.value._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.5.attention.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.5.intermediate.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.5.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.6.attention.self.query._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.6.attention.self.key._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.6.attention.self.value._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.6.attention.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.6.intermediate.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.6.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.7.attention.self.query._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.7.attention.self.key._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.7.attention.self.value._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.7.attention.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.7.intermediate.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.7.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.8.attention.self.query._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.8.attention.self.key._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.8.attention.self.value._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.8.attention.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.8.intermediate.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.8.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.9.attention.self.query._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.9.attention.self.key._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.9.attention.self.value._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.9.attention.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.9.intermediate.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.9.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.10.attention.self.query._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.10.attention.self.key._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.10.attention.self.value._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.10.attention.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.10.intermediate.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.10.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.11.attention.self.query._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.11.attention.self.key._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.11.attention.self.value._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.11.attention.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.11.intermediate.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.11.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.pooler.dense._packed_params.weight\n",
|
||||
"Decompose quantization for qa_outputs._packed_params.weight\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Elementary representation: we decompose the quantized tensors into (scale, zero_point, int_repr).\n",
|
||||
"# See https://pytorch.org/docs/stable/quantization.html\n",
|
||||
"\n",
|
||||
"# We further leverage the fact that int_repr is sparse matrix to optimize the storage: we decompose int_repr into\n",
|
||||
"# its CSR representation (data, indptr, indices).\n",
|
||||
"\n",
|
||||
"elementary_qtz_st = {}\n",
|
||||
"for name, param in qtz_st.items():\n",
|
||||
" if \"dtype\" not in name and param.is_quantized:\n",
|
||||
" print(\"Decompose quantization for\", name)\n",
|
||||
" # We need to extract the scale, the zero_point and the int_repr for the quantized tensor and modules\n",
|
||||
" scale = param.q_scale() # torch.tensor(1,) - float32\n",
|
||||
" zero_point = param.q_zero_point() # torch.tensor(1,) - int32\n",
|
||||
" elementary_qtz_st[f\"{name}.scale\"] = scale\n",
|
||||
" elementary_qtz_st[f\"{name}.zero_point\"] = zero_point\n",
|
||||
"\n",
|
||||
" # We assume the int_repr is sparse and compute its CSR representation\n",
|
||||
" # Only the FCs in the encoder are actually sparse\n",
|
||||
" int_repr = param.int_repr() # torch.tensor(nb_rows, nb_columns) - int8\n",
|
||||
" int_repr_cs = sparse.csr_matrix(int_repr) # scipy.sparse.csr.csr_matrix\n",
|
||||
"\n",
|
||||
" elementary_qtz_st[f\"{name}.int_repr.data\"] = int_repr_cs.data # np.array int8\n",
|
||||
" elementary_qtz_st[f\"{name}.int_repr.indptr\"] = int_repr_cs.indptr # np.array int32\n",
|
||||
" assert max(int_repr_cs.indices) < 65535 # If not, we shall fall back to int32\n",
|
||||
" elementary_qtz_st[f\"{name}.int_repr.indices\"] = np.uint16(int_repr_cs.indices) # np.array uint16\n",
|
||||
" elementary_qtz_st[f\"{name}.int_repr.shape\"] = int_repr_cs.shape # tuple(int, int)\n",
|
||||
" else:\n",
|
||||
" elementary_qtz_st[name] = param\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create mapping from torch.dtype to string description (we could also used an int8 instead of string)\n",
|
||||
"str_2_dtype = {\"qint8\": torch.qint8}\n",
|
||||
"dtype_2_str = {torch.qint8: \"qint8\"}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Encoder Size (MB) - Sparse & Quantized - `torch.save`: 21.29\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Saving the pruned (encoder + classifier) in the standard torch.save format\n",
|
||||
"\n",
|
||||
"dense_optimized_st = {name: param for name, param in elementary_qtz_st.items() \n",
|
||||
" if \"embedding\" not in name and \"pooler\" not in name}\n",
|
||||
"torch.save(dense_optimized_st, 'dbg/dense_squad_optimized.pt',)\n",
|
||||
"print(\"Encoder Size (MB) - Sparse & Quantized - `torch.save`:\",\n",
|
||||
" round(os.path.getsize(\"dbg/dense_squad_optimized.pt\")/1e6, 2))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Skip bert.embeddings.word_embeddings.weight\n",
|
||||
"Skip bert.embeddings.position_embeddings.weight\n",
|
||||
"Skip bert.embeddings.token_type_embeddings.weight\n",
|
||||
"Skip bert.embeddings.LayerNorm.weight\n",
|
||||
"Skip bert.embeddings.LayerNorm.bias\n",
|
||||
"Skip bert.pooler.dense.scale\n",
|
||||
"Skip bert.pooler.dense.zero_point\n",
|
||||
"Skip bert.pooler.dense._packed_params.weight.scale\n",
|
||||
"Skip bert.pooler.dense._packed_params.weight.zero_point\n",
|
||||
"Skip bert.pooler.dense._packed_params.weight.int_repr.data\n",
|
||||
"Skip bert.pooler.dense._packed_params.weight.int_repr.indptr\n",
|
||||
"Skip bert.pooler.dense._packed_params.weight.int_repr.indices\n",
|
||||
"Skip bert.pooler.dense._packed_params.weight.int_repr.shape\n",
|
||||
"Skip bert.pooler.dense._packed_params.bias\n",
|
||||
"Skip bert.pooler.dense._packed_params.dtype\n",
|
||||
"\n",
|
||||
"Encoder Size (MB) - Dense: 340.26\n",
|
||||
"Encoder Size (MB) - Sparse & Quantized: 11.28\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Save the decomposed state_dict with an HDF5 file\n",
|
||||
"# Saving only the encoder + QA Head\n",
|
||||
"\n",
|
||||
"with h5py.File('dbg/squad_sparse.h5','w') as hf:\n",
|
||||
" for name, param in elementary_qtz_st.items():\n",
|
||||
" if \"embedding\" in name:\n",
|
||||
" print(f\"Skip {name}\")\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" if \"pooler\" in name:\n",
|
||||
" print(f\"Skip {name}\")\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" if type(param) == torch.Tensor:\n",
|
||||
" if param.numel() == 1:\n",
|
||||
" # module scale\n",
|
||||
" # module zero_point\n",
|
||||
" hf.attrs[name] = param\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" if param.requires_grad:\n",
|
||||
" # LayerNorm\n",
|
||||
" param = param.detach().numpy()\n",
|
||||
" hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
|
||||
"\n",
|
||||
" elif type(param) == float or type(param) == int or type(param) == tuple:\n",
|
||||
" # float - tensor _packed_params.weight.scale\n",
|
||||
" # int - tensor _packed_params.weight.zero_point\n",
|
||||
" # tuple - tensor _packed_params.weight.shape\n",
|
||||
" hf.attrs[name] = param\n",
|
||||
"\n",
|
||||
" elif type(param) == torch.dtype:\n",
|
||||
" # dtype - tensor _packed_params.dtype\n",
|
||||
" hf.attrs[name] = dtype_2_str[param]\n",
|
||||
" \n",
|
||||
" else:\n",
|
||||
" hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"with open('dbg/metadata.json', 'w') as f:\n",
|
||||
" f.write(json.dumps(qtz_st._metadata)) \n",
|
||||
"\n",
|
||||
"size = os.path.getsize(\"dbg/squad_sparse.h5\") + os.path.getsize(\"dbg/metadata.json\")\n",
|
||||
"print(\"\")\n",
|
||||
"print(\"Encoder Size (MB) - Dense: \", round(dense_mb_size/1e6, 2))\n",
|
||||
"print(\"Encoder Size (MB) - Sparse & Quantized:\", round(size/1e6, 2))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Size (MB): 99.41\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Save the decomposed state_dict to HDF5 storage\n",
|
||||
"# Save everything in the architecutre (embedding + encoder + QA Head)\n",
|
||||
"\n",
|
||||
"with h5py.File('dbg/squad_sparse_with_embs.h5','w') as hf:\n",
|
||||
" for name, param in elementary_qtz_st.items():\n",
|
||||
"# if \"embedding\" in name:\n",
|
||||
"# print(f\"Skip {name}\")\n",
|
||||
"# continue\n",
|
||||
"\n",
|
||||
"# if \"pooler\" in name:\n",
|
||||
"# print(f\"Skip {name}\")\n",
|
||||
"# continue\n",
|
||||
"\n",
|
||||
" if type(param) == torch.Tensor:\n",
|
||||
" if param.numel() == 1:\n",
|
||||
" # module scale\n",
|
||||
" # module zero_point\n",
|
||||
" hf.attrs[name] = param\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" if param.requires_grad:\n",
|
||||
" # LayerNorm\n",
|
||||
" param = param.detach().numpy()\n",
|
||||
" hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
|
||||
"\n",
|
||||
" elif type(param) == float or type(param) == int or type(param) == tuple:\n",
|
||||
" # float - tensor _packed_params.weight.scale\n",
|
||||
" # int - tensor _packed_params.weight.zero_point\n",
|
||||
" # tuple - tensor _packed_params.weight.shape\n",
|
||||
" hf.attrs[name] = param\n",
|
||||
"\n",
|
||||
" elif type(param) == torch.dtype:\n",
|
||||
" # dtype - tensor _packed_params.dtype\n",
|
||||
" hf.attrs[name] = dtype_2_str[param]\n",
|
||||
" \n",
|
||||
" else:\n",
|
||||
" hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"with open('dbg/metadata.json', 'w') as f:\n",
|
||||
" f.write(json.dumps(qtz_st._metadata)) \n",
|
||||
"\n",
|
||||
"size = os.path.getsize(\"dbg/squad_sparse_with_embs.h5\") + os.path.getsize(\"dbg/metadata.json\")\n",
|
||||
"print('\\nSize (MB):', round(size/1e6, 2))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Loading"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Reconstruct the elementary state dict\n",
|
||||
"\n",
|
||||
"reconstructed_elementary_qtz_st = {}\n",
|
||||
"\n",
|
||||
"hf = h5py.File('dbg/squad_sparse_with_embs.h5','r')\n",
|
||||
"\n",
|
||||
"for attr_name, attr_param in hf.attrs.items():\n",
|
||||
" if 'shape' in attr_name:\n",
|
||||
" attr_param = tuple(attr_param)\n",
|
||||
" elif \".scale\" in attr_name:\n",
|
||||
" if \"_packed_params\" in attr_name:\n",
|
||||
" attr_param = float(attr_param)\n",
|
||||
" else:\n",
|
||||
" attr_param = torch.tensor(attr_param)\n",
|
||||
" elif \".zero_point\" in attr_name:\n",
|
||||
" if \"_packed_params\" in attr_name:\n",
|
||||
" attr_param = int(attr_param)\n",
|
||||
" else:\n",
|
||||
" attr_param = torch.tensor(attr_param)\n",
|
||||
" elif \".dtype\" in attr_name:\n",
|
||||
" attr_param = str_2_dtype[attr_param]\n",
|
||||
" reconstructed_elementary_qtz_st[attr_name] = attr_param\n",
|
||||
" # print(f\"Unpack {attr_name}\")\n",
|
||||
" \n",
|
||||
"# Get the tensors/arrays\n",
|
||||
"for data_name, data_param in hf.items():\n",
|
||||
" if \"LayerNorm\" in data_name or \"_packed_params.bias\" in data_name:\n",
|
||||
" reconstructed_elementary_qtz_st[data_name] = torch.from_numpy(np.array(data_param))\n",
|
||||
" elif \"embedding\" in data_name:\n",
|
||||
" reconstructed_elementary_qtz_st[data_name] = torch.from_numpy(np.array(data_param))\n",
|
||||
" else: # _packed_params.weight.int_repr.data, _packed_params.weight.int_repr.indices and _packed_params.weight.int_repr.indptr\n",
|
||||
" data_param = np.array(data_param)\n",
|
||||
" if \"indices\" in data_name:\n",
|
||||
" data_param = np.array(data_param, dtype=np.int32)\n",
|
||||
" reconstructed_elementary_qtz_st[data_name] = data_param\n",
|
||||
" # print(f\"Unpack {data_name}\")\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"hf.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Sanity checks\n",
|
||||
"\n",
|
||||
"for name, param in reconstructed_elementary_qtz_st.items():\n",
|
||||
" assert name in elementary_qtz_st\n",
|
||||
"for name, param in elementary_qtz_st.items():\n",
|
||||
" assert name in reconstructed_elementary_qtz_st, name\n",
|
||||
"\n",
|
||||
"for name, param in reconstructed_elementary_qtz_st.items():\n",
|
||||
" assert type(param) == type(elementary_qtz_st[name]), name\n",
|
||||
" if type(param) == torch.Tensor:\n",
|
||||
" assert torch.all(torch.eq(param, elementary_qtz_st[name])), name\n",
|
||||
" elif type(param) == np.ndarray:\n",
|
||||
" assert (param == elementary_qtz_st[name]).all(), name\n",
|
||||
" else:\n",
|
||||
" assert param == elementary_qtz_st[name], name"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Re-assemble the sparse int_repr from the CSR format\n",
|
||||
"\n",
|
||||
"reconstructed_qtz_st = {}\n",
|
||||
"\n",
|
||||
"for name, param in reconstructed_elementary_qtz_st.items():\n",
|
||||
" if \"weight.int_repr.indptr\" in name:\n",
|
||||
" prefix_ = name[:-16]\n",
|
||||
" data = reconstructed_elementary_qtz_st[f\"{prefix_}.int_repr.data\"]\n",
|
||||
" indptr = reconstructed_elementary_qtz_st[f\"{prefix_}.int_repr.indptr\"]\n",
|
||||
" indices = reconstructed_elementary_qtz_st[f\"{prefix_}.int_repr.indices\"]\n",
|
||||
" shape = reconstructed_elementary_qtz_st[f\"{prefix_}.int_repr.shape\"]\n",
|
||||
"\n",
|
||||
" int_repr = sparse.csr_matrix(arg1=(data, indices, indptr),\n",
|
||||
" shape=shape)\n",
|
||||
" int_repr = torch.tensor(int_repr.todense())\n",
|
||||
"\n",
|
||||
" scale = reconstructed_elementary_qtz_st[f\"{prefix_}.scale\"]\n",
|
||||
" zero_point = reconstructed_elementary_qtz_st[f\"{prefix_}.zero_point\"]\n",
|
||||
" weight = torch._make_per_tensor_quantized_tensor(int_repr,\n",
|
||||
" scale,\n",
|
||||
" zero_point)\n",
|
||||
"\n",
|
||||
" reconstructed_qtz_st[f\"{prefix_}\"] = weight\n",
|
||||
" elif \"int_repr.data\" in name or \"int_repr.shape\" in name or \"int_repr.indices\" in name or \\\n",
|
||||
" \"weight.scale\" in name or \"weight.zero_point\" in name:\n",
|
||||
" continue\n",
|
||||
" else:\n",
|
||||
" reconstructed_qtz_st[name] = param\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Sanity checks\n",
|
||||
"\n",
|
||||
"for name, param in reconstructed_qtz_st.items():\n",
|
||||
" assert name in qtz_st\n",
|
||||
"for name, param in qtz_st.items():\n",
|
||||
" assert name in reconstructed_qtz_st, name\n",
|
||||
"\n",
|
||||
"for name, param in reconstructed_qtz_st.items():\n",
|
||||
" assert type(param) == type(qtz_st[name]), name\n",
|
||||
" if type(param) == torch.Tensor:\n",
|
||||
" assert torch.all(torch.eq(param, qtz_st[name])), name\n",
|
||||
" elif type(param) == np.ndarray:\n",
|
||||
" assert (param == qtz_st[name]).all(), name\n",
|
||||
" else:\n",
|
||||
" assert param == qtz_st[name], name"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Sanity checks"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<All keys matched successfully>"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Load the re-constructed state dict into a model\n",
|
||||
"\n",
|
||||
"dummy_model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')\n",
|
||||
"dummy_model.to('cpu')\n",
|
||||
"\n",
|
||||
"reconstructed_qtz_model = torch.quantization.quantize_dynamic(\n",
|
||||
" model=dummy_model,\n",
|
||||
" qconfig_spec = None,\n",
|
||||
" dtype=torch.qint8,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"reconstructed_qtz_st = OrderedDict(reconstructed_qtz_st)\n",
|
||||
"with open('dbg/metadata.json', 'r') as read_file:\n",
|
||||
" metadata = json.loads(read_file.read())\n",
|
||||
"reconstructed_qtz_st._metadata = metadata\n",
|
||||
"\n",
|
||||
"reconstructed_qtz_model.load_state_dict(reconstructed_qtz_st)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sanity check passed\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Sanity checks on the infernce\n",
|
||||
"\n",
|
||||
"N = 32\n",
|
||||
"\n",
|
||||
"for _ in range(25):\n",
|
||||
" inputs = torch.randint(low=0, high=30000, size=(N, 128))\n",
|
||||
" mask = torch.ones(size=(N, 128))\n",
|
||||
"\n",
|
||||
" y_reconstructed = reconstructed_qtz_model(input_ids=inputs, attention_mask=mask)[0]\n",
|
||||
" y = quantized_model(input_ids=inputs, attention_mask=mask)[0]\n",
|
||||
" \n",
|
||||
" assert torch.all(torch.eq(y, y_reconstructed))\n",
|
||||
"print(\"Sanity check passed\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
132
examples/research_projects/movement-pruning/bertarize.py
Normal file
132
examples/research_projects/movement-pruning/bertarize.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# Copyright 2020-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Once a model has been fine-pruned, the weights that are masked during the forward pass can be pruned once for all.
|
||||
For instance, once the a model from the :class:`~emmental.MaskedBertForSequenceClassification` is trained, it can be saved (and then loaded)
|
||||
as a standard :class:`~transformers.BertForSequenceClassification`.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
|
||||
from emmental.modules import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
|
||||
|
||||
|
||||
def main(args):
|
||||
pruning_method = args.pruning_method
|
||||
threshold = args.threshold
|
||||
|
||||
model_name_or_path = args.model_name_or_path.rstrip("/")
|
||||
target_model_path = args.target_model_path
|
||||
|
||||
print(f"Load fine-pruned model from {model_name_or_path}")
|
||||
model = torch.load(os.path.join(model_name_or_path, "pytorch_model.bin"))
|
||||
pruned_model = {}
|
||||
|
||||
for name, tensor in model.items():
|
||||
if "embeddings" in name or "LayerNorm" in name or "pooler" in name:
|
||||
pruned_model[name] = tensor
|
||||
print(f"Copied layer {name}")
|
||||
elif "classifier" in name or "qa_output" in name:
|
||||
pruned_model[name] = tensor
|
||||
print(f"Copied layer {name}")
|
||||
elif "bias" in name:
|
||||
pruned_model[name] = tensor
|
||||
print(f"Copied layer {name}")
|
||||
else:
|
||||
if pruning_method == "magnitude":
|
||||
mask = MagnitudeBinarizer.apply(inputs=tensor, threshold=threshold)
|
||||
pruned_model[name] = tensor * mask
|
||||
print(f"Pruned layer {name}")
|
||||
elif pruning_method == "topK":
|
||||
if "mask_scores" in name:
|
||||
continue
|
||||
prefix_ = name[:-6]
|
||||
scores = model[f"{prefix_}mask_scores"]
|
||||
mask = TopKBinarizer.apply(scores, threshold)
|
||||
pruned_model[name] = tensor * mask
|
||||
print(f"Pruned layer {name}")
|
||||
elif pruning_method == "sigmoied_threshold":
|
||||
if "mask_scores" in name:
|
||||
continue
|
||||
prefix_ = name[:-6]
|
||||
scores = model[f"{prefix_}mask_scores"]
|
||||
mask = ThresholdBinarizer.apply(scores, threshold, True)
|
||||
pruned_model[name] = tensor * mask
|
||||
print(f"Pruned layer {name}")
|
||||
elif pruning_method == "l0":
|
||||
if "mask_scores" in name:
|
||||
continue
|
||||
prefix_ = name[:-6]
|
||||
scores = model[f"{prefix_}mask_scores"]
|
||||
l, r = -0.1, 1.1
|
||||
s = torch.sigmoid(scores)
|
||||
s_bar = s * (r - l) + l
|
||||
mask = s_bar.clamp(min=0.0, max=1.0)
|
||||
pruned_model[name] = tensor * mask
|
||||
print(f"Pruned layer {name}")
|
||||
else:
|
||||
raise ValueError("Unknown pruning method")
|
||||
|
||||
if target_model_path is None:
|
||||
target_model_path = os.path.join(
|
||||
os.path.dirname(model_name_or_path), f"bertarized_{os.path.basename(model_name_or_path)}"
|
||||
)
|
||||
|
||||
if not os.path.isdir(target_model_path):
|
||||
shutil.copytree(model_name_or_path, target_model_path)
|
||||
print(f"\nCreated folder {target_model_path}")
|
||||
|
||||
torch.save(pruned_model, os.path.join(target_model_path, "pytorch_model.bin"))
|
||||
print("\nPruned model saved! See you later!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--pruning_method",
|
||||
choices=["l0", "magnitude", "topK", "sigmoied_threshold"],
|
||||
type=str,
|
||||
required=True,
|
||||
help="Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning, sigmoied_threshold = Soft movement pruning)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--threshold",
|
||||
type=float,
|
||||
required=False,
|
||||
help="For `magnitude` and `topK`, it is the level of remaining weights (in %) in the fine-pruned model."
|
||||
"For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared."
|
||||
"Not needed for `l0`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Folder containing the model that was previously fine-pruned",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target_model_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=False,
|
||||
help="Folder containing the model that was previously fine-pruned",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
@@ -0,0 +1,92 @@
|
||||
# Copyright 2020-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Count remaining (non-zero) weights in the encoder (i.e. the transformer layers).
|
||||
Sparsity and remaining weights levels are equivalent: sparsity % = 100 - remaining weights %.
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from emmental.modules import ThresholdBinarizer, TopKBinarizer
|
||||
|
||||
|
||||
def main(args):
|
||||
serialization_dir = args.serialization_dir
|
||||
pruning_method = args.pruning_method
|
||||
threshold = args.threshold
|
||||
|
||||
st = torch.load(os.path.join(serialization_dir, "pytorch_model.bin"), map_location="cpu")
|
||||
|
||||
remaining_count = 0 # Number of remaining (not pruned) params in the encoder
|
||||
encoder_count = 0 # Number of params in the encoder
|
||||
|
||||
print("name".ljust(60, " "), "Remaining Weights %", "Remaining Weight")
|
||||
for name, param in st.items():
|
||||
if "encoder" not in name:
|
||||
continue
|
||||
|
||||
if "mask_scores" in name:
|
||||
if pruning_method == "topK":
|
||||
mask_ones = TopKBinarizer.apply(param, threshold).sum().item()
|
||||
elif pruning_method == "sigmoied_threshold":
|
||||
mask_ones = ThresholdBinarizer.apply(param, threshold, True).sum().item()
|
||||
elif pruning_method == "l0":
|
||||
l, r = -0.1, 1.1
|
||||
s = torch.sigmoid(param)
|
||||
s_bar = s * (r - l) + l
|
||||
mask = s_bar.clamp(min=0.0, max=1.0)
|
||||
mask_ones = (mask > 0.0).sum().item()
|
||||
else:
|
||||
raise ValueError("Unknown pruning method")
|
||||
remaining_count += mask_ones
|
||||
print(name.ljust(60, " "), str(round(100 * mask_ones / param.numel(), 3)).ljust(20, " "), str(mask_ones))
|
||||
else:
|
||||
encoder_count += param.numel()
|
||||
if "bias" in name or "LayerNorm" in name:
|
||||
remaining_count += param.numel()
|
||||
|
||||
print("")
|
||||
print("Remaining Weights (global) %: ", 100 * remaining_count / encoder_count)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--pruning_method",
|
||||
choices=["l0", "topK", "sigmoied_threshold"],
|
||||
type=str,
|
||||
required=True,
|
||||
help="Pruning Method (l0 = L0 regularization, topK = Movement pruning, sigmoied_threshold = Soft movement pruning)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--threshold",
|
||||
type=float,
|
||||
required=False,
|
||||
help="For `topK`, it is the level of remaining weights (in %) in the fine-pruned model."
|
||||
"For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared."
|
||||
"Not needed for `l0`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--serialization_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Folder containing the model that was previously fine-pruned",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
@@ -0,0 +1,10 @@
|
||||
# flake8: noqa
|
||||
from .configuration_bert_masked import MaskedBertConfig
|
||||
from .modeling_bert_masked import (
|
||||
MaskedBertForMultipleChoice,
|
||||
MaskedBertForQuestionAnswering,
|
||||
MaskedBertForSequenceClassification,
|
||||
MaskedBertForTokenClassification,
|
||||
MaskedBertModel,
|
||||
)
|
||||
from .modules import *
|
||||
@@ -0,0 +1,71 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Masked BERT model configuration. It replicates the class `~transformers.BertConfig`
|
||||
and adapts it to the specificities of MaskedBert (`pruning_method`, `mask_init` and `mask_scale`."""
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MaskedBertConfig(PretrainedConfig):
|
||||
"""
|
||||
A class replicating the `~transformers.BertConfig` with additional parameters for pruning/masking configuration.
|
||||
"""
|
||||
|
||||
model_type = "masked_bert"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
pad_token_id=0,
|
||||
pruning_method="topK",
|
||||
mask_init="constant",
|
||||
mask_scale=0.0,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.pruning_method = pruning_method
|
||||
self.mask_init = mask_init
|
||||
self.mask_scale = mask_scale
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,3 @@
|
||||
# flake8: noqa
|
||||
from .binarizer import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
|
||||
from .masked_nn import MaskedLinear
|
||||
@@ -0,0 +1,144 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020-present, AllenAI Authors, University of Illinois Urbana-Champaign,
|
||||
# Intel Nervana Systems and the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Binarizers take a (real value) matrix as input and produce a binary (values in {0,1}) mask of the same shape.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import autograd
|
||||
|
||||
|
||||
class ThresholdBinarizer(autograd.Function):
|
||||
"""
|
||||
Thresholdd binarizer.
|
||||
Computes a binary mask M from a real value matrix S such that `M_{i,j} = 1` if and only if `S_{i,j} > \tau`
|
||||
where `\tau` is a real value threshold.
|
||||
|
||||
Implementation is inspired from:
|
||||
https://github.com/arunmallya/piggyback
|
||||
Piggyback: Adapting a Single Network to Multiple Tasks by Learning to Mask Weights
|
||||
Arun Mallya, Dillon Davis, Svetlana Lazebnik
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs: torch.tensor, threshold: float, sigmoid: bool):
|
||||
"""
|
||||
Args:
|
||||
inputs (`torch.FloatTensor`)
|
||||
The input matrix from which the binarizer computes the binary mask.
|
||||
threshold (`float`)
|
||||
The threshold value (in R).
|
||||
sigmoid (`bool`)
|
||||
If set to ``True``, we apply the sigmoid function to the `inputs` matrix before comparing to `threshold`.
|
||||
In this case, `threshold` should be a value between 0 and 1.
|
||||
Returns:
|
||||
mask (`torch.FloatTensor`)
|
||||
Binary matrix of the same size as `inputs` acting as a mask (1 - the associated weight is
|
||||
retained, 0 - the associated weight is pruned).
|
||||
"""
|
||||
nb_elems = inputs.numel()
|
||||
nb_min = int(0.005 * nb_elems) + 1
|
||||
if sigmoid:
|
||||
mask = (torch.sigmoid(inputs) > threshold).type(inputs.type())
|
||||
else:
|
||||
mask = (inputs > threshold).type(inputs.type())
|
||||
if mask.sum() < nb_min:
|
||||
# We limit the pruning so that at least 0.5% (half a percent) of the weights are remaining
|
||||
k_threshold = inputs.flatten().kthvalue(max(nb_elems - nb_min, 1)).values
|
||||
mask = (inputs > k_threshold).type(inputs.type())
|
||||
return mask
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gradOutput):
|
||||
return gradOutput, None, None
|
||||
|
||||
|
||||
class TopKBinarizer(autograd.Function):
|
||||
"""
|
||||
Top-k Binarizer.
|
||||
Computes a binary mask M from a real value matrix S such that `M_{i,j} = 1` if and only if `S_{i,j}`
|
||||
is among the k% highest values of S.
|
||||
|
||||
Implementation is inspired from:
|
||||
https://github.com/allenai/hidden-networks
|
||||
What's hidden in a randomly weighted neural network?
|
||||
Vivek Ramanujan*, Mitchell Wortsman*, Aniruddha Kembhavi, Ali Farhadi, Mohammad Rastegari
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs: torch.tensor, threshold: float):
|
||||
"""
|
||||
Args:
|
||||
inputs (`torch.FloatTensor`)
|
||||
The input matrix from which the binarizer computes the binary mask.
|
||||
threshold (`float`)
|
||||
The percentage of weights to keep (the rest is pruned).
|
||||
`threshold` is a float between 0 and 1.
|
||||
Returns:
|
||||
mask (`torch.FloatTensor`)
|
||||
Binary matrix of the same size as `inputs` acting as a mask (1 - the associated weight is
|
||||
retained, 0 - the associated weight is pruned).
|
||||
"""
|
||||
# Get the subnetwork by sorting the inputs and using the top threshold %
|
||||
mask = inputs.clone()
|
||||
_, idx = inputs.flatten().sort(descending=True)
|
||||
j = int(threshold * inputs.numel())
|
||||
|
||||
# flat_out and mask access the same memory.
|
||||
flat_out = mask.flatten()
|
||||
flat_out[idx[j:]] = 0
|
||||
flat_out[idx[:j]] = 1
|
||||
return mask
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gradOutput):
|
||||
return gradOutput, None
|
||||
|
||||
|
||||
class MagnitudeBinarizer(object):
|
||||
"""
|
||||
Magnitude Binarizer.
|
||||
Computes a binary mask M from a real value matrix S such that `M_{i,j} = 1` if and only if `S_{i,j}`
|
||||
is among the k% highest values of |S| (absolute value).
|
||||
|
||||
Implementation is inspired from https://github.com/NervanaSystems/distiller/blob/2291fdcc2ea642a98d4e20629acb5a9e2e04b4e6/distiller/pruning/automated_gradual_pruner.py#L24
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def apply(inputs: torch.tensor, threshold: float):
|
||||
"""
|
||||
Args:
|
||||
inputs (`torch.FloatTensor`)
|
||||
The input matrix from which the binarizer computes the binary mask.
|
||||
This input marix is typically the weight matrix.
|
||||
threshold (`float`)
|
||||
The percentage of weights to keep (the rest is pruned).
|
||||
`threshold` is a float between 0 and 1.
|
||||
Returns:
|
||||
mask (`torch.FloatTensor`)
|
||||
Binary matrix of the same size as `inputs` acting as a mask (1 - the associated weight is
|
||||
retained, 0 - the associated weight is pruned).
|
||||
"""
|
||||
# Get the subnetwork by sorting the inputs and using the top threshold %
|
||||
mask = inputs.clone()
|
||||
_, idx = inputs.abs().flatten().sort(descending=True)
|
||||
j = int(threshold * inputs.numel())
|
||||
|
||||
# flat_out and mask access the same memory.
|
||||
flat_out = mask.flatten()
|
||||
flat_out[idx[j:]] = 0
|
||||
flat_out[idx[:j]] = 1
|
||||
return mask
|
||||
@@ -0,0 +1,107 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Masked Linear module: A fully connected layer that computes an adaptive binary mask on the fly.
|
||||
The mask (binary or not) is computed at each forward pass and multiplied against
|
||||
the weight matrix to prune a portion of the weights.
|
||||
The pruned weight matrix is then multiplied against the inputs (and if necessary, the bias is added).
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn import init
|
||||
|
||||
from .binarizer import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
|
||||
|
||||
|
||||
class MaskedLinear(nn.Linear):
|
||||
"""
|
||||
Fully Connected layer with on the fly adaptive mask.
|
||||
If needed, a score matrix is created to store the importance of each associated weight.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
mask_init: str = "constant",
|
||||
mask_scale: float = 0.0,
|
||||
pruning_method: str = "topK",
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
in_features (`int`)
|
||||
Size of each input sample
|
||||
out_features (`int`)
|
||||
Size of each output sample
|
||||
bias (`bool`)
|
||||
If set to ``False``, the layer will not learn an additive bias.
|
||||
Default: ``True``
|
||||
mask_init (`str`)
|
||||
The initialization method for the score matrix if a score matrix is needed.
|
||||
Choices: ["constant", "uniform", "kaiming"]
|
||||
Default: ``constant``
|
||||
mask_scale (`float`)
|
||||
The initialization parameter for the chosen initialization method `mask_init`.
|
||||
Default: ``0.``
|
||||
pruning_method (`str`)
|
||||
Method to compute the mask.
|
||||
Choices: ["topK", "threshold", "sigmoied_threshold", "magnitude", "l0"]
|
||||
Default: ``topK``
|
||||
"""
|
||||
super(MaskedLinear, self).__init__(in_features=in_features, out_features=out_features, bias=bias)
|
||||
assert pruning_method in ["topK", "threshold", "sigmoied_threshold", "magnitude", "l0"]
|
||||
self.pruning_method = pruning_method
|
||||
|
||||
if self.pruning_method in ["topK", "threshold", "sigmoied_threshold", "l0"]:
|
||||
self.mask_scale = mask_scale
|
||||
self.mask_init = mask_init
|
||||
self.mask_scores = nn.Parameter(torch.Tensor(self.weight.size()))
|
||||
self.init_mask()
|
||||
|
||||
def init_mask(self):
|
||||
if self.mask_init == "constant":
|
||||
init.constant_(self.mask_scores, val=self.mask_scale)
|
||||
elif self.mask_init == "uniform":
|
||||
init.uniform_(self.mask_scores, a=-self.mask_scale, b=self.mask_scale)
|
||||
elif self.mask_init == "kaiming":
|
||||
init.kaiming_uniform_(self.mask_scores, a=math.sqrt(5))
|
||||
|
||||
def forward(self, input: torch.tensor, threshold: float):
|
||||
# Get the mask
|
||||
if self.pruning_method == "topK":
|
||||
mask = TopKBinarizer.apply(self.mask_scores, threshold)
|
||||
elif self.pruning_method in ["threshold", "sigmoied_threshold"]:
|
||||
sig = "sigmoied" in self.pruning_method
|
||||
mask = ThresholdBinarizer.apply(self.mask_scores, threshold, sig)
|
||||
elif self.pruning_method == "magnitude":
|
||||
mask = MagnitudeBinarizer.apply(self.weight, threshold)
|
||||
elif self.pruning_method == "l0":
|
||||
l, r, b = -0.1, 1.1, 2 / 3
|
||||
if self.training:
|
||||
u = torch.zeros_like(self.mask_scores).uniform_().clamp(0.0001, 0.9999)
|
||||
s = torch.sigmoid((u.log() - (1 - u).log() + self.mask_scores) / b)
|
||||
else:
|
||||
s = torch.sigmoid(self.mask_scores)
|
||||
s_bar = s * (r - l) + l
|
||||
mask = s_bar.clamp(min=0.0, max=1.0)
|
||||
# Mask weights with computed mask
|
||||
weight_thresholded = mask * self.weight
|
||||
# Compute output (linear layer) with masked weights
|
||||
return F.linear(input, weight_thresholded, self.bias)
|
||||
@@ -0,0 +1,5 @@
|
||||
# LXMERT DEMO
|
||||
|
||||
1. make a virtualenv: ``virtualenv venv`` and activate ``source venv/bin/activate``
|
||||
2. install reqs: ``pip install -r ./requirements.txt``
|
||||
3. usage is as shown in demo.ipynb
|
||||
267
examples/research_projects/movement-pruning/lxmert/demo.ipynb
Normal file
267
examples/research_projects/movement-pruning/lxmert/demo.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -0,0 +1,149 @@
|
||||
import getopt
|
||||
import json
|
||||
import os
|
||||
|
||||
# import numpy as np
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modeling_frcnn import GeneralizedRCNN
|
||||
from processing_image import Preprocess
|
||||
from utils import Config
|
||||
|
||||
|
||||
"""
|
||||
USAGE:
|
||||
``python extracting_data.py -i <img_dir> -o <dataset_file>.datasets <batch_size>``
|
||||
"""
|
||||
|
||||
|
||||
TEST = False
|
||||
CONFIG = Config.from_pretrained("unc-nlp/frcnn-vg-finetuned")
|
||||
DEFAULT_SCHEMA = datasets.Features(
|
||||
OrderedDict(
|
||||
{
|
||||
"attr_ids": datasets.Sequence(length=CONFIG.MAX_DETECTIONS, feature=datasets.Value("float32")),
|
||||
"attr_probs": datasets.Sequence(length=CONFIG.MAX_DETECTIONS, feature=datasets.Value("float32")),
|
||||
"boxes": datasets.Array2D((CONFIG.MAX_DETECTIONS, 4), dtype="float32"),
|
||||
"img_id": datasets.Value("int32"),
|
||||
"obj_ids": datasets.Sequence(length=CONFIG.MAX_DETECTIONS, feature=datasets.Value("float32")),
|
||||
"obj_probs": datasets.Sequence(length=CONFIG.MAX_DETECTIONS, feature=datasets.Value("float32")),
|
||||
"roi_features": datasets.Array2D((CONFIG.MAX_DETECTIONS, 2048), dtype="float32"),
|
||||
"sizes": datasets.Sequence(length=2, feature=datasets.Value("float32")),
|
||||
"preds_per_image": datasets.Value(dtype="int32"),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class Extract:
|
||||
def __init__(self, argv=sys.argv[1:]):
|
||||
inputdir = None
|
||||
outputfile = None
|
||||
subset_list = None
|
||||
batch_size = 1
|
||||
opts, args = getopt.getopt(argv, "i:o:b:s", ["inputdir=", "outfile=", "batch_size=", "subset_list="])
|
||||
for opt, arg in opts:
|
||||
if opt in ("-i", "--inputdir"):
|
||||
inputdir = arg
|
||||
elif opt in ("-o", "--outfile"):
|
||||
outputfile = arg
|
||||
elif opt in ("-b", "--batch_size"):
|
||||
batch_size = int(arg)
|
||||
elif opt in ("-s", "--subset_list"):
|
||||
subset_list = arg
|
||||
|
||||
assert inputdir is not None # and os.path.isdir(inputdir), f"{inputdir}"
|
||||
assert outputfile is not None and not os.path.isfile(outputfile), f"{outputfile}"
|
||||
if subset_list is not None:
|
||||
with open(os.path.realpath(subset_list)) as f:
|
||||
self.subset_list = set(map(lambda x: self._vqa_file_split()[0], tryload(f)))
|
||||
else:
|
||||
self.subset_list = None
|
||||
|
||||
self.config = CONFIG
|
||||
if torch.cuda.is_available():
|
||||
self.config.model.device = "cuda"
|
||||
self.inputdir = os.path.realpath(inputdir)
|
||||
self.outputfile = os.path.realpath(outputfile)
|
||||
self.preprocess = Preprocess(self.config)
|
||||
self.model = GeneralizedRCNN.from_pretrained("unc-nlp/frcnn-vg-finetuned", config=self.config)
|
||||
self.batch = batch_size if batch_size != 0 else 1
|
||||
self.schema = DEFAULT_SCHEMA
|
||||
|
||||
def _vqa_file_split(self, file):
|
||||
img_id = int(file.split(".")[0].split("_")[-1])
|
||||
filepath = os.path.join(self.inputdir, file)
|
||||
return (img_id, filepath)
|
||||
|
||||
@property
|
||||
def file_generator(self):
|
||||
batch = []
|
||||
for i, file in enumerate(os.listdir(self.inputdir)):
|
||||
if self.subset_list is not None and i not in self.subset_list:
|
||||
continue
|
||||
batch.append(self._vqa_file_split(file))
|
||||
if len(batch) == self.batch:
|
||||
temp = batch
|
||||
batch = []
|
||||
yield list(map(list, zip(*temp)))
|
||||
|
||||
for i in range(1):
|
||||
yield list(map(list, zip(*batch)))
|
||||
|
||||
def __call__(self):
|
||||
# make writer
|
||||
if not TEST:
|
||||
writer = datasets.ArrowWriter(features=self.schema, path=self.outputfile)
|
||||
# do file generator
|
||||
for i, (img_ids, filepaths) in enumerate(self.file_generator):
|
||||
images, sizes, scales_yx = self.preprocess(filepaths)
|
||||
output_dict = self.model(
|
||||
images,
|
||||
sizes,
|
||||
scales_yx=scales_yx,
|
||||
padding="max_detections",
|
||||
max_detections=self.config.MAX_DETECTIONS,
|
||||
pad_value=0,
|
||||
return_tensors="np",
|
||||
location="cpu",
|
||||
)
|
||||
output_dict["boxes"] = output_dict.pop("normalized_boxes")
|
||||
if not TEST:
|
||||
output_dict["img_id"] = np.array(img_ids)
|
||||
batch = self.schema.encode_batch(output_dict)
|
||||
writer.write_batch(batch)
|
||||
if TEST:
|
||||
break
|
||||
# finalizer the writer
|
||||
if not TEST:
|
||||
num_examples, num_bytes = writer.finalize()
|
||||
print(f"Success! You wrote {num_examples} entry(s) and {num_bytes >> 20} mb")
|
||||
|
||||
|
||||
def tryload(stream):
|
||||
try:
|
||||
data = json.load(stream)
|
||||
try:
|
||||
data = list(data.keys())
|
||||
except Exception:
|
||||
data = [d["img_id"] for d in data]
|
||||
except Exception:
|
||||
try:
|
||||
data = eval(stream.read())
|
||||
except Exception:
|
||||
data = stream.read().split("\n")
|
||||
return data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
extract = Extract(sys.argv[1:])
|
||||
extract()
|
||||
if not TEST:
|
||||
dataset = datasets.Dataset.from_file(extract.outputfile)
|
||||
# wala!
|
||||
# print(np.array(dataset[0:2]["roi_features"]).shape)
|
||||
1922
examples/research_projects/movement-pruning/lxmert/modeling_frcnn.py
Normal file
1922
examples/research_projects/movement-pruning/lxmert/modeling_frcnn.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,147 @@
|
||||
"""
|
||||
coding=utf-8
|
||||
Copyright 2018, Antonio Mendoza Hao Tan, Mohit Bansal
|
||||
Adapted From Facebook Inc, Detectron2
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.import copy
|
||||
"""
|
||||
import sys
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
|
||||
from utils import img_tensorize
|
||||
|
||||
|
||||
class ResizeShortestEdge:
|
||||
def __init__(self, short_edge_length, max_size=sys.maxsize):
|
||||
"""
|
||||
Args:
|
||||
short_edge_length (list[min, max])
|
||||
max_size (int): maximum allowed longest edge length.
|
||||
"""
|
||||
self.interp_method = "bilinear"
|
||||
self.max_size = max_size
|
||||
self.short_edge_length = short_edge_length
|
||||
|
||||
def __call__(self, imgs):
|
||||
img_augs = []
|
||||
for img in imgs:
|
||||
h, w = img.shape[:2]
|
||||
# later: provide list and randomly choose index for resize
|
||||
size = np.random.randint(self.short_edge_length[0], self.short_edge_length[1] + 1)
|
||||
if size == 0:
|
||||
return img
|
||||
scale = size * 1.0 / min(h, w)
|
||||
if h < w:
|
||||
newh, neww = size, scale * w
|
||||
else:
|
||||
newh, neww = scale * h, size
|
||||
if max(newh, neww) > self.max_size:
|
||||
scale = self.max_size * 1.0 / max(newh, neww)
|
||||
newh = newh * scale
|
||||
neww = neww * scale
|
||||
neww = int(neww + 0.5)
|
||||
newh = int(newh + 0.5)
|
||||
|
||||
if img.dtype == np.uint8:
|
||||
pil_image = Image.fromarray(img)
|
||||
pil_image = pil_image.resize((neww, newh), Image.BILINEAR)
|
||||
img = np.asarray(pil_image)
|
||||
else:
|
||||
img = img.permute(2, 0, 1).unsqueeze(0) # 3, 0, 1) # hw(c) -> nchw
|
||||
img = F.interpolate(img, (newh, neww), mode=self.interp_method, align_corners=False).squeeze(0)
|
||||
img_augs.append(img)
|
||||
|
||||
return img_augs
|
||||
|
||||
|
||||
class Preprocess:
|
||||
def __init__(self, cfg):
|
||||
self.aug = ResizeShortestEdge([cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST)
|
||||
self.input_format = cfg.INPUT.FORMAT
|
||||
self.size_divisibility = cfg.SIZE_DIVISIBILITY
|
||||
self.pad_value = cfg.PAD_VALUE
|
||||
self.max_image_size = cfg.INPUT.MAX_SIZE_TEST
|
||||
self.device = cfg.MODEL.DEVICE
|
||||
self.pixel_std = torch.tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(len(cfg.MODEL.PIXEL_STD), 1, 1)
|
||||
self.pixel_mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(len(cfg.MODEL.PIXEL_STD), 1, 1)
|
||||
self.normalizer = lambda x: (x - self.pixel_mean) / self.pixel_std
|
||||
|
||||
def pad(self, images):
|
||||
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
|
||||
image_sizes = [im.shape[-2:] for im in images]
|
||||
images = [
|
||||
F.pad(
|
||||
im,
|
||||
[0, max_size[-1] - size[1], 0, max_size[-2] - size[0]],
|
||||
value=self.pad_value,
|
||||
)
|
||||
for size, im in zip(image_sizes, images)
|
||||
]
|
||||
|
||||
return torch.stack(images), torch.tensor(image_sizes)
|
||||
|
||||
def __call__(self, images, single_image=False):
|
||||
with torch.no_grad():
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
if single_image:
|
||||
assert len(images) == 1
|
||||
for i in range(len(images)):
|
||||
if isinstance(images[i], torch.Tensor):
|
||||
images.insert(i, images.pop(i).to(self.device).float())
|
||||
elif not isinstance(images[i], torch.Tensor):
|
||||
images.insert(
|
||||
i,
|
||||
torch.as_tensor(img_tensorize(images.pop(i), input_format=self.input_format))
|
||||
.to(self.device)
|
||||
.float(),
|
||||
)
|
||||
# resize smallest edge
|
||||
raw_sizes = torch.tensor([im.shape[:2] for im in images])
|
||||
images = self.aug(images)
|
||||
# transpose images and convert to torch tensors
|
||||
# images = [torch.as_tensor(i.astype("float32")).permute(2, 0, 1).to(self.device) for i in images]
|
||||
# now normalize before pad to avoid useless arithmetic
|
||||
images = [self.normalizer(x) for x in images]
|
||||
# now pad them to do the following operations
|
||||
images, sizes = self.pad(images)
|
||||
# Normalize
|
||||
|
||||
if self.size_divisibility > 0:
|
||||
raise NotImplementedError()
|
||||
# pad
|
||||
scales_yx = torch.true_divide(raw_sizes, sizes)
|
||||
if single_image:
|
||||
return images[0], sizes[0], scales_yx[0]
|
||||
else:
|
||||
return images, sizes, scales_yx
|
||||
|
||||
|
||||
def _scale_box(boxes, scale_yx):
|
||||
boxes[:, 0::2] *= scale_yx[:, 1]
|
||||
boxes[:, 1::2] *= scale_yx[:, 0]
|
||||
return boxes
|
||||
|
||||
|
||||
def _clip_box(tensor, box_size: Tuple[int, int]):
|
||||
assert torch.isfinite(tensor).all(), "Box tensor contains infinite or NaN!"
|
||||
h, w = box_size
|
||||
tensor[:, 0].clamp_(min=0, max=w)
|
||||
tensor[:, 1].clamp_(min=0, max=h)
|
||||
tensor[:, 2].clamp_(min=0, max=w)
|
||||
tensor[:, 3].clamp_(min=0, max=h)
|
||||
@@ -0,0 +1,99 @@
|
||||
appdirs==1.4.3
|
||||
argon2-cffi==20.1.0
|
||||
async-generator==1.10
|
||||
attrs==20.2.0
|
||||
backcall==0.2.0
|
||||
bleach==3.1.5
|
||||
CacheControl==0.12.6
|
||||
certifi==2020.6.20
|
||||
cffi==1.14.2
|
||||
chardet==3.0.4
|
||||
click==7.1.2
|
||||
colorama==0.4.3
|
||||
contextlib2==0.6.0
|
||||
cycler==0.10.0
|
||||
datasets==1.0.0
|
||||
decorator==4.4.2
|
||||
defusedxml==0.6.0
|
||||
dill==0.3.2
|
||||
distlib==0.3.0
|
||||
distro==1.4.0
|
||||
entrypoints==0.3
|
||||
filelock==3.0.12
|
||||
future==0.18.2
|
||||
html5lib==1.0.1
|
||||
idna==2.8
|
||||
ipaddr==2.2.0
|
||||
ipykernel==5.3.4
|
||||
ipython
|
||||
ipython-genutils==0.2.0
|
||||
ipywidgets==7.5.1
|
||||
jedi==0.17.2
|
||||
Jinja2==2.11.2
|
||||
joblib==0.16.0
|
||||
jsonschema==3.2.0
|
||||
jupyter==1.0.0
|
||||
jupyter-client==6.1.7
|
||||
jupyter-console==6.2.0
|
||||
jupyter-core==4.6.3
|
||||
jupyterlab-pygments==0.1.1
|
||||
kiwisolver==1.2.0
|
||||
lockfile==0.12.2
|
||||
MarkupSafe==1.1.1
|
||||
matplotlib==3.3.1
|
||||
mistune==0.8.4
|
||||
msgpack==0.6.2
|
||||
nbclient==0.5.0
|
||||
nbconvert==6.0.1
|
||||
nbformat==5.0.7
|
||||
nest-asyncio==1.4.0
|
||||
notebook==6.1.4
|
||||
numpy==1.19.2
|
||||
opencv-python==4.4.0.42
|
||||
packaging==20.3
|
||||
pandas==1.1.2
|
||||
pandocfilters==1.4.2
|
||||
parso==0.7.1
|
||||
pep517==0.8.2
|
||||
pexpect==4.8.0
|
||||
pickleshare==0.7.5
|
||||
Pillow==7.2.0
|
||||
progress==1.5
|
||||
prometheus-client==0.8.0
|
||||
prompt-toolkit==3.0.7
|
||||
ptyprocess==0.6.0
|
||||
pyaml==20.4.0
|
||||
pyarrow==1.0.1
|
||||
pycparser==2.20
|
||||
Pygments==2.6.1
|
||||
pyparsing==2.4.6
|
||||
pyrsistent==0.16.0
|
||||
python-dateutil==2.8.1
|
||||
pytoml==0.1.21
|
||||
pytz==2020.1
|
||||
PyYAML==5.3.1
|
||||
pyzmq==19.0.2
|
||||
qtconsole==4.7.7
|
||||
QtPy==1.9.0
|
||||
regex==2020.7.14
|
||||
requests==2.22.0
|
||||
retrying==1.3.3
|
||||
sacremoses==0.0.43
|
||||
Send2Trash==1.5.0
|
||||
sentencepiece==0.1.91
|
||||
six==1.14.0
|
||||
terminado==0.8.3
|
||||
testpath==0.4.4
|
||||
tokenizers==0.8.1rc2
|
||||
torch==1.6.0
|
||||
torchvision==0.7.0
|
||||
tornado==6.0.4
|
||||
tqdm==4.48.2
|
||||
traitlets
|
||||
transformers==3.5.1
|
||||
urllib3==1.25.8
|
||||
wcwidth==0.2.5
|
||||
webencodings==0.5.1
|
||||
wget==3.2
|
||||
widgetsnbextension==3.5.1
|
||||
xxhash==2.0.0
|
||||
559
examples/research_projects/movement-pruning/lxmert/utils.py
Normal file
559
examples/research_projects/movement-pruning/lxmert/utils.py
Normal file
@@ -0,0 +1,559 @@
|
||||
"""
|
||||
coding=utf-8
|
||||
Copyright 2018, Antonio Mendoza Hao Tan, Mohit Bansal, Huggingface team :)
|
||||
Adapted From Facebook Inc, Detectron2
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.import copy
|
||||
"""
|
||||
|
||||
import copy
|
||||
import fnmatch
|
||||
import json
|
||||
import os
|
||||
import pickle as pkl
|
||||
import shutil
|
||||
import sys
|
||||
import tarfile
|
||||
import tempfile
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from hashlib import sha256
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
from zipfile import ZipFile, is_zipfile
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import cv2
|
||||
import requests
|
||||
import wget
|
||||
from filelock import FileLock
|
||||
from yaml import Loader, dump, load
|
||||
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
_torch_available = True
|
||||
except ImportError:
|
||||
_torch_available = False
|
||||
|
||||
|
||||
try:
|
||||
from torch.hub import _get_torch_home
|
||||
|
||||
torch_cache_home = _get_torch_home()
|
||||
except ImportError:
|
||||
torch_cache_home = os.path.expanduser(
|
||||
os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
|
||||
)
|
||||
|
||||
default_cache_path = os.path.join(torch_cache_home, "transformers")
|
||||
|
||||
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
|
||||
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
|
||||
PATH = "/".join(str(Path(__file__).resolve()).split("/")[:-1])
|
||||
CONFIG = os.path.join(PATH, "config.yaml")
|
||||
ATTRIBUTES = os.path.join(PATH, "attributes.txt")
|
||||
OBJECTS = os.path.join(PATH, "objects.txt")
|
||||
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
|
||||
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
|
||||
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
CONFIG_NAME = "config.yaml"
|
||||
|
||||
|
||||
def load_labels(objs=OBJECTS, attrs=ATTRIBUTES):
|
||||
vg_classes = []
|
||||
with open(objs) as f:
|
||||
for object in f.readlines():
|
||||
vg_classes.append(object.split(",")[0].lower().strip())
|
||||
|
||||
vg_attrs = []
|
||||
with open(attrs) as f:
|
||||
for object in f.readlines():
|
||||
vg_attrs.append(object.split(",")[0].lower().strip())
|
||||
return vg_classes, vg_attrs
|
||||
|
||||
|
||||
def load_checkpoint(ckp):
|
||||
r = OrderedDict()
|
||||
with open(ckp, "rb") as f:
|
||||
ckp = pkl.load(f)["model"]
|
||||
for k in copy.deepcopy(list(ckp.keys())):
|
||||
v = ckp.pop(k)
|
||||
if isinstance(v, np.ndarray):
|
||||
v = torch.tensor(v)
|
||||
else:
|
||||
assert isinstance(v, torch.tensor), type(v)
|
||||
r[k] = v
|
||||
return r
|
||||
|
||||
|
||||
class Config:
|
||||
_pointer = {}
|
||||
|
||||
def __init__(self, dictionary: dict, name: str = "root", level=0):
|
||||
self._name = name
|
||||
self._level = level
|
||||
d = {}
|
||||
for k, v in dictionary.items():
|
||||
if v is None:
|
||||
raise ValueError()
|
||||
k = copy.deepcopy(k)
|
||||
v = copy.deepcopy(v)
|
||||
if isinstance(v, dict):
|
||||
v = Config(v, name=k, level=level + 1)
|
||||
d[k] = v
|
||||
setattr(self, k, v)
|
||||
|
||||
self._pointer = d
|
||||
|
||||
def __repr__(self):
|
||||
return str(list((self._pointer.keys())))
|
||||
|
||||
def __setattr__(self, key, val):
|
||||
self.__dict__[key] = val
|
||||
self.__dict__[key.upper()] = val
|
||||
levels = key.split(".")
|
||||
last_level = len(levels) - 1
|
||||
pointer = self._pointer
|
||||
if len(levels) > 1:
|
||||
for i, l in enumerate(levels):
|
||||
if hasattr(self, l) and isinstance(getattr(self, l), Config):
|
||||
setattr(getattr(self, l), ".".join(levels[i:]), val)
|
||||
if l == last_level:
|
||||
pointer[l] = val
|
||||
else:
|
||||
pointer = pointer[l]
|
||||
|
||||
def to_dict(self):
|
||||
return self._pointer
|
||||
|
||||
def dump_yaml(self, data, file_name):
|
||||
with open(f"{file_name}", "w") as stream:
|
||||
dump(data, stream)
|
||||
|
||||
def dump_json(self, data, file_name):
|
||||
with open(f"{file_name}", "w") as stream:
|
||||
json.dump(data, stream)
|
||||
|
||||
@staticmethod
|
||||
def load_yaml(config):
|
||||
with open(config) as stream:
|
||||
data = load(stream, Loader=Loader)
|
||||
return data
|
||||
|
||||
def __str__(self):
|
||||
t = " "
|
||||
if self._name != "root":
|
||||
r = f"{t * (self._level-1)}{self._name}:\n"
|
||||
else:
|
||||
r = ""
|
||||
level = self._level
|
||||
for i, (k, v) in enumerate(self._pointer.items()):
|
||||
if isinstance(v, Config):
|
||||
r += f"{t * (self._level)}{v}\n"
|
||||
self._level += 1
|
||||
else:
|
||||
r += f"{t * (self._level)}{k}: {v} ({type(v).__name__})\n"
|
||||
self._level = level
|
||||
return r[:-1]
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
return cls(config_dict)
|
||||
|
||||
@classmethod
|
||||
def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs):
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
config_file = pretrained_model_name_or_path
|
||||
else:
|
||||
config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False)
|
||||
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_config_file = cached_path(
|
||||
config_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
# Load config dict
|
||||
if resolved_config_file is None:
|
||||
raise EnvironmentError
|
||||
|
||||
config_file = Config.load_yaml(resolved_config_file)
|
||||
|
||||
except EnvironmentError:
|
||||
msg = "Can't load config for"
|
||||
raise EnvironmentError(msg)
|
||||
|
||||
if resolved_config_file == config_file:
|
||||
print("loading configuration file from path")
|
||||
else:
|
||||
print("loading configuration file cache")
|
||||
|
||||
return Config.load_yaml(resolved_config_file), kwargs
|
||||
|
||||
|
||||
# quick compare tensors
|
||||
def compare(in_tensor):
|
||||
|
||||
out_tensor = torch.load("dump.pt", map_location=in_tensor.device)
|
||||
n1 = in_tensor.numpy()
|
||||
n2 = out_tensor.numpy()[0]
|
||||
print(n1.shape, n1[0, 0, :5])
|
||||
print(n2.shape, n2[0, 0, :5])
|
||||
assert np.allclose(
|
||||
n1, n2, rtol=0.01, atol=0.1
|
||||
), f"{sum([1 for x in np.isclose(n1, n2, rtol=0.01, atol=0.1).flatten() if x == False])/len(n1.flatten())*100:.4f} % element-wise mismatch"
|
||||
raise Exception("tensors are all good")
|
||||
|
||||
# Hugging face functions below
|
||||
|
||||
|
||||
def is_remote_url(url_or_filename):
|
||||
parsed = urlparse(url_or_filename)
|
||||
return parsed.scheme in ("http", "https")
|
||||
|
||||
|
||||
def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str:
|
||||
endpoint = CLOUDFRONT_DISTRIB_PREFIX if use_cdn else S3_BUCKET_PREFIX
|
||||
legacy_format = "/" not in model_id
|
||||
if legacy_format:
|
||||
return f"{endpoint}/{model_id}-{filename}"
|
||||
else:
|
||||
return f"{endpoint}/{model_id}/{filename}"
|
||||
|
||||
|
||||
def http_get(
|
||||
url,
|
||||
temp_file,
|
||||
proxies=None,
|
||||
resume_size=0,
|
||||
user_agent=None,
|
||||
):
|
||||
ua = "python/{}".format(sys.version.split()[0])
|
||||
if _torch_available:
|
||||
ua += "; torch/{}".format(torch.__version__)
|
||||
if isinstance(user_agent, dict):
|
||||
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
|
||||
elif isinstance(user_agent, str):
|
||||
ua += "; " + user_agent
|
||||
headers = {"user-agent": ua}
|
||||
if resume_size > 0:
|
||||
headers["Range"] = "bytes=%d-" % (resume_size,)
|
||||
response = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
||||
if response.status_code == 416: # Range not satisfiable
|
||||
return
|
||||
content_length = response.headers.get("Content-Length")
|
||||
total = resume_size + int(content_length) if content_length is not None else None
|
||||
progress = tqdm(
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
total=total,
|
||||
initial=resume_size,
|
||||
desc="Downloading",
|
||||
)
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
progress.update(len(chunk))
|
||||
temp_file.write(chunk)
|
||||
progress.close()
|
||||
|
||||
|
||||
def get_from_cache(
|
||||
url,
|
||||
cache_dir=None,
|
||||
force_download=False,
|
||||
proxies=None,
|
||||
etag_timeout=10,
|
||||
resume_download=False,
|
||||
user_agent=None,
|
||||
local_files_only=False,
|
||||
):
|
||||
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
etag = None
|
||||
if not local_files_only:
|
||||
try:
|
||||
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
|
||||
if response.status_code == 200:
|
||||
etag = response.headers.get("ETag")
|
||||
except (EnvironmentError, requests.exceptions.Timeout):
|
||||
# etag is already None
|
||||
pass
|
||||
|
||||
filename = url_to_filename(url, etag)
|
||||
|
||||
# get cache path to put the file
|
||||
cache_path = os.path.join(cache_dir, filename)
|
||||
|
||||
# etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
|
||||
# try to get the last downloaded one
|
||||
if etag is None:
|
||||
if os.path.exists(cache_path):
|
||||
return cache_path
|
||||
else:
|
||||
matching_files = [
|
||||
file
|
||||
for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
|
||||
if not file.endswith(".json") and not file.endswith(".lock")
|
||||
]
|
||||
if len(matching_files) > 0:
|
||||
return os.path.join(cache_dir, matching_files[-1])
|
||||
else:
|
||||
# If files cannot be found and local_files_only=True,
|
||||
# the models might've been found if local_files_only=False
|
||||
# Notify the user about that
|
||||
if local_files_only:
|
||||
raise ValueError(
|
||||
"Cannot find the requested files in the cached path and outgoing traffic has been"
|
||||
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
|
||||
" to False."
|
||||
)
|
||||
return None
|
||||
|
||||
# From now on, etag is not None.
|
||||
if os.path.exists(cache_path) and not force_download:
|
||||
return cache_path
|
||||
|
||||
# Prevent parallel downloads of the same file with a lock.
|
||||
lock_path = cache_path + ".lock"
|
||||
with FileLock(lock_path):
|
||||
|
||||
# If the download just completed while the lock was activated.
|
||||
if os.path.exists(cache_path) and not force_download:
|
||||
# Even if returning early like here, the lock will be released.
|
||||
return cache_path
|
||||
|
||||
if resume_download:
|
||||
incomplete_path = cache_path + ".incomplete"
|
||||
|
||||
@contextmanager
|
||||
def _resumable_file_manager():
|
||||
with open(incomplete_path, "a+b") as f:
|
||||
yield f
|
||||
|
||||
temp_file_manager = _resumable_file_manager
|
||||
if os.path.exists(incomplete_path):
|
||||
resume_size = os.stat(incomplete_path).st_size
|
||||
else:
|
||||
resume_size = 0
|
||||
else:
|
||||
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
|
||||
resume_size = 0
|
||||
|
||||
# Download to temporary file, then copy to cache dir once finished.
|
||||
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
||||
with temp_file_manager() as temp_file:
|
||||
print(
|
||||
"%s not found in cache or force_download set to True, downloading to %s",
|
||||
url,
|
||||
temp_file.name,
|
||||
)
|
||||
|
||||
http_get(
|
||||
url,
|
||||
temp_file,
|
||||
proxies=proxies,
|
||||
resume_size=resume_size,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
os.replace(temp_file.name, cache_path)
|
||||
|
||||
meta = {"url": url, "etag": etag}
|
||||
meta_path = cache_path + ".json"
|
||||
with open(meta_path, "w") as meta_file:
|
||||
json.dump(meta, meta_file)
|
||||
|
||||
return cache_path
|
||||
|
||||
|
||||
def url_to_filename(url, etag=None):
|
||||
|
||||
url_bytes = url.encode("utf-8")
|
||||
url_hash = sha256(url_bytes)
|
||||
filename = url_hash.hexdigest()
|
||||
|
||||
if etag:
|
||||
etag_bytes = etag.encode("utf-8")
|
||||
etag_hash = sha256(etag_bytes)
|
||||
filename += "." + etag_hash.hexdigest()
|
||||
|
||||
if url.endswith(".h5"):
|
||||
filename += ".h5"
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def cached_path(
|
||||
url_or_filename,
|
||||
cache_dir=None,
|
||||
force_download=False,
|
||||
proxies=None,
|
||||
resume_download=False,
|
||||
user_agent=None,
|
||||
extract_compressed_file=False,
|
||||
force_extract=False,
|
||||
local_files_only=False,
|
||||
):
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
if isinstance(url_or_filename, Path):
|
||||
url_or_filename = str(url_or_filename)
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
|
||||
if is_remote_url(url_or_filename):
|
||||
# URL, so get it from the cache (downloading if necessary)
|
||||
output_path = get_from_cache(
|
||||
url_or_filename,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
user_agent=user_agent,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
elif os.path.exists(url_or_filename):
|
||||
# File, and it exists.
|
||||
output_path = url_or_filename
|
||||
elif urlparse(url_or_filename).scheme == "":
|
||||
# File, but it doesn't exist.
|
||||
raise EnvironmentError("file {} not found".format(url_or_filename))
|
||||
else:
|
||||
# Something unknown
|
||||
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
|
||||
|
||||
if extract_compressed_file:
|
||||
if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
|
||||
return output_path
|
||||
|
||||
# Path where we extract compressed archives
|
||||
# We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
|
||||
output_dir, output_file = os.path.split(output_path)
|
||||
output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
|
||||
output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
|
||||
|
||||
if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
|
||||
return output_path_extracted
|
||||
|
||||
# Prevent parallel extractions
|
||||
lock_path = output_path + ".lock"
|
||||
with FileLock(lock_path):
|
||||
shutil.rmtree(output_path_extracted, ignore_errors=True)
|
||||
os.makedirs(output_path_extracted)
|
||||
if is_zipfile(output_path):
|
||||
with ZipFile(output_path, "r") as zip_file:
|
||||
zip_file.extractall(output_path_extracted)
|
||||
zip_file.close()
|
||||
elif tarfile.is_tarfile(output_path):
|
||||
tar_file = tarfile.open(output_path)
|
||||
tar_file.extractall(output_path_extracted)
|
||||
tar_file.close()
|
||||
else:
|
||||
raise EnvironmentError("Archive format of {} could not be identified".format(output_path))
|
||||
|
||||
return output_path_extracted
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def get_data(query, delim=","):
|
||||
assert isinstance(query, str)
|
||||
if os.path.isfile(query):
|
||||
with open(query) as f:
|
||||
data = eval(f.read())
|
||||
else:
|
||||
req = requests.get(query)
|
||||
try:
|
||||
data = requests.json()
|
||||
except Exception:
|
||||
data = req.content.decode()
|
||||
assert data is not None, "could not connect"
|
||||
try:
|
||||
data = eval(data)
|
||||
except Exception:
|
||||
data = data.split("\n")
|
||||
req.close()
|
||||
return data
|
||||
|
||||
|
||||
def get_image_from_url(url):
|
||||
response = requests.get(url)
|
||||
img = np.array(Image.open(BytesIO(response.content)))
|
||||
return img
|
||||
|
||||
|
||||
# to load legacy frcnn checkpoint from detectron
|
||||
def load_frcnn_pkl_from_url(url):
|
||||
fn = url.split("/")[-1]
|
||||
if fn not in os.listdir(os.getcwd()):
|
||||
wget.download(url)
|
||||
with open(fn, "rb") as stream:
|
||||
weights = pkl.load(stream)
|
||||
model = weights.pop("model")
|
||||
new = {}
|
||||
for k, v in model.items():
|
||||
new[k] = torch.from_numpy(v)
|
||||
if "running_var" in k:
|
||||
zero = torch.Tensor([0])
|
||||
k2 = k.replace("running_var", "num_batches_tracked")
|
||||
new[k2] = zero
|
||||
return new
|
||||
|
||||
|
||||
def get_demo_path():
|
||||
print(f"{os.path.abspath(os.path.join(PATH, os.pardir))}/demo.ipynb")
|
||||
|
||||
|
||||
def img_tensorize(im, input_format="RGB"):
|
||||
assert isinstance(im, str)
|
||||
if os.path.isfile(im):
|
||||
img = cv2.imread(im)
|
||||
else:
|
||||
img = get_image_from_url(im)
|
||||
assert img is not None, f"could not connect to: {im}"
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
if input_format == "RGB":
|
||||
img = img[:, :, ::-1]
|
||||
return img
|
||||
|
||||
|
||||
def chunk(images, batch=1):
|
||||
return (images[i : i + batch] for i in range(0, len(images), batch))
|
||||
@@ -0,0 +1,499 @@
|
||||
"""
|
||||
coding=utf-8
|
||||
Copyright 2018, Antonio Mendoza Hao Tan, Mohit Bansal
|
||||
Adapted From Facebook Inc, Detectron2
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.import copy
|
||||
"""
|
||||
import colorsys
|
||||
import io
|
||||
|
||||
import matplotlib as mpl
|
||||
import matplotlib.colors as mplc
|
||||
import matplotlib.figure as mplfigure
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||||
|
||||
import cv2
|
||||
from utils import img_tensorize
|
||||
|
||||
|
||||
_SMALL_OBJ = 1000
|
||||
|
||||
|
||||
class SingleImageViz:
|
||||
def __init__(
|
||||
self,
|
||||
img,
|
||||
scale=1.2,
|
||||
edgecolor="g",
|
||||
alpha=0.5,
|
||||
linestyle="-",
|
||||
saveas="test_out.jpg",
|
||||
rgb=True,
|
||||
pynb=False,
|
||||
id2obj=None,
|
||||
id2attr=None,
|
||||
pad=0.7,
|
||||
):
|
||||
"""
|
||||
img: an RGB image of shape (H, W, 3).
|
||||
"""
|
||||
if isinstance(img, torch.Tensor):
|
||||
img = img.numpy().astype("np.uint8")
|
||||
if isinstance(img, str):
|
||||
img = img_tensorize(img)
|
||||
assert isinstance(img, np.ndarray)
|
||||
|
||||
width, height = img.shape[1], img.shape[0]
|
||||
fig = mplfigure.Figure(frameon=False)
|
||||
dpi = fig.get_dpi()
|
||||
width_in = (width * scale + 1e-2) / dpi
|
||||
height_in = (height * scale + 1e-2) / dpi
|
||||
fig.set_size_inches(width_in, height_in)
|
||||
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
|
||||
ax.axis("off")
|
||||
ax.set_xlim(0.0, width)
|
||||
ax.set_ylim(height)
|
||||
|
||||
self.saveas = saveas
|
||||
self.rgb = rgb
|
||||
self.pynb = pynb
|
||||
self.img = img
|
||||
self.edgecolor = edgecolor
|
||||
self.alpha = 0.5
|
||||
self.linestyle = linestyle
|
||||
self.font_size = int(np.sqrt(min(height, width)) * scale // 3)
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.scale = scale
|
||||
self.fig = fig
|
||||
self.ax = ax
|
||||
self.pad = pad
|
||||
self.id2obj = id2obj
|
||||
self.id2attr = id2attr
|
||||
self.canvas = FigureCanvasAgg(fig)
|
||||
|
||||
def add_box(self, box, color=None):
|
||||
if color is None:
|
||||
color = self.edgecolor
|
||||
(x0, y0, x1, y1) = box
|
||||
width = x1 - x0
|
||||
height = y1 - y0
|
||||
self.ax.add_patch(
|
||||
mpl.patches.Rectangle(
|
||||
(x0, y0),
|
||||
width,
|
||||
height,
|
||||
fill=False,
|
||||
edgecolor=color,
|
||||
linewidth=self.font_size // 3,
|
||||
alpha=self.alpha,
|
||||
linestyle=self.linestyle,
|
||||
)
|
||||
)
|
||||
|
||||
def draw_boxes(self, boxes, obj_ids=None, obj_scores=None, attr_ids=None, attr_scores=None):
|
||||
if len(boxes.shape) > 2:
|
||||
boxes = boxes[0]
|
||||
if len(obj_ids.shape) > 1:
|
||||
obj_ids = obj_ids[0]
|
||||
if len(obj_scores.shape) > 1:
|
||||
obj_scores = obj_scores[0]
|
||||
if len(attr_ids.shape) > 1:
|
||||
attr_ids = attr_ids[0]
|
||||
if len(attr_scores.shape) > 1:
|
||||
attr_scores = attr_scores[0]
|
||||
if isinstance(boxes, torch.Tensor):
|
||||
boxes = boxes.numpy()
|
||||
if isinstance(boxes, list):
|
||||
boxes = np.array(boxes)
|
||||
assert isinstance(boxes, np.ndarray)
|
||||
areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1)
|
||||
sorted_idxs = np.argsort(-areas).tolist()
|
||||
boxes = boxes[sorted_idxs] if boxes is not None else None
|
||||
obj_ids = obj_ids[sorted_idxs] if obj_ids is not None else None
|
||||
obj_scores = obj_scores[sorted_idxs] if obj_scores is not None else None
|
||||
attr_ids = attr_ids[sorted_idxs] if attr_ids is not None else None
|
||||
attr_scores = attr_scores[sorted_idxs] if attr_scores is not None else None
|
||||
|
||||
assigned_colors = [self._random_color(maximum=1) for _ in range(len(boxes))]
|
||||
assigned_colors = [assigned_colors[idx] for idx in sorted_idxs]
|
||||
if obj_ids is not None:
|
||||
labels = self._create_text_labels_attr(obj_ids, obj_scores, attr_ids, attr_scores)
|
||||
for i in range(len(boxes)):
|
||||
color = assigned_colors[i]
|
||||
self.add_box(boxes[i], color)
|
||||
self.draw_labels(labels[i], boxes[i], color)
|
||||
|
||||
def draw_labels(self, label, box, color):
|
||||
x0, y0, x1, y1 = box
|
||||
text_pos = (x0, y0)
|
||||
instance_area = (y1 - y0) * (x1 - x0)
|
||||
small = _SMALL_OBJ * self.scale
|
||||
if instance_area < small or y1 - y0 < 40 * self.scale:
|
||||
if y1 >= self.height - 5:
|
||||
text_pos = (x1, y0)
|
||||
else:
|
||||
text_pos = (x0, y1)
|
||||
|
||||
height_ratio = (y1 - y0) / np.sqrt(self.height * self.width)
|
||||
lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
|
||||
font_size = np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
|
||||
font_size *= 0.75 * self.font_size
|
||||
|
||||
self.draw_text(
|
||||
text=label,
|
||||
position=text_pos,
|
||||
color=lighter_color,
|
||||
)
|
||||
|
||||
def draw_text(
|
||||
self,
|
||||
text,
|
||||
position,
|
||||
color="g",
|
||||
ha="left",
|
||||
):
|
||||
rotation = 0
|
||||
font_size = self.font_size
|
||||
color = np.maximum(list(mplc.to_rgb(color)), 0.2)
|
||||
color[np.argmax(color)] = max(0.8, np.max(color))
|
||||
bbox = {
|
||||
"facecolor": "black",
|
||||
"alpha": self.alpha,
|
||||
"pad": self.pad,
|
||||
"edgecolor": "none",
|
||||
}
|
||||
x, y = position
|
||||
self.ax.text(
|
||||
x,
|
||||
y,
|
||||
text,
|
||||
size=font_size * self.scale,
|
||||
family="sans-serif",
|
||||
bbox=bbox,
|
||||
verticalalignment="top",
|
||||
horizontalalignment=ha,
|
||||
color=color,
|
||||
zorder=10,
|
||||
rotation=rotation,
|
||||
)
|
||||
|
||||
def save(self, saveas=None):
|
||||
if saveas is None:
|
||||
saveas = self.saveas
|
||||
if saveas.lower().endswith(".jpg") or saveas.lower().endswith(".png"):
|
||||
cv2.imwrite(
|
||||
saveas,
|
||||
self._get_buffer()[:, :, ::-1],
|
||||
)
|
||||
else:
|
||||
self.fig.savefig(saveas)
|
||||
|
||||
def _create_text_labels_attr(self, classes, scores, attr_classes, attr_scores):
|
||||
labels = [self.id2obj[i] for i in classes]
|
||||
attr_labels = [self.id2attr[i] for i in attr_classes]
|
||||
labels = [
|
||||
f"{label} {score:.2f} {attr} {attr_score:.2f}"
|
||||
for label, score, attr, attr_score in zip(labels, scores, attr_labels, attr_scores)
|
||||
]
|
||||
return labels
|
||||
|
||||
def _create_text_labels(self, classes, scores):
|
||||
labels = [self.id2obj[i] for i in classes]
|
||||
if scores is not None:
|
||||
if labels is None:
|
||||
labels = ["{:.0f}%".format(s * 100) for s in scores]
|
||||
else:
|
||||
labels = ["{} {:.0f}%".format(li, s * 100) for li, s in zip(labels, scores)]
|
||||
return labels
|
||||
|
||||
def _random_color(self, maximum=255):
|
||||
idx = np.random.randint(0, len(_COLORS))
|
||||
ret = _COLORS[idx] * maximum
|
||||
if not self.rgb:
|
||||
ret = ret[::-1]
|
||||
return ret
|
||||
|
||||
def _get_buffer(self):
|
||||
if not self.pynb:
|
||||
s, (width, height) = self.canvas.print_to_buffer()
|
||||
if (width, height) != (self.width, self.height):
|
||||
img = cv2.resize(self.img, (width, height))
|
||||
else:
|
||||
img = self.img
|
||||
else:
|
||||
buf = io.BytesIO() # works for cairo backend
|
||||
self.canvas.print_rgba(buf)
|
||||
width, height = self.width, self.height
|
||||
s = buf.getvalue()
|
||||
img = self.img
|
||||
|
||||
buffer = np.frombuffer(s, dtype="uint8")
|
||||
img_rgba = buffer.reshape(height, width, 4)
|
||||
rgb, alpha = np.split(img_rgba, [3], axis=2)
|
||||
|
||||
try:
|
||||
import numexpr as ne # fuse them with numexpr
|
||||
|
||||
visualized_image = ne.evaluate("img * (1 - alpha / 255.0) + rgb * (alpha / 255.0)")
|
||||
except ImportError:
|
||||
alpha = alpha.astype("float32") / 255.0
|
||||
visualized_image = img * (1 - alpha) + rgb * alpha
|
||||
|
||||
return visualized_image.astype("uint8")
|
||||
|
||||
def _change_color_brightness(self, color, brightness_factor):
|
||||
assert brightness_factor >= -1.0 and brightness_factor <= 1.0
|
||||
color = mplc.to_rgb(color)
|
||||
polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
|
||||
modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
|
||||
modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
|
||||
modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
|
||||
modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2])
|
||||
return modified_color
|
||||
|
||||
|
||||
# Color map
|
||||
_COLORS = (
|
||||
np.array(
|
||||
[
|
||||
0.000,
|
||||
0.447,
|
||||
0.741,
|
||||
0.850,
|
||||
0.325,
|
||||
0.098,
|
||||
0.929,
|
||||
0.694,
|
||||
0.125,
|
||||
0.494,
|
||||
0.184,
|
||||
0.556,
|
||||
0.466,
|
||||
0.674,
|
||||
0.188,
|
||||
0.301,
|
||||
0.745,
|
||||
0.933,
|
||||
0.635,
|
||||
0.078,
|
||||
0.184,
|
||||
0.300,
|
||||
0.300,
|
||||
0.300,
|
||||
0.600,
|
||||
0.600,
|
||||
0.600,
|
||||
1.000,
|
||||
0.000,
|
||||
0.000,
|
||||
1.000,
|
||||
0.500,
|
||||
0.000,
|
||||
0.749,
|
||||
0.749,
|
||||
0.000,
|
||||
0.000,
|
||||
1.000,
|
||||
0.000,
|
||||
0.000,
|
||||
0.000,
|
||||
1.000,
|
||||
0.667,
|
||||
0.000,
|
||||
1.000,
|
||||
0.333,
|
||||
0.333,
|
||||
0.000,
|
||||
0.333,
|
||||
0.667,
|
||||
0.000,
|
||||
0.333,
|
||||
1.000,
|
||||
0.000,
|
||||
0.667,
|
||||
0.333,
|
||||
0.000,
|
||||
0.667,
|
||||
0.667,
|
||||
0.000,
|
||||
0.667,
|
||||
1.000,
|
||||
0.000,
|
||||
1.000,
|
||||
0.333,
|
||||
0.000,
|
||||
1.000,
|
||||
0.667,
|
||||
0.000,
|
||||
1.000,
|
||||
1.000,
|
||||
0.000,
|
||||
0.000,
|
||||
0.333,
|
||||
0.500,
|
||||
0.000,
|
||||
0.667,
|
||||
0.500,
|
||||
0.000,
|
||||
1.000,
|
||||
0.500,
|
||||
0.333,
|
||||
0.000,
|
||||
0.500,
|
||||
0.333,
|
||||
0.333,
|
||||
0.500,
|
||||
0.333,
|
||||
0.667,
|
||||
0.500,
|
||||
0.333,
|
||||
1.000,
|
||||
0.500,
|
||||
0.667,
|
||||
0.000,
|
||||
0.500,
|
||||
0.667,
|
||||
0.333,
|
||||
0.500,
|
||||
0.667,
|
||||
0.667,
|
||||
0.500,
|
||||
0.667,
|
||||
1.000,
|
||||
0.500,
|
||||
1.000,
|
||||
0.000,
|
||||
0.500,
|
||||
1.000,
|
||||
0.333,
|
||||
0.500,
|
||||
1.000,
|
||||
0.667,
|
||||
0.500,
|
||||
1.000,
|
||||
1.000,
|
||||
0.500,
|
||||
0.000,
|
||||
0.333,
|
||||
1.000,
|
||||
0.000,
|
||||
0.667,
|
||||
1.000,
|
||||
0.000,
|
||||
1.000,
|
||||
1.000,
|
||||
0.333,
|
||||
0.000,
|
||||
1.000,
|
||||
0.333,
|
||||
0.333,
|
||||
1.000,
|
||||
0.333,
|
||||
0.667,
|
||||
1.000,
|
||||
0.333,
|
||||
1.000,
|
||||
1.000,
|
||||
0.667,
|
||||
0.000,
|
||||
1.000,
|
||||
0.667,
|
||||
0.333,
|
||||
1.000,
|
||||
0.667,
|
||||
0.667,
|
||||
1.000,
|
||||
0.667,
|
||||
1.000,
|
||||
1.000,
|
||||
1.000,
|
||||
0.000,
|
||||
1.000,
|
||||
1.000,
|
||||
0.333,
|
||||
1.000,
|
||||
1.000,
|
||||
0.667,
|
||||
1.000,
|
||||
0.333,
|
||||
0.000,
|
||||
0.000,
|
||||
0.500,
|
||||
0.000,
|
||||
0.000,
|
||||
0.667,
|
||||
0.000,
|
||||
0.000,
|
||||
0.833,
|
||||
0.000,
|
||||
0.000,
|
||||
1.000,
|
||||
0.000,
|
||||
0.000,
|
||||
0.000,
|
||||
0.167,
|
||||
0.000,
|
||||
0.000,
|
||||
0.333,
|
||||
0.000,
|
||||
0.000,
|
||||
0.500,
|
||||
0.000,
|
||||
0.000,
|
||||
0.667,
|
||||
0.000,
|
||||
0.000,
|
||||
0.833,
|
||||
0.000,
|
||||
0.000,
|
||||
1.000,
|
||||
0.000,
|
||||
0.000,
|
||||
0.000,
|
||||
0.167,
|
||||
0.000,
|
||||
0.000,
|
||||
0.333,
|
||||
0.000,
|
||||
0.000,
|
||||
0.500,
|
||||
0.000,
|
||||
0.000,
|
||||
0.667,
|
||||
0.000,
|
||||
0.000,
|
||||
0.833,
|
||||
0.000,
|
||||
0.000,
|
||||
1.000,
|
||||
0.000,
|
||||
0.000,
|
||||
0.000,
|
||||
0.143,
|
||||
0.143,
|
||||
0.143,
|
||||
0.857,
|
||||
0.857,
|
||||
0.857,
|
||||
1.000,
|
||||
1.000,
|
||||
1.000,
|
||||
]
|
||||
)
|
||||
.astype(np.float32)
|
||||
.reshape(-1, 3)
|
||||
)
|
||||
953
examples/research_projects/movement-pruning/masked_run_glue.py
Normal file
953
examples/research_projects/movement-pruning/masked_run_glue.py
Normal file
@@ -0,0 +1,953 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Fine-pruning Masked BERT on sequence classification on GLUE."""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from emmental import MaskedBertConfig, MaskedBertForSequenceClassification
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
BertConfig,
|
||||
BertForSequenceClassification,
|
||||
BertTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from transformers import glue_compute_metrics as compute_metrics
|
||||
from transformers import glue_convert_examples_to_features as convert_examples_to_features
|
||||
from transformers import glue_output_modes as output_modes
|
||||
from transformers import glue_processors as processors
|
||||
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except ImportError:
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
|
||||
"masked_bert": (MaskedBertConfig, MaskedBertForSequenceClassification, BertTokenizer),
|
||||
}
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
def schedule_threshold(
|
||||
step: int,
|
||||
total_step: int,
|
||||
warmup_steps: int,
|
||||
initial_threshold: float,
|
||||
final_threshold: float,
|
||||
initial_warmup: int,
|
||||
final_warmup: int,
|
||||
final_lambda: float,
|
||||
):
|
||||
if step <= initial_warmup * warmup_steps:
|
||||
threshold = initial_threshold
|
||||
elif step > (total_step - final_warmup * warmup_steps):
|
||||
threshold = final_threshold
|
||||
else:
|
||||
spars_warmup_steps = initial_warmup * warmup_steps
|
||||
spars_schedu_steps = (final_warmup + initial_warmup) * warmup_steps
|
||||
mul_coeff = 1 - (step - spars_warmup_steps) / (total_step - spars_schedu_steps)
|
||||
threshold = final_threshold + (initial_threshold - final_threshold) * (mul_coeff ** 3)
|
||||
regu_lambda = final_lambda * threshold / final_threshold
|
||||
return threshold, regu_lambda
|
||||
|
||||
|
||||
def regularization(model: nn.Module, mode: str):
|
||||
regu, counter = 0, 0
|
||||
for name, param in model.named_parameters():
|
||||
if "mask_scores" in name:
|
||||
if mode == "l1":
|
||||
regu += torch.norm(torch.sigmoid(param), p=1) / param.numel()
|
||||
elif mode == "l0":
|
||||
regu += torch.sigmoid(param - 2 / 3 * np.log(0.1 / 1.1)).sum() / param.numel()
|
||||
else:
|
||||
ValueError("Don't know this mode.")
|
||||
counter += 1
|
||||
return regu / counter
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
""" Train the model """
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter(log_dir=args.output_dir)
|
||||
|
||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
if args.max_steps > 0:
|
||||
t_total = args.max_steps
|
||||
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
||||
else:
|
||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if "mask_score" in n and p.requires_grad],
|
||||
"lr": args.mask_scores_learning_rate,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if "mask_score" not in n and p.requires_grad and not any(nd in n for nd in no_decay)
|
||||
],
|
||||
"lr": args.learning_rate,
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if "mask_score" not in n and p.requires_grad and any(nd in n for nd in no_decay)
|
||||
],
|
||||
"lr": args.learning_rate,
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
|
||||
os.path.join(args.model_name_or_path, "scheduler.pt")
|
||||
):
|
||||
# Load in optimizer and scheduler states
|
||||
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
||||
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
|
||||
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||
|
||||
# multi-gpu training (should be after apex fp16 initialization)
|
||||
if args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True,
|
||||
)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size
|
||||
* args.gradient_accumulation_steps
|
||||
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
# Distillation
|
||||
if teacher is not None:
|
||||
logger.info(" Training with distillation")
|
||||
|
||||
global_step = 0
|
||||
# Global TopK
|
||||
if args.global_topk:
|
||||
threshold_mem = None
|
||||
epochs_trained = 0
|
||||
steps_trained_in_current_epoch = 0
|
||||
# Check if continuing training from a checkpoint
|
||||
if os.path.exists(args.model_name_or_path):
|
||||
# set global_step to global_step of last saved checkpoint from model path
|
||||
try:
|
||||
global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
|
||||
except ValueError:
|
||||
global_step = 0
|
||||
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||
|
||||
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
||||
logger.info(" Continuing training from epoch %d", epochs_trained)
|
||||
logger.info(" Continuing training from global step %d", global_step)
|
||||
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
||||
|
||||
tr_loss, logging_loss = 0.0, 0.0
|
||||
model.zero_grad()
|
||||
train_iterator = trange(
|
||||
epochs_trained,
|
||||
int(args.num_train_epochs),
|
||||
desc="Epoch",
|
||||
disable=args.local_rank not in [-1, 0],
|
||||
)
|
||||
set_seed(args) # Added here for reproducibility
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
|
||||
# Skip past any already trained steps if resuming training
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
steps_trained_in_current_epoch -= 1
|
||||
continue
|
||||
|
||||
model.train()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
threshold, regu_lambda = schedule_threshold(
|
||||
step=global_step,
|
||||
total_step=t_total,
|
||||
warmup_steps=args.warmup_steps,
|
||||
final_threshold=args.final_threshold,
|
||||
initial_threshold=args.initial_threshold,
|
||||
final_warmup=args.final_warmup,
|
||||
initial_warmup=args.initial_warmup,
|
||||
final_lambda=args.final_lambda,
|
||||
)
|
||||
# Global TopK
|
||||
if args.global_topk:
|
||||
if threshold == 1.0:
|
||||
threshold = -1e2 # Or an indefinitely low quantity
|
||||
else:
|
||||
if (threshold_mem is None) or (global_step % args.global_topk_frequency_compute == 0):
|
||||
# Sort all the values to get the global topK
|
||||
concat = torch.cat(
|
||||
[param.view(-1) for name, param in model.named_parameters() if "mask_scores" in name]
|
||||
)
|
||||
n = concat.numel()
|
||||
kth = max(n - (int(n * threshold) + 1), 1)
|
||||
threshold_mem = concat.kthvalue(kth).values.item()
|
||||
threshold = threshold_mem
|
||||
else:
|
||||
threshold = threshold_mem
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = (
|
||||
batch[2] if args.model_type in ["bert", "masked_bert", "xlnet", "albert"] else None
|
||||
) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
|
||||
|
||||
if "masked" in args.model_type:
|
||||
inputs["threshold"] = threshold
|
||||
|
||||
outputs = model(**inputs)
|
||||
loss, logits_stu = outputs # model outputs are always tuple in transformers (see doc)
|
||||
|
||||
# Distillation loss
|
||||
if teacher is not None:
|
||||
if "token_type_ids" not in inputs:
|
||||
inputs["token_type_ids"] = None if args.teacher_type == "xlm" else batch[2]
|
||||
with torch.no_grad():
|
||||
(logits_tea,) = teacher(
|
||||
input_ids=inputs["input_ids"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
)
|
||||
|
||||
loss_logits = (
|
||||
F.kl_div(
|
||||
input=F.log_softmax(logits_stu / args.temperature, dim=-1),
|
||||
target=F.softmax(logits_tea / args.temperature, dim=-1),
|
||||
reduction="batchmean",
|
||||
)
|
||||
* (args.temperature ** 2)
|
||||
)
|
||||
|
||||
loss = args.alpha_distil * loss_logits + args.alpha_ce * loss
|
||||
|
||||
# Regularization
|
||||
if args.regularization is not None:
|
||||
regu_ = regularization(model=model, mode=args.regularization)
|
||||
loss = loss + regu_lambda * regu_
|
||||
|
||||
if args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
if args.fp16:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
tr_loss += loss.item()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0 or (
|
||||
# last step in epoch but step is always smaller than gradient_accumulation_steps
|
||||
len(epoch_iterator) <= args.gradient_accumulation_steps
|
||||
and (step + 1) == len(epoch_iterator)
|
||||
):
|
||||
if args.fp16:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
tb_writer.add_scalar("threshold", threshold, global_step)
|
||||
for name, param in model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
tb_writer.add_scalar("parameter_mean/" + name, param.data.mean(), global_step)
|
||||
tb_writer.add_scalar("parameter_std/" + name, param.data.std(), global_step)
|
||||
tb_writer.add_scalar("parameter_min/" + name, param.data.min(), global_step)
|
||||
tb_writer.add_scalar("parameter_max/" + name, param.data.max(), global_step)
|
||||
tb_writer.add_scalar("grad_mean/" + name, param.grad.data.mean(), global_step)
|
||||
tb_writer.add_scalar("grad_std/" + name, param.grad.data.std(), global_step)
|
||||
if args.regularization is not None and "mask_scores" in name:
|
||||
if args.regularization == "l1":
|
||||
perc = (torch.sigmoid(param) > threshold).sum().item() / param.numel()
|
||||
elif args.regularization == "l0":
|
||||
perc = (torch.sigmoid(param - 2 / 3 * np.log(0.1 / 1.1))).sum().item() / param.numel()
|
||||
tb_writer.add_scalar("retained_weights_perc/" + name, perc, global_step)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step() # Update learning rate schedule
|
||||
model.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
logs = {}
|
||||
if (
|
||||
args.local_rank == -1 and args.evaluate_during_training
|
||||
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||
results = evaluate(args, model, tokenizer)
|
||||
for key, value in results.items():
|
||||
eval_key = "eval_{}".format(key)
|
||||
logs[eval_key] = value
|
||||
|
||||
loss_scalar = (tr_loss - logging_loss) / args.logging_steps
|
||||
learning_rate_scalar = scheduler.get_lr()
|
||||
logs["learning_rate"] = learning_rate_scalar[0]
|
||||
if len(learning_rate_scalar) > 1:
|
||||
for idx, lr in enumerate(learning_rate_scalar[1:]):
|
||||
logs[f"learning_rate/{idx+1}"] = lr
|
||||
logs["loss"] = loss_scalar
|
||||
if teacher is not None:
|
||||
logs["loss/distil"] = loss_logits.item()
|
||||
if args.regularization is not None:
|
||||
logs["loss/regularization"] = regu_.item()
|
||||
if (teacher is not None) or (args.regularization is not None):
|
||||
if (teacher is not None) and (args.regularization is not None):
|
||||
logs["loss/instant_ce"] = (
|
||||
loss.item()
|
||||
- regu_lambda * logs["loss/regularization"]
|
||||
- args.alpha_distil * logs["loss/distil"]
|
||||
) / args.alpha_ce
|
||||
elif teacher is not None:
|
||||
logs["loss/instant_ce"] = (
|
||||
loss.item() - args.alpha_distil * logs["loss/distil"]
|
||||
) / args.alpha_ce
|
||||
else:
|
||||
logs["loss/instant_ce"] = loss.item() - regu_lambda * logs["loss/regularization"]
|
||||
logging_loss = tr_loss
|
||||
|
||||
for key, value in logs.items():
|
||||
tb_writer.add_scalar(key, value, global_step)
|
||||
print(json.dumps({**logs, **{"step": global_step}}))
|
||||
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
epoch_iterator.close()
|
||||
break
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
train_iterator.close()
|
||||
break
|
||||
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer.close()
|
||||
|
||||
return global_step, tr_loss / global_step
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, prefix=""):
|
||||
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
||||
eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
|
||||
eval_outputs_dirs = (args.output_dir, args.output_dir + "/MM") if args.task_name == "mnli" else (args.output_dir,)
|
||||
|
||||
results = {}
|
||||
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
|
||||
eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
|
||||
|
||||
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(eval_output_dir)
|
||||
|
||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(eval_dataset)
|
||||
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
# multi-gpu eval
|
||||
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Eval!
|
||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||
logger.info(" Num examples = %d", len(eval_dataset))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
eval_loss = 0.0
|
||||
nb_eval_steps = 0
|
||||
preds = None
|
||||
out_label_ids = None
|
||||
|
||||
# Global TopK
|
||||
if args.global_topk:
|
||||
threshold_mem = None
|
||||
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
model.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = (
|
||||
batch[2] if args.model_type in ["bert", "masked_bert", "xlnet", "albert"] else None
|
||||
) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
|
||||
if "masked" in args.model_type:
|
||||
inputs["threshold"] = args.final_threshold
|
||||
if args.global_topk:
|
||||
if threshold_mem is None:
|
||||
concat = torch.cat(
|
||||
[param.view(-1) for name, param in model.named_parameters() if "mask_scores" in name]
|
||||
)
|
||||
n = concat.numel()
|
||||
kth = max(n - (int(n * args.final_threshold) + 1), 1)
|
||||
threshold_mem = concat.kthvalue(kth).values.item()
|
||||
inputs["threshold"] = threshold_mem
|
||||
outputs = model(**inputs)
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
|
||||
eval_loss += tmp_eval_loss.mean().item()
|
||||
nb_eval_steps += 1
|
||||
if preds is None:
|
||||
preds = logits.detach().cpu().numpy()
|
||||
out_label_ids = inputs["labels"].detach().cpu().numpy()
|
||||
else:
|
||||
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
||||
out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
if args.output_mode == "classification":
|
||||
from scipy.special import softmax
|
||||
|
||||
probs = softmax(preds, axis=-1)
|
||||
entropy = np.exp((-probs * np.log(probs)).sum(axis=-1).mean())
|
||||
preds = np.argmax(preds, axis=1)
|
||||
elif args.output_mode == "regression":
|
||||
preds = np.squeeze(preds)
|
||||
result = compute_metrics(eval_task, preds, out_label_ids)
|
||||
results.update(result)
|
||||
if entropy is not None:
|
||||
result["eval_avg_entropy"] = entropy
|
||||
|
||||
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results {} *****".format(prefix))
|
||||
for key in sorted(result.keys()):
|
||||
logger.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
||||
if args.local_rank not in [-1, 0] and not evaluate:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
processor = processors[task]()
|
||||
output_mode = output_modes[task]
|
||||
# Load data features from cache or dataset file
|
||||
cached_features_file = os.path.join(
|
||||
args.data_dir,
|
||||
"cached_{}_{}_{}_{}".format(
|
||||
"dev" if evaluate else "train",
|
||||
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||||
str(args.max_seq_length),
|
||||
str(task),
|
||||
),
|
||||
)
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features = torch.load(cached_features_file)
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s", args.data_dir)
|
||||
label_list = processor.get_labels()
|
||||
if task in ["mnli", "mnli-mm"] and args.model_type in ["roberta", "xlmroberta"]:
|
||||
# HACK(label indices are swapped in RoBERTa pretrained model)
|
||||
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||
examples = (
|
||||
processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
||||
)
|
||||
features = convert_examples_to_features(
|
||||
examples,
|
||||
tokenizer,
|
||||
max_length=args.max_seq_length,
|
||||
label_list=label_list,
|
||||
output_mode=output_mode,
|
||||
)
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(features, cached_features_file)
|
||||
|
||||
if args.local_rank == 0 and not evaluate:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
# Convert to Tensors and build dataset
|
||||
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
||||
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
|
||||
if output_mode == "classification":
|
||||
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
||||
elif output_mode == "regression":
|
||||
all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
|
||||
|
||||
dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_name",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
# Other parameters
|
||||
parser.add_argument(
|
||||
"--config_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained config name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from huggingface.co",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||
parser.add_argument(
|
||||
"--evaluate_during_training",
|
||||
action="store_true",
|
||||
help="Run evaluation during training at each logging step.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_lower_case",
|
||||
action="store_true",
|
||||
help="Set this flag if you are using an uncased model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--per_gpu_train_batch_size",
|
||||
default=8,
|
||||
type=int,
|
||||
help="Batch size per GPU/CPU for training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_gpu_eval_batch_size",
|
||||
default=8,
|
||||
type=int,
|
||||
help="Batch size per GPU/CPU for evaluation.",
|
||||
)
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
|
||||
# Pruning parameters
|
||||
parser.add_argument(
|
||||
"--mask_scores_learning_rate",
|
||||
default=1e-2,
|
||||
type=float,
|
||||
help="The Adam initial learning rate of the mask scores.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initial_threshold", default=1.0, type=float, help="Initial value of the threshold (for scheduling)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--final_threshold", default=0.7, type=float, help="Final value of the threshold (for scheduling)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initial_warmup",
|
||||
default=1,
|
||||
type=int,
|
||||
help="Run `initial_warmup` * `warmup_steps` steps of threshold warmup during which threshold stays"
|
||||
"at its `initial_threshold` value (sparsity schedule).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--final_warmup",
|
||||
default=2,
|
||||
type=int,
|
||||
help="Run `final_warmup` * `warmup_steps` steps of threshold cool-down during which threshold stays"
|
||||
"at its final_threshold value (sparsity schedule).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pruning_method",
|
||||
default="topK",
|
||||
type=str,
|
||||
help="Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning, sigmoied_threshold = Soft movement pruning).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mask_init",
|
||||
default="constant",
|
||||
type=str,
|
||||
help="Initialization method for the mask scores. Choices: constant, uniform, kaiming.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mask_scale", default=0.0, type=float, help="Initialization parameter for the chosen initialization method."
|
||||
)
|
||||
|
||||
parser.add_argument("--regularization", default=None, help="Add L0 or L1 regularization to the mask scores.")
|
||||
parser.add_argument(
|
||||
"--final_lambda",
|
||||
default=0.0,
|
||||
type=float,
|
||||
help="Regularization intensity (used in conjunction with `regularization`.",
|
||||
)
|
||||
|
||||
parser.add_argument("--global_topk", action="store_true", help="Global TopK on the Scores.")
|
||||
parser.add_argument(
|
||||
"--global_topk_frequency_compute",
|
||||
default=25,
|
||||
type=int,
|
||||
help="Frequency at which we compute the TopK global threshold.",
|
||||
)
|
||||
|
||||
# Distillation parameters (optional)
|
||||
parser.add_argument(
|
||||
"--teacher_type",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--teacher_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to the already fine-tuned teacher model. Only for distillation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha_ce", default=0.5, type=float, help="Cross entropy loss linear weight. Only for distillation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha_distil", default=0.5, type=float, help="Distillation loss linear weight. Only for distillation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature", default=2.0, type=float, help="Distillation temperature. Only for distillation."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs",
|
||||
default=3.0,
|
||||
type=float,
|
||||
help="Total number of training epochs to perform.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||
)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
|
||||
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument(
|
||||
"--eval_all_checkpoints",
|
||||
action="store_true",
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||
)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||
parser.add_argument(
|
||||
"--overwrite_output_dir",
|
||||
action="store_true",
|
||||
help="Overwrite the content of the output directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache",
|
||||
action="store_true",
|
||||
help="Overwrite the cached training and evaluation sets",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Regularization
|
||||
if args.regularization == "null":
|
||||
args.regularization = None
|
||||
|
||||
if (
|
||||
os.path.exists(args.output_dir)
|
||||
and os.listdir(args.output_dir)
|
||||
and args.do_train
|
||||
and not args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
f"Output directory ({args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
|
||||
# Setup CUDA, GPU & distributed training
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
|
||||
else: # Initializes the distributed backend which will take care of synchronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank,
|
||||
device,
|
||||
args.n_gpu,
|
||||
bool(args.local_rank != -1),
|
||||
args.fp16,
|
||||
)
|
||||
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
|
||||
# Prepare GLUE task
|
||||
args.task_name = args.task_name.lower()
|
||||
if args.task_name not in processors:
|
||||
raise ValueError("Task not found: %s" % (args.task_name))
|
||||
processor = processors[args.task_name]()
|
||||
args.output_mode = output_modes[args.task_name]
|
||||
label_list = processor.get_labels()
|
||||
num_labels = len(label_list)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=args.task_name,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
pruning_method=args.pruning_method,
|
||||
mask_init=args.mask_init,
|
||||
mask_scale=args.mask_scale,
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
do_lower_case=args.do_lower_case,
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
|
||||
if args.teacher_type is not None:
|
||||
assert args.teacher_name_or_path is not None
|
||||
assert args.alpha_distil > 0.0
|
||||
assert args.alpha_distil + args.alpha_ce > 0.0
|
||||
teacher_config_class, teacher_model_class, _ = MODEL_CLASSES[args.teacher_type]
|
||||
teacher_config = teacher_config_class.from_pretrained(args.teacher_name_or_path)
|
||||
teacher = teacher_model_class.from_pretrained(
|
||||
args.teacher_name_or_path,
|
||||
from_tf=False,
|
||||
config=teacher_config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
teacher.to(args.device)
|
||||
else:
|
||||
teacher = None
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
model.to(args.device)
|
||||
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Training
|
||||
if args.do_train:
|
||||
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer, teacher=teacher)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(
|
||||
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||
)
|
||||
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
for checkpoint in checkpoints:
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
model.to(args.device)
|
||||
result = evaluate(args, model, tokenizer, prefix=prefix)
|
||||
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
||||
results.update(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1133
examples/research_projects/movement-pruning/masked_run_squad.py
Normal file
1133
examples/research_projects/movement-pruning/masked_run_squad.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,6 @@
|
||||
torch>=1.4.0
|
||||
-e git+https://github.com/huggingface/transformers.git@352d5472b0c1dec0f420d606d16747d851b4bda8#egg=transformers
|
||||
knockknock>=0.1.8.1
|
||||
h5py>=2.10.0
|
||||
numpy>=1.18.2
|
||||
scipy>=1.4.1
|
||||
54
examples/research_projects/pplm/README.md
Normal file
54
examples/research_projects/pplm/README.md
Normal file
@@ -0,0 +1,54 @@
|
||||
# Plug and Play Language Models: a Simple Approach to Controlled Text Generation
|
||||
|
||||
Authors: [Sumanth Dathathri](https://dathath.github.io/), [Andrea Madotto](https://andreamad8.github.io/), Janice Lan, Jane Hung, Eric Frank, [Piero Molino](https://w4nderlu.st/), [Jason Yosinski](http://yosinski.com/), and [Rosanne Liu](http://www.rosanneliu.com/)
|
||||
|
||||
This folder contains the original code used to run the Plug and Play Language Model (PPLM).
|
||||
|
||||
Paper link: https://arxiv.org/abs/1912.02164
|
||||
|
||||
Blog link: https://eng.uber.com/pplm
|
||||
|
||||
Please check out the repo under uber-research for more information: https://github.com/uber-research/PPLM
|
||||
|
||||
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/transformers && cd transformers
|
||||
pip install .
|
||||
pip install nltk torchtext # additional requirements.
|
||||
cd examples/text-generation/pplm
|
||||
```
|
||||
|
||||
## PPLM-BoW
|
||||
|
||||
### Example command for bag-of-words control
|
||||
|
||||
```bash
|
||||
python run_pplm.py -B military --cond_text "The potato" --length 50 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.03 --window_length 5 --kl_scale 0.01 --gm_scale 0.99 --colorama --sample
|
||||
```
|
||||
|
||||
### Tuning hyperparameters for bag-of-words control
|
||||
|
||||
1. Increase `--stepsize` to intensify topic control, and decrease its value to soften the control. `--stepsize 0` recovers the original uncontrolled GPT-2 model.
|
||||
|
||||
2. If the language being generated is repetitive (For e.g. "science science experiment experiment"), there are several options to consider: </br>
|
||||
a) Reduce the `--stepsize` </br>
|
||||
b) Increase `--kl_scale` (the KL-loss coefficient) or decrease `--gm_scale` (the gm-scaling term) </br>
|
||||
c) Add `--grad-length xx` where xx is an (integer <= length, e.g. `--grad-length 30`).</br>
|
||||
|
||||
|
||||
## PPLM-Discrim
|
||||
|
||||
### Example command for discriminator based sentiment control
|
||||
|
||||
```bash
|
||||
python run_pplm.py -D sentiment --class_label 2 --cond_text "My dog died" --length 50 --gamma 1.0 --num_iterations 10 --num_samples 10 --stepsize 0.04 --kl_scale 0.01 --gm_scale 0.95 --sample
|
||||
```
|
||||
|
||||
### Tuning hyperparameters for discriminator control
|
||||
|
||||
1. Increase `--stepsize` to intensify topic control, and decrease its value to soften the control. `--stepsize 0` recovers the original uncontrolled GPT-2 model.
|
||||
|
||||
2. Use `--class_label 3` for negative, and `--class_label 2` for positive
|
||||
|
||||
BIN
examples/research_projects/pplm/imgs/headfigure.png
Normal file
BIN
examples/research_projects/pplm/imgs/headfigure.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 653 KiB |
BIN
examples/research_projects/pplm/imgs/wooly.png
Normal file
BIN
examples/research_projects/pplm/imgs/wooly.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 664 KiB |
19
examples/research_projects/pplm/pplm_classification_head.py
Normal file
19
examples/research_projects/pplm/pplm_classification_head.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import torch
|
||||
|
||||
|
||||
class ClassificationHead(torch.nn.Module):
|
||||
"""Classification Head for transformer encoders"""
|
||||
|
||||
def __init__(self, class_size, embed_size):
|
||||
super().__init__()
|
||||
self.class_size = class_size
|
||||
self.embed_size = embed_size
|
||||
# self.mlp1 = torch.nn.Linear(embed_size, embed_size)
|
||||
# self.mlp2 = (torch.nn.Linear(embed_size, class_size))
|
||||
self.mlp = torch.nn.Linear(embed_size, class_size)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
# hidden_state = F.relu(self.mlp1(hidden_state))
|
||||
# hidden_state = self.mlp2(hidden_state)
|
||||
logits = self.mlp(hidden_state)
|
||||
return logits
|
||||
22
examples/research_projects/pplm/requirements.txt
Normal file
22
examples/research_projects/pplm/requirements.txt
Normal file
@@ -0,0 +1,22 @@
|
||||
tensorboard
|
||||
scikit-learn
|
||||
seqeval
|
||||
psutil
|
||||
sacrebleu
|
||||
rouge-score
|
||||
tensorflow_datasets
|
||||
pytorch-lightning==1.0.4
|
||||
matplotlib
|
||||
git-python==1.0.3
|
||||
faiss-cpu
|
||||
streamlit
|
||||
elasticsearch
|
||||
nltk
|
||||
pandas
|
||||
datasets >= 1.1.3
|
||||
fire
|
||||
pytest
|
||||
conllu
|
||||
sentencepiece != 0.1.92
|
||||
protobuf
|
||||
transformers==3.5.1
|
||||
820
examples/research_projects/pplm/run_pplm.py
Normal file
820
examples/research_projects/pplm/run_pplm.py
Normal file
@@ -0,0 +1,820 @@
|
||||
#! /usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright (c) 2019 Uber Technologies, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Example command with bag of words:
|
||||
python run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95
|
||||
|
||||
Example command with discriminator:
|
||||
python run_pplm.py -D sentiment --class_label 3 --cond_text "The lake" --length 10 --gamma 1.0 --num_iterations 30 --num_samples 10 --stepsize 0.01 --kl_scale 0.01 --gm_scale 0.95
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from operator import add
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from tqdm import trange
|
||||
|
||||
from pplm_classification_head import ClassificationHead
|
||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||
from transformers.file_utils import cached_path
|
||||
|
||||
|
||||
PPLM_BOW = 1
|
||||
PPLM_DISCRIM = 2
|
||||
PPLM_BOW_DISCRIM = 3
|
||||
SMALL_CONST = 1e-15
|
||||
BIG_CONST = 1e10
|
||||
|
||||
BAG_OF_WORDS_ARCHIVE_MAP = {
|
||||
"legal": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/legal.txt",
|
||||
"military": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/military.txt",
|
||||
"politics": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/politics.txt",
|
||||
"religion": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/religion.txt",
|
||||
"science": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/science.txt",
|
||||
"space": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/space.txt",
|
||||
"technology": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/technology.txt",
|
||||
}
|
||||
|
||||
DISCRIMINATOR_MODELS_PARAMS = {
|
||||
"clickbait": {
|
||||
"url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/clickbait_classifier_head.pt",
|
||||
"class_size": 2,
|
||||
"embed_size": 1024,
|
||||
"class_vocab": {"non_clickbait": 0, "clickbait": 1},
|
||||
"default_class": 1,
|
||||
"pretrained_model": "gpt2-medium",
|
||||
},
|
||||
"sentiment": {
|
||||
"url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/SST_classifier_head.pt",
|
||||
"class_size": 5,
|
||||
"embed_size": 1024,
|
||||
"class_vocab": {"very_positive": 2, "very_negative": 3},
|
||||
"default_class": 3,
|
||||
"pretrained_model": "gpt2-medium",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def top_k_filter(logits, k, probs=False):
|
||||
"""
|
||||
Masks everything but the k top entries as -infinity (1e10).
|
||||
Used to mask logits such that e^-infinity -> 0 won't contribute to the
|
||||
sum of the denominator.
|
||||
"""
|
||||
if k == 0:
|
||||
return logits
|
||||
else:
|
||||
values = torch.topk(logits, k)[0]
|
||||
batch_mins = values[:, -1].view(-1, 1).expand_as(logits)
|
||||
if probs:
|
||||
return torch.where(logits < batch_mins, torch.ones_like(logits) * 0.0, logits)
|
||||
return torch.where(logits < batch_mins, torch.ones_like(logits) * -BIG_CONST, logits)
|
||||
|
||||
|
||||
def perturb_past(
|
||||
past,
|
||||
model,
|
||||
last,
|
||||
unpert_past=None,
|
||||
unpert_logits=None,
|
||||
accumulated_hidden=None,
|
||||
grad_norms=None,
|
||||
stepsize=0.01,
|
||||
one_hot_bows_vectors=None,
|
||||
classifier=None,
|
||||
class_label=None,
|
||||
loss_type=0,
|
||||
num_iterations=3,
|
||||
horizon_length=1,
|
||||
window_length=0,
|
||||
decay=False,
|
||||
gamma=1.5,
|
||||
kl_scale=0.01,
|
||||
device="cuda",
|
||||
):
|
||||
# Generate inital perturbed past
|
||||
grad_accumulator = [(np.zeros(p.shape).astype("float32")) for p in past]
|
||||
|
||||
if accumulated_hidden is None:
|
||||
accumulated_hidden = 0
|
||||
|
||||
if decay:
|
||||
decay_mask = torch.arange(0.0, 1.0 + SMALL_CONST, 1.0 / (window_length))[1:]
|
||||
else:
|
||||
decay_mask = 1.0
|
||||
|
||||
# TODO fix this comment (SUMANTH)
|
||||
# Generate a mask is gradient perturbated is based on a past window
|
||||
_, _, _, curr_length, _ = past[0].shape
|
||||
|
||||
if curr_length > window_length and window_length > 0:
|
||||
ones_key_val_shape = tuple(past[0].shape[:-2]) + tuple([window_length]) + tuple(past[0].shape[-1:])
|
||||
|
||||
zeros_key_val_shape = (
|
||||
tuple(past[0].shape[:-2]) + tuple([curr_length - window_length]) + tuple(past[0].shape[-1:])
|
||||
)
|
||||
|
||||
ones_mask = torch.ones(ones_key_val_shape)
|
||||
ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3)
|
||||
ones_mask = ones_mask.permute(0, 1, 2, 4, 3)
|
||||
|
||||
window_mask = torch.cat((ones_mask, torch.zeros(zeros_key_val_shape)), dim=-2).to(device)
|
||||
else:
|
||||
window_mask = torch.ones_like(past[0]).to(device)
|
||||
|
||||
# accumulate perturbations for num_iterations
|
||||
loss_per_iter = []
|
||||
new_accumulated_hidden = None
|
||||
for i in range(num_iterations):
|
||||
print("Iteration ", i + 1)
|
||||
curr_perturbation = [torch.from_numpy(p_).requires_grad_(True).to(device=device) for p_ in grad_accumulator]
|
||||
# make sure p_.grad is not None
|
||||
for p_ in curr_perturbation:
|
||||
p_.retain_grad()
|
||||
|
||||
# Compute hidden using perturbed past
|
||||
perturbed_past = list(map(add, past, curr_perturbation))
|
||||
_, _, _, curr_length, _ = curr_perturbation[0].shape
|
||||
lm_output = model(last, past_key_values=perturbed_past)
|
||||
all_logits, all_hidden = lm_output["logits"], lm_output["hidden_states"]
|
||||
hidden = all_hidden[-1]
|
||||
new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach()
|
||||
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
|
||||
logits = all_logits[:, -1, :]
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
|
||||
loss = 0.0
|
||||
loss_list = []
|
||||
if loss_type == PPLM_BOW or loss_type == PPLM_BOW_DISCRIM:
|
||||
for one_hot_bow in one_hot_bows_vectors:
|
||||
bow_logits = torch.mm(probs, torch.t(one_hot_bow))
|
||||
bow_loss = -torch.log(torch.sum(bow_logits))
|
||||
loss += bow_loss
|
||||
loss_list.append(bow_loss)
|
||||
print(" pplm_bow_loss:", loss.data.cpu().numpy())
|
||||
|
||||
if loss_type == 2 or loss_type == 3:
|
||||
ce_loss = torch.nn.CrossEntropyLoss()
|
||||
# TODO why we need to do this assignment and not just using unpert_past? (Sumanth)
|
||||
curr_unpert_past = unpert_past
|
||||
curr_probs = torch.unsqueeze(probs, dim=1)
|
||||
wte = model.resize_token_embeddings()
|
||||
for _ in range(horizon_length):
|
||||
inputs_embeds = torch.matmul(curr_probs, wte.weight.data)
|
||||
lm_output = model(past_key_values=curr_unpert_past, inputs_embeds=inputs_embeds)
|
||||
curr_unpert_past, curr_all_hidden = lm_output["past_key_values"], lm_output["hidden_states"]
|
||||
curr_hidden = curr_all_hidden[-1]
|
||||
new_accumulated_hidden = new_accumulated_hidden + torch.sum(curr_hidden, dim=1)
|
||||
|
||||
prediction = classifier(new_accumulated_hidden / (curr_length + 1 + horizon_length))
|
||||
|
||||
label = torch.tensor(prediction.shape[0] * [class_label], device=device, dtype=torch.long)
|
||||
discrim_loss = ce_loss(prediction, label)
|
||||
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
|
||||
loss += discrim_loss
|
||||
loss_list.append(discrim_loss)
|
||||
|
||||
kl_loss = 0.0
|
||||
if kl_scale > 0.0:
|
||||
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||
unpert_probs = unpert_probs + SMALL_CONST * (unpert_probs <= SMALL_CONST).float().to(device).detach()
|
||||
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach()
|
||||
corrected_probs = probs + correction.detach()
|
||||
kl_loss = kl_scale * ((corrected_probs * (corrected_probs / unpert_probs).log()).sum())
|
||||
print(" kl_loss", kl_loss.data.cpu().numpy())
|
||||
loss += kl_loss
|
||||
|
||||
loss_per_iter.append(loss.data.cpu().numpy())
|
||||
print(" pplm_loss", (loss - kl_loss).data.cpu().numpy())
|
||||
|
||||
# compute gradients
|
||||
loss.backward()
|
||||
|
||||
# calculate gradient norms
|
||||
if grad_norms is not None and loss_type == PPLM_BOW:
|
||||
grad_norms = [
|
||||
torch.max(grad_norms[index], torch.norm(p_.grad * window_mask))
|
||||
for index, p_ in enumerate(curr_perturbation)
|
||||
]
|
||||
else:
|
||||
grad_norms = [
|
||||
(torch.norm(p_.grad * window_mask) + SMALL_CONST) for index, p_ in enumerate(curr_perturbation)
|
||||
]
|
||||
|
||||
# normalize gradients
|
||||
grad = [
|
||||
-stepsize * (p_.grad * window_mask / grad_norms[index] ** gamma).data.cpu().numpy()
|
||||
for index, p_ in enumerate(curr_perturbation)
|
||||
]
|
||||
|
||||
# accumulate gradient
|
||||
grad_accumulator = list(map(add, grad, grad_accumulator))
|
||||
|
||||
# reset gradients, just to make sure
|
||||
for p_ in curr_perturbation:
|
||||
p_.grad.data.zero_()
|
||||
|
||||
# removing past from the graph
|
||||
new_past = []
|
||||
for p_ in past:
|
||||
new_past.append(p_.detach())
|
||||
past = new_past
|
||||
|
||||
# apply the accumulated perturbations to the past
|
||||
grad_accumulator = [torch.from_numpy(p_).requires_grad_(True).to(device=device) for p_ in grad_accumulator]
|
||||
pert_past = list(map(add, past, grad_accumulator))
|
||||
|
||||
return pert_past, new_accumulated_hidden, grad_norms, loss_per_iter
|
||||
|
||||
|
||||
def get_classifier(
|
||||
name: Optional[str], class_label: Union[str, int], device: str
|
||||
) -> Tuple[Optional[ClassificationHead], Optional[int]]:
|
||||
if name is None:
|
||||
return None, None
|
||||
|
||||
params = DISCRIMINATOR_MODELS_PARAMS[name]
|
||||
classifier = ClassificationHead(class_size=params["class_size"], embed_size=params["embed_size"]).to(device)
|
||||
if "url" in params:
|
||||
resolved_archive_file = cached_path(params["url"])
|
||||
elif "path" in params:
|
||||
resolved_archive_file = params["path"]
|
||||
else:
|
||||
raise ValueError("Either url or path have to be specified in the discriminator model parameters")
|
||||
classifier.load_state_dict(torch.load(resolved_archive_file, map_location=device))
|
||||
classifier.eval()
|
||||
|
||||
if isinstance(class_label, str):
|
||||
if class_label in params["class_vocab"]:
|
||||
label_id = params["class_vocab"][class_label]
|
||||
else:
|
||||
label_id = params["default_class"]
|
||||
print("class_label {} not in class_vocab".format(class_label))
|
||||
print("available values are: {}".format(params["class_vocab"]))
|
||||
print("using default class {}".format(label_id))
|
||||
|
||||
elif isinstance(class_label, int):
|
||||
if class_label in set(params["class_vocab"].values()):
|
||||
label_id = class_label
|
||||
else:
|
||||
label_id = params["default_class"]
|
||||
print("class_label {} not in class_vocab".format(class_label))
|
||||
print("available values are: {}".format(params["class_vocab"]))
|
||||
print("using default class {}".format(label_id))
|
||||
|
||||
else:
|
||||
label_id = params["default_class"]
|
||||
|
||||
return classifier, label_id
|
||||
|
||||
|
||||
def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> List[List[List[int]]]:
|
||||
bow_indices = []
|
||||
for id_or_path in bag_of_words_ids_or_paths:
|
||||
if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP:
|
||||
filepath = cached_path(BAG_OF_WORDS_ARCHIVE_MAP[id_or_path])
|
||||
else:
|
||||
filepath = id_or_path
|
||||
with open(filepath, "r") as f:
|
||||
words = f.read().strip().split("\n")
|
||||
bow_indices.append([tokenizer.encode(word.strip(), add_prefix_space=True) for word in words])
|
||||
return bow_indices
|
||||
|
||||
|
||||
def build_bows_one_hot_vectors(bow_indices, tokenizer, device="cuda"):
|
||||
if bow_indices is None:
|
||||
return None
|
||||
|
||||
one_hot_bows_vectors = []
|
||||
for single_bow in bow_indices:
|
||||
single_bow = list(filter(lambda x: len(x) <= 1, single_bow))
|
||||
single_bow = torch.tensor(single_bow).to(device)
|
||||
num_words = single_bow.shape[0]
|
||||
one_hot_bow = torch.zeros(num_words, tokenizer.vocab_size).to(device)
|
||||
one_hot_bow.scatter_(1, single_bow, 1)
|
||||
one_hot_bows_vectors.append(one_hot_bow)
|
||||
return one_hot_bows_vectors
|
||||
|
||||
|
||||
def full_text_generation(
|
||||
model,
|
||||
tokenizer,
|
||||
context=None,
|
||||
num_samples=1,
|
||||
device="cuda",
|
||||
bag_of_words=None,
|
||||
discrim=None,
|
||||
class_label=None,
|
||||
length=100,
|
||||
stepsize=0.02,
|
||||
temperature=1.0,
|
||||
top_k=10,
|
||||
sample=False,
|
||||
num_iterations=3,
|
||||
grad_length=10000,
|
||||
horizon_length=1,
|
||||
window_length=0,
|
||||
decay=False,
|
||||
gamma=1.5,
|
||||
gm_scale=0.9,
|
||||
kl_scale=0.01,
|
||||
repetition_penalty=1.0,
|
||||
**kwargs
|
||||
):
|
||||
classifier, class_id = get_classifier(discrim, class_label, device)
|
||||
|
||||
bow_indices = []
|
||||
if bag_of_words:
|
||||
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"), tokenizer)
|
||||
|
||||
if bag_of_words and classifier:
|
||||
print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.")
|
||||
loss_type = PPLM_BOW_DISCRIM
|
||||
|
||||
elif bag_of_words:
|
||||
loss_type = PPLM_BOW
|
||||
print("Using PPLM-BoW")
|
||||
|
||||
elif classifier is not None:
|
||||
loss_type = PPLM_DISCRIM
|
||||
print("Using PPLM-Discrim")
|
||||
|
||||
else:
|
||||
raise Exception("Specify either a bag of words or a discriminator")
|
||||
|
||||
unpert_gen_tok_text, _, _ = generate_text_pplm(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
context=context,
|
||||
device=device,
|
||||
length=length,
|
||||
sample=sample,
|
||||
perturb=False,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
if device == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
pert_gen_tok_texts = []
|
||||
discrim_losses = []
|
||||
losses_in_time = []
|
||||
|
||||
for i in range(num_samples):
|
||||
pert_gen_tok_text, discrim_loss, loss_in_time = generate_text_pplm(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
context=context,
|
||||
device=device,
|
||||
perturb=True,
|
||||
bow_indices=bow_indices,
|
||||
classifier=classifier,
|
||||
class_label=class_id,
|
||||
loss_type=loss_type,
|
||||
length=length,
|
||||
stepsize=stepsize,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
sample=sample,
|
||||
num_iterations=num_iterations,
|
||||
grad_length=grad_length,
|
||||
horizon_length=horizon_length,
|
||||
window_length=window_length,
|
||||
decay=decay,
|
||||
gamma=gamma,
|
||||
gm_scale=gm_scale,
|
||||
kl_scale=kl_scale,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
pert_gen_tok_texts.append(pert_gen_tok_text)
|
||||
if classifier is not None:
|
||||
discrim_losses.append(discrim_loss.data.cpu().numpy())
|
||||
losses_in_time.append(loss_in_time)
|
||||
|
||||
if device == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
|
||||
|
||||
|
||||
def generate_text_pplm(
|
||||
model,
|
||||
tokenizer,
|
||||
context=None,
|
||||
past=None,
|
||||
device="cuda",
|
||||
perturb=True,
|
||||
bow_indices=None,
|
||||
classifier=None,
|
||||
class_label=None,
|
||||
loss_type=0,
|
||||
length=100,
|
||||
stepsize=0.02,
|
||||
temperature=1.0,
|
||||
top_k=10,
|
||||
sample=False,
|
||||
num_iterations=3,
|
||||
grad_length=10000,
|
||||
horizon_length=1,
|
||||
window_length=0,
|
||||
decay=False,
|
||||
gamma=1.5,
|
||||
gm_scale=0.9,
|
||||
kl_scale=0.01,
|
||||
repetition_penalty=1.0,
|
||||
):
|
||||
output_so_far = None
|
||||
if context:
|
||||
context_t = torch.tensor(context, device=device, dtype=torch.long)
|
||||
while len(context_t.shape) < 2:
|
||||
context_t = context_t.unsqueeze(0)
|
||||
output_so_far = context_t
|
||||
|
||||
# collect one hot vectors for bags of words
|
||||
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer, device)
|
||||
|
||||
grad_norms = None
|
||||
last = None
|
||||
unpert_discrim_loss = 0
|
||||
loss_in_time = []
|
||||
for i in trange(length, ascii=True):
|
||||
|
||||
# Get past/probs for current output, except for last word
|
||||
# Note that GPT takes 2 inputs: past + current_token
|
||||
|
||||
# run model forward to obtain unperturbed
|
||||
if past is None and output_so_far is not None:
|
||||
last = output_so_far[:, -1:]
|
||||
if output_so_far.shape[1] > 1:
|
||||
past = model(output_so_far[:, :-1])["past_key_values"]
|
||||
|
||||
lm_output = model(output_so_far)
|
||||
unpert_logits, unpert_past, unpert_all_hidden = (
|
||||
lm_output["logits"],
|
||||
lm_output["past_key_values"],
|
||||
lm_output["hidden_states"],
|
||||
)
|
||||
unpert_last_hidden = unpert_all_hidden[-1]
|
||||
|
||||
# check if we are abowe grad max length
|
||||
if i >= grad_length:
|
||||
current_stepsize = stepsize * 0
|
||||
else:
|
||||
current_stepsize = stepsize
|
||||
|
||||
# modify the past if necessary
|
||||
if not perturb or num_iterations == 0:
|
||||
pert_past = past
|
||||
|
||||
else:
|
||||
accumulated_hidden = unpert_last_hidden[:, :-1, :]
|
||||
accumulated_hidden = torch.sum(accumulated_hidden, dim=1)
|
||||
|
||||
if past is not None:
|
||||
pert_past, _, grad_norms, loss_this_iter = perturb_past(
|
||||
past,
|
||||
model,
|
||||
last,
|
||||
unpert_past=unpert_past,
|
||||
unpert_logits=unpert_logits,
|
||||
accumulated_hidden=accumulated_hidden,
|
||||
grad_norms=grad_norms,
|
||||
stepsize=current_stepsize,
|
||||
one_hot_bows_vectors=one_hot_bows_vectors,
|
||||
classifier=classifier,
|
||||
class_label=class_label,
|
||||
loss_type=loss_type,
|
||||
num_iterations=num_iterations,
|
||||
horizon_length=horizon_length,
|
||||
window_length=window_length,
|
||||
decay=decay,
|
||||
gamma=gamma,
|
||||
kl_scale=kl_scale,
|
||||
device=device,
|
||||
)
|
||||
loss_in_time.append(loss_this_iter)
|
||||
else:
|
||||
pert_past = past
|
||||
|
||||
lm_output = model(last, past_key_values=pert_past)
|
||||
pert_logits, past = (
|
||||
lm_output["logits"],
|
||||
lm_output["past_key_values"],
|
||||
)
|
||||
pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST
|
||||
|
||||
for token_idx in set(output_so_far[0].tolist()):
|
||||
if pert_logits[0, token_idx] < 0:
|
||||
pert_logits[0, token_idx] *= repetition_penalty
|
||||
else:
|
||||
pert_logits[0, token_idx] /= repetition_penalty
|
||||
|
||||
pert_probs = F.softmax(pert_logits, dim=-1)
|
||||
|
||||
if classifier is not None:
|
||||
ce_loss = torch.nn.CrossEntropyLoss()
|
||||
prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
|
||||
label = torch.tensor([class_label], device=device, dtype=torch.long)
|
||||
unpert_discrim_loss = ce_loss(prediction, label)
|
||||
print("unperturbed discrim loss", unpert_discrim_loss.data.cpu().numpy())
|
||||
else:
|
||||
unpert_discrim_loss = 0
|
||||
|
||||
# Fuse the modified model and original model
|
||||
if perturb:
|
||||
|
||||
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||
|
||||
pert_probs = (pert_probs ** gm_scale) * (unpert_probs ** (1 - gm_scale)) # + SMALL_CONST
|
||||
pert_probs = top_k_filter(pert_probs, k=top_k, probs=True) # + SMALL_CONST
|
||||
|
||||
# rescale
|
||||
if torch.sum(pert_probs) <= 1:
|
||||
pert_probs = pert_probs / torch.sum(pert_probs)
|
||||
|
||||
else:
|
||||
pert_logits = top_k_filter(pert_logits, k=top_k) # + SMALL_CONST
|
||||
pert_probs = F.softmax(pert_logits, dim=-1)
|
||||
|
||||
# sample or greedy
|
||||
if sample:
|
||||
last = torch.multinomial(pert_probs, num_samples=1)
|
||||
|
||||
else:
|
||||
_, last = torch.topk(pert_probs, k=1, dim=-1)
|
||||
|
||||
# update context/output_so_far appending the new token
|
||||
output_so_far = last if output_so_far is None else torch.cat((output_so_far, last), dim=1)
|
||||
|
||||
print(tokenizer.decode(output_so_far.tolist()[0]))
|
||||
|
||||
return output_so_far, unpert_discrim_loss, loss_in_time
|
||||
|
||||
|
||||
def set_generic_model_params(discrim_weights, discrim_meta):
|
||||
if discrim_weights is None:
|
||||
raise ValueError("When using a generic discriminator, discrim_weights need to be specified")
|
||||
if discrim_meta is None:
|
||||
raise ValueError("When using a generic discriminator, discrim_meta need to be specified")
|
||||
|
||||
with open(discrim_meta, "r") as discrim_meta_file:
|
||||
meta = json.load(discrim_meta_file)
|
||||
meta["path"] = discrim_weights
|
||||
DISCRIMINATOR_MODELS_PARAMS["generic"] = meta
|
||||
|
||||
|
||||
def run_pplm_example(
|
||||
pretrained_model="gpt2-medium",
|
||||
cond_text="",
|
||||
uncond=False,
|
||||
num_samples=1,
|
||||
bag_of_words=None,
|
||||
discrim=None,
|
||||
discrim_weights=None,
|
||||
discrim_meta=None,
|
||||
class_label=-1,
|
||||
length=100,
|
||||
stepsize=0.02,
|
||||
temperature=1.0,
|
||||
top_k=10,
|
||||
sample=False,
|
||||
num_iterations=3,
|
||||
grad_length=10000,
|
||||
horizon_length=1,
|
||||
window_length=0,
|
||||
decay=False,
|
||||
gamma=1.5,
|
||||
gm_scale=0.9,
|
||||
kl_scale=0.01,
|
||||
seed=0,
|
||||
no_cuda=False,
|
||||
colorama=False,
|
||||
repetition_penalty=1.0,
|
||||
):
|
||||
# set Random seed
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
# set the device
|
||||
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
|
||||
|
||||
if discrim == "generic":
|
||||
set_generic_model_params(discrim_weights, discrim_meta)
|
||||
|
||||
if discrim is not None:
|
||||
pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim]["pretrained_model"]
|
||||
print("discrim = {}, pretrained_model set to discriminator's = {}".format(discrim, pretrained_model))
|
||||
|
||||
# load pretrained model
|
||||
model = GPT2LMHeadModel.from_pretrained(pretrained_model, output_hidden_states=True)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# load tokenizer
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)
|
||||
|
||||
# Freeze GPT-2 weights
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# figure out conditioning text
|
||||
if uncond:
|
||||
tokenized_cond_text = tokenizer.encode([tokenizer.bos_token])
|
||||
else:
|
||||
raw_text = cond_text
|
||||
while not raw_text:
|
||||
print("Did you forget to add `--cond_text`? ")
|
||||
raw_text = input("Model prompt >>> ")
|
||||
tokenized_cond_text = tokenizer.encode(tokenizer.bos_token + raw_text)
|
||||
|
||||
print("= Prefix of sentence =")
|
||||
print(tokenizer.decode(tokenized_cond_text))
|
||||
print()
|
||||
|
||||
# generate unperturbed and perturbed texts
|
||||
|
||||
# full_text_generation returns:
|
||||
# unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
|
||||
unpert_gen_tok_text, pert_gen_tok_texts, _, _ = full_text_generation(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
context=tokenized_cond_text,
|
||||
device=device,
|
||||
num_samples=num_samples,
|
||||
bag_of_words=bag_of_words,
|
||||
discrim=discrim,
|
||||
class_label=class_label,
|
||||
length=length,
|
||||
stepsize=stepsize,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
sample=sample,
|
||||
num_iterations=num_iterations,
|
||||
grad_length=grad_length,
|
||||
horizon_length=horizon_length,
|
||||
window_length=window_length,
|
||||
decay=decay,
|
||||
gamma=gamma,
|
||||
gm_scale=gm_scale,
|
||||
kl_scale=kl_scale,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
|
||||
# untokenize unperturbed text
|
||||
unpert_gen_text = tokenizer.decode(unpert_gen_tok_text.tolist()[0])
|
||||
|
||||
print("=" * 80)
|
||||
print("= Unperturbed generated text =")
|
||||
print(unpert_gen_text)
|
||||
print()
|
||||
|
||||
generated_texts = []
|
||||
|
||||
bow_word_ids = set()
|
||||
if bag_of_words and colorama:
|
||||
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"), tokenizer)
|
||||
for single_bow_list in bow_indices:
|
||||
# filtering all words in the list composed of more than 1 token
|
||||
filtered = list(filter(lambda x: len(x) <= 1, single_bow_list))
|
||||
# w[0] because we are sure w has only 1 item because previous fitler
|
||||
bow_word_ids.update(w[0] for w in filtered)
|
||||
|
||||
# iterate through the perturbed texts
|
||||
for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts):
|
||||
try:
|
||||
# untokenize unperturbed text
|
||||
if colorama:
|
||||
import colorama
|
||||
|
||||
pert_gen_text = ""
|
||||
for word_id in pert_gen_tok_text.tolist()[0]:
|
||||
if word_id in bow_word_ids:
|
||||
pert_gen_text += "{}{}{}".format(
|
||||
colorama.Fore.RED,
|
||||
tokenizer.decode([word_id]),
|
||||
colorama.Style.RESET_ALL,
|
||||
)
|
||||
else:
|
||||
pert_gen_text += tokenizer.decode([word_id])
|
||||
else:
|
||||
pert_gen_text = tokenizer.decode(pert_gen_tok_text.tolist()[0])
|
||||
|
||||
print("= Perturbed generated text {} =".format(i + 1))
|
||||
print(pert_gen_text)
|
||||
print()
|
||||
except Exception as exc:
|
||||
print("Ignoring error while generating perturbed text:", exc)
|
||||
|
||||
# keep the prefix, perturbed seq, original seq for each index
|
||||
generated_texts.append((tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text))
|
||||
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--pretrained_model",
|
||||
"-M",
|
||||
type=str,
|
||||
default="gpt2-medium",
|
||||
help="pretrained model name or path to local checkpoint",
|
||||
)
|
||||
parser.add_argument("--cond_text", type=str, default="The lake", help="Prefix texts to condition on")
|
||||
parser.add_argument("--uncond", action="store_true", help="Generate from end-of-text as prefix")
|
||||
parser.add_argument(
|
||||
"--num_samples",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of samples to generate from the modified latents",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bag_of_words",
|
||||
"-B",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Bags of words used for PPLM-BoW. "
|
||||
"Either a BOW id (see list in code) or a filepath. "
|
||||
"Multiple BoWs separated by ;"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--discrim",
|
||||
"-D",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=("clickbait", "sentiment", "toxicity", "generic"),
|
||||
help="Discriminator to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--discrim_weights",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Weights for the generic discriminator",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--discrim_meta",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Meta information for the generic discriminator",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--class_label",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Class label used for the discriminator",
|
||||
)
|
||||
parser.add_argument("--length", type=int, default=100)
|
||||
parser.add_argument("--stepsize", type=float, default=0.02)
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--top_k", type=int, default=10)
|
||||
parser.add_argument("--sample", action="store_true", help="Generate from end-of-text as prefix")
|
||||
parser.add_argument("--num_iterations", type=int, default=3)
|
||||
parser.add_argument("--grad_length", type=int, default=10000)
|
||||
parser.add_argument(
|
||||
"--window_length",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Length of past which is being optimized; 0 corresponds to infinite window length",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--horizon_length",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Length of future to optimize over",
|
||||
)
|
||||
parser.add_argument("--decay", action="store_true", help="whether to decay or not")
|
||||
parser.add_argument("--gamma", type=float, default=1.5)
|
||||
parser.add_argument("--gm_scale", type=float, default=0.9)
|
||||
parser.add_argument("--kl_scale", type=float, default=0.01)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
|
||||
parser.add_argument("--colorama", action="store_true", help="colors keywords")
|
||||
parser.add_argument(
|
||||
"--repetition_penalty",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Penalize repetition. More than 1.0 -> less repetition",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
run_pplm_example(**vars(args))
|
||||
522
examples/research_projects/pplm/run_pplm_discrim_train.py
Normal file
522
examples/research_projects/pplm/run_pplm_discrim_train.py
Normal file
@@ -0,0 +1,522 @@
|
||||
#! /usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright (c) 2019 Uber Technologies, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import math
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import torch.utils.data as data
|
||||
from nltk.tokenize.treebank import TreebankWordDetokenizer
|
||||
from torchtext import data as torchtext_data
|
||||
from torchtext import datasets
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from pplm_classification_head import ClassificationHead
|
||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
EPSILON = 1e-10
|
||||
example_sentence = "This is incredible! I love it, this is the best chicken I have ever had."
|
||||
max_length_seq = 100
|
||||
|
||||
|
||||
class Discriminator(torch.nn.Module):
|
||||
"""Transformer encoder followed by a Classification Head"""
|
||||
|
||||
def __init__(self, class_size, pretrained_model="gpt2-medium", cached_mode=False, device="cpu"):
|
||||
super().__init__()
|
||||
self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)
|
||||
self.encoder = GPT2LMHeadModel.from_pretrained(pretrained_model)
|
||||
self.embed_size = self.encoder.transformer.config.hidden_size
|
||||
self.classifier_head = ClassificationHead(class_size=class_size, embed_size=self.embed_size)
|
||||
self.cached_mode = cached_mode
|
||||
self.device = device
|
||||
|
||||
def get_classifier(self):
|
||||
return self.classifier_head
|
||||
|
||||
def train_custom(self):
|
||||
for param in self.encoder.parameters():
|
||||
param.requires_grad = False
|
||||
self.classifier_head.train()
|
||||
|
||||
def avg_representation(self, x):
|
||||
mask = x.ne(0).unsqueeze(2).repeat(1, 1, self.embed_size).float().to(self.device).detach()
|
||||
hidden = self.encoder.transformer(x)["last_hidden_state"]
|
||||
masked_hidden = hidden * mask
|
||||
avg_hidden = torch.sum(masked_hidden, dim=1) / (torch.sum(mask, dim=1).detach() + EPSILON)
|
||||
return avg_hidden
|
||||
|
||||
def forward(self, x):
|
||||
if self.cached_mode:
|
||||
avg_hidden = x.to(self.device)
|
||||
else:
|
||||
avg_hidden = self.avg_representation(x.to(self.device))
|
||||
|
||||
logits = self.classifier_head(avg_hidden)
|
||||
probs = F.log_softmax(logits, dim=-1)
|
||||
|
||||
return probs
|
||||
|
||||
|
||||
class Dataset(data.Dataset):
|
||||
def __init__(self, X, y):
|
||||
"""Reads source and target sequences from txt files."""
|
||||
self.X = X
|
||||
self.y = y
|
||||
|
||||
def __len__(self):
|
||||
return len(self.X)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Returns one data pair (source and target)."""
|
||||
data = {}
|
||||
data["X"] = self.X[index]
|
||||
data["y"] = self.y[index]
|
||||
return data
|
||||
|
||||
|
||||
def collate_fn(data):
|
||||
def pad_sequences(sequences):
|
||||
lengths = [len(seq) for seq in sequences]
|
||||
|
||||
padded_sequences = torch.zeros(len(sequences), max(lengths)).long() # padding value = 0
|
||||
|
||||
for i, seq in enumerate(sequences):
|
||||
end = lengths[i]
|
||||
padded_sequences[i, :end] = seq[:end]
|
||||
|
||||
return padded_sequences, lengths
|
||||
|
||||
item_info = {}
|
||||
for key in data[0].keys():
|
||||
item_info[key] = [d[key] for d in data]
|
||||
|
||||
x_batch, _ = pad_sequences(item_info["X"])
|
||||
y_batch = torch.tensor(item_info["y"], dtype=torch.long)
|
||||
|
||||
return x_batch, y_batch
|
||||
|
||||
|
||||
def cached_collate_fn(data):
|
||||
item_info = {}
|
||||
for key in data[0].keys():
|
||||
item_info[key] = [d[key] for d in data]
|
||||
|
||||
x_batch = torch.cat(item_info["X"], 0)
|
||||
y_batch = torch.tensor(item_info["y"], dtype=torch.long)
|
||||
|
||||
return x_batch, y_batch
|
||||
|
||||
|
||||
def train_epoch(data_loader, discriminator, optimizer, epoch=0, log_interval=10, device="cpu"):
|
||||
samples_so_far = 0
|
||||
discriminator.train_custom()
|
||||
for batch_idx, (input_t, target_t) in enumerate(data_loader):
|
||||
input_t, target_t = input_t.to(device), target_t.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
output_t = discriminator(input_t)
|
||||
loss = F.nll_loss(output_t, target_t)
|
||||
loss.backward(retain_graph=True)
|
||||
optimizer.step()
|
||||
|
||||
samples_so_far += len(input_t)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
print(
|
||||
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
|
||||
epoch + 1,
|
||||
samples_so_far,
|
||||
len(data_loader.dataset),
|
||||
100 * samples_so_far / len(data_loader.dataset),
|
||||
loss.item(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def evaluate_performance(data_loader, discriminator, device="cpu"):
|
||||
discriminator.eval()
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for input_t, target_t in data_loader:
|
||||
input_t, target_t = input_t.to(device), target_t.to(device)
|
||||
output_t = discriminator(input_t)
|
||||
# sum up batch loss
|
||||
test_loss += F.nll_loss(output_t, target_t, reduction="sum").item()
|
||||
# get the index of the max log-probability
|
||||
pred_t = output_t.argmax(dim=1, keepdim=True)
|
||||
correct += pred_t.eq(target_t.view_as(pred_t)).sum().item()
|
||||
|
||||
test_loss /= len(data_loader.dataset)
|
||||
|
||||
print(
|
||||
"Performance on test set: "
|
||||
"Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format(
|
||||
test_loss, correct, len(data_loader.dataset), 100.0 * correct / len(data_loader.dataset)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def predict(input_sentence, model, classes, cached=False, device="cpu"):
|
||||
input_t = model.tokenizer.encode(input_sentence)
|
||||
input_t = torch.tensor([input_t], dtype=torch.long, device=device)
|
||||
if cached:
|
||||
input_t = model.avg_representation(input_t)
|
||||
|
||||
log_probs = model(input_t).data.cpu().numpy().flatten().tolist()
|
||||
print("Input sentence:", input_sentence)
|
||||
print(
|
||||
"Predictions:",
|
||||
", ".join("{}: {:.4f}".format(c, math.exp(log_prob)) for c, log_prob in zip(classes, log_probs)),
|
||||
)
|
||||
|
||||
|
||||
def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False, device="cpu"):
|
||||
data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, collate_fn=collate_fn)
|
||||
|
||||
xs = []
|
||||
ys = []
|
||||
for batch_idx, (x, y) in enumerate(tqdm(data_loader, ascii=True)):
|
||||
with torch.no_grad():
|
||||
x = x.to(device)
|
||||
avg_rep = discriminator.avg_representation(x).cpu().detach()
|
||||
avg_rep_list = torch.unbind(avg_rep.unsqueeze(1))
|
||||
xs += avg_rep_list
|
||||
ys += y.cpu().numpy().tolist()
|
||||
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset=Dataset(xs, ys), batch_size=batch_size, shuffle=shuffle, collate_fn=cached_collate_fn
|
||||
)
|
||||
|
||||
return data_loader
|
||||
|
||||
|
||||
def train_discriminator(
|
||||
dataset,
|
||||
dataset_fp=None,
|
||||
pretrained_model="gpt2-medium",
|
||||
epochs=10,
|
||||
batch_size=64,
|
||||
log_interval=10,
|
||||
save_model=False,
|
||||
cached=False,
|
||||
no_cuda=False,
|
||||
):
|
||||
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
|
||||
|
||||
print("Preprocessing {} dataset...".format(dataset))
|
||||
start = time.time()
|
||||
|
||||
if dataset == "SST":
|
||||
idx2class = ["positive", "negative", "very positive", "very negative", "neutral"]
|
||||
class2idx = {c: i for i, c in enumerate(idx2class)}
|
||||
|
||||
discriminator = Discriminator(
|
||||
class_size=len(idx2class), pretrained_model=pretrained_model, cached_mode=cached, device=device
|
||||
).to(device)
|
||||
|
||||
text = torchtext_data.Field()
|
||||
label = torchtext_data.Field(sequential=False)
|
||||
train_data, val_data, test_data = datasets.SST.splits(
|
||||
text,
|
||||
label,
|
||||
fine_grained=True,
|
||||
train_subtrees=True,
|
||||
)
|
||||
|
||||
x = []
|
||||
y = []
|
||||
for i in trange(len(train_data), ascii=True):
|
||||
seq = TreebankWordDetokenizer().detokenize(vars(train_data[i])["text"])
|
||||
seq = discriminator.tokenizer.encode(seq)
|
||||
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
|
||||
x.append(seq)
|
||||
y.append(class2idx[vars(train_data[i])["label"]])
|
||||
train_dataset = Dataset(x, y)
|
||||
|
||||
test_x = []
|
||||
test_y = []
|
||||
for i in trange(len(test_data), ascii=True):
|
||||
seq = TreebankWordDetokenizer().detokenize(vars(test_data[i])["text"])
|
||||
seq = discriminator.tokenizer.encode(seq)
|
||||
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
|
||||
test_x.append(seq)
|
||||
test_y.append(class2idx[vars(test_data[i])["label"]])
|
||||
test_dataset = Dataset(test_x, test_y)
|
||||
|
||||
discriminator_meta = {
|
||||
"class_size": len(idx2class),
|
||||
"embed_size": discriminator.embed_size,
|
||||
"pretrained_model": pretrained_model,
|
||||
"class_vocab": class2idx,
|
||||
"default_class": 2,
|
||||
}
|
||||
|
||||
elif dataset == "clickbait":
|
||||
idx2class = ["non_clickbait", "clickbait"]
|
||||
class2idx = {c: i for i, c in enumerate(idx2class)}
|
||||
|
||||
discriminator = Discriminator(
|
||||
class_size=len(idx2class), pretrained_model=pretrained_model, cached_mode=cached, device=device
|
||||
).to(device)
|
||||
|
||||
with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
|
||||
data = []
|
||||
for i, line in enumerate(f):
|
||||
try:
|
||||
data.append(eval(line))
|
||||
except Exception:
|
||||
print("Error evaluating line {}: {}".format(i, line))
|
||||
continue
|
||||
x = []
|
||||
y = []
|
||||
with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
|
||||
for i, line in enumerate(tqdm(f, ascii=True)):
|
||||
try:
|
||||
d = eval(line)
|
||||
seq = discriminator.tokenizer.encode(d["text"])
|
||||
|
||||
if len(seq) < max_length_seq:
|
||||
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
|
||||
else:
|
||||
print("Line {} is longer than maximum length {}".format(i, max_length_seq))
|
||||
continue
|
||||
x.append(seq)
|
||||
y.append(d["label"])
|
||||
except Exception:
|
||||
print("Error evaluating / tokenizing" " line {}, skipping it".format(i))
|
||||
pass
|
||||
|
||||
full_dataset = Dataset(x, y)
|
||||
train_size = int(0.9 * len(full_dataset))
|
||||
test_size = len(full_dataset) - train_size
|
||||
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
|
||||
|
||||
discriminator_meta = {
|
||||
"class_size": len(idx2class),
|
||||
"embed_size": discriminator.embed_size,
|
||||
"pretrained_model": pretrained_model,
|
||||
"class_vocab": class2idx,
|
||||
"default_class": 1,
|
||||
}
|
||||
|
||||
elif dataset == "toxic":
|
||||
idx2class = ["non_toxic", "toxic"]
|
||||
class2idx = {c: i for i, c in enumerate(idx2class)}
|
||||
|
||||
discriminator = Discriminator(
|
||||
class_size=len(idx2class), pretrained_model=pretrained_model, cached_mode=cached, device=device
|
||||
).to(device)
|
||||
|
||||
x = []
|
||||
y = []
|
||||
with open("datasets/toxic/toxic_train.txt") as f:
|
||||
for i, line in enumerate(tqdm(f, ascii=True)):
|
||||
try:
|
||||
d = eval(line)
|
||||
seq = discriminator.tokenizer.encode(d["text"])
|
||||
|
||||
if len(seq) < max_length_seq:
|
||||
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
|
||||
else:
|
||||
print("Line {} is longer than maximum length {}".format(i, max_length_seq))
|
||||
continue
|
||||
x.append(seq)
|
||||
y.append(int(np.sum(d["label"]) > 0))
|
||||
except Exception:
|
||||
print("Error evaluating / tokenizing" " line {}, skipping it".format(i))
|
||||
pass
|
||||
|
||||
full_dataset = Dataset(x, y)
|
||||
train_size = int(0.9 * len(full_dataset))
|
||||
test_size = len(full_dataset) - train_size
|
||||
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
|
||||
|
||||
discriminator_meta = {
|
||||
"class_size": len(idx2class),
|
||||
"embed_size": discriminator.embed_size,
|
||||
"pretrained_model": pretrained_model,
|
||||
"class_vocab": class2idx,
|
||||
"default_class": 0,
|
||||
}
|
||||
|
||||
else: # if dataset == "generic":
|
||||
# This assumes the input dataset is a TSV with the following structure:
|
||||
# class \t text
|
||||
|
||||
if dataset_fp is None:
|
||||
raise ValueError("When generic dataset is selected, " "dataset_fp needs to be specified aswell.")
|
||||
|
||||
classes = set()
|
||||
with open(dataset_fp) as f:
|
||||
csv_reader = csv.reader(f, delimiter="\t")
|
||||
for row in tqdm(csv_reader, ascii=True):
|
||||
if row:
|
||||
classes.add(row[0])
|
||||
|
||||
idx2class = sorted(classes)
|
||||
class2idx = {c: i for i, c in enumerate(idx2class)}
|
||||
|
||||
discriminator = Discriminator(
|
||||
class_size=len(idx2class), pretrained_model=pretrained_model, cached_mode=cached, device=device
|
||||
).to(device)
|
||||
|
||||
x = []
|
||||
y = []
|
||||
with open(dataset_fp) as f:
|
||||
csv_reader = csv.reader(f, delimiter="\t")
|
||||
for i, row in enumerate(tqdm(csv_reader, ascii=True)):
|
||||
if row:
|
||||
label = row[0]
|
||||
text = row[1]
|
||||
|
||||
try:
|
||||
seq = discriminator.tokenizer.encode(text)
|
||||
if len(seq) < max_length_seq:
|
||||
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
|
||||
|
||||
else:
|
||||
print("Line {} is longer than maximum length {}".format(i, max_length_seq))
|
||||
continue
|
||||
|
||||
x.append(seq)
|
||||
y.append(class2idx[label])
|
||||
|
||||
except Exception:
|
||||
print("Error tokenizing line {}, skipping it".format(i))
|
||||
pass
|
||||
|
||||
full_dataset = Dataset(x, y)
|
||||
train_size = int(0.9 * len(full_dataset))
|
||||
test_size = len(full_dataset) - train_size
|
||||
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
|
||||
|
||||
discriminator_meta = {
|
||||
"class_size": len(idx2class),
|
||||
"embed_size": discriminator.embed_size,
|
||||
"pretrained_model": pretrained_model,
|
||||
"class_vocab": class2idx,
|
||||
"default_class": 0,
|
||||
}
|
||||
|
||||
end = time.time()
|
||||
print("Preprocessed {} data points".format(len(train_dataset) + len(test_dataset)))
|
||||
print("Data preprocessing took: {:.3f}s".format(end - start))
|
||||
|
||||
if cached:
|
||||
print("Building representation cache...")
|
||||
|
||||
start = time.time()
|
||||
|
||||
train_loader = get_cached_data_loader(train_dataset, batch_size, discriminator, shuffle=True, device=device)
|
||||
|
||||
test_loader = get_cached_data_loader(test_dataset, batch_size, discriminator, device=device)
|
||||
|
||||
end = time.time()
|
||||
print("Building representation cache took: {:.3f}s".format(end - start))
|
||||
|
||||
else:
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
dataset=train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
|
||||
)
|
||||
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, collate_fn=collate_fn)
|
||||
|
||||
if save_model:
|
||||
with open("{}_classifier_head_meta.json".format(dataset), "w") as meta_file:
|
||||
json.dump(discriminator_meta, meta_file)
|
||||
|
||||
optimizer = optim.Adam(discriminator.parameters(), lr=0.0001)
|
||||
|
||||
for epoch in range(epochs):
|
||||
start = time.time()
|
||||
print("\nEpoch", epoch + 1)
|
||||
|
||||
train_epoch(
|
||||
discriminator=discriminator,
|
||||
data_loader=train_loader,
|
||||
optimizer=optimizer,
|
||||
epoch=epoch,
|
||||
log_interval=log_interval,
|
||||
device=device,
|
||||
)
|
||||
evaluate_performance(data_loader=test_loader, discriminator=discriminator, device=device)
|
||||
|
||||
end = time.time()
|
||||
print("Epoch took: {:.3f}s".format(end - start))
|
||||
|
||||
print("\nExample prediction")
|
||||
predict(example_sentence, discriminator, idx2class, cached=cached, device=device)
|
||||
|
||||
if save_model:
|
||||
# torch.save(discriminator.state_dict(),
|
||||
# "{}_discriminator_{}.pt".format(
|
||||
# args.dataset, epoch + 1
|
||||
# ))
|
||||
torch.save(
|
||||
discriminator.get_classifier().state_dict(),
|
||||
"{}_classifier_head_epoch_{}.pt".format(dataset, epoch + 1),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Train a discriminator on top of GPT-2 representations")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="SST",
|
||||
choices=("SST", "clickbait", "toxic", "generic"),
|
||||
help="dataset to train the discriminator on."
|
||||
"In case of generic, the dataset is expected"
|
||||
"to be a TSBV file with structure: class \\t text",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_fp",
|
||||
type=str,
|
||||
default="",
|
||||
help="File path of the dataset to use. " "Needed only in case of generic datadset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_model", type=str, default="gpt2-medium", help="Pretrained model to use as encoder"
|
||||
)
|
||||
parser.add_argument("--epochs", type=int, default=10, metavar="N", help="Number of training epochs")
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log_interval",
|
||||
type=int,
|
||||
default=10,
|
||||
metavar="N",
|
||||
help="how many batches to wait before logging training status",
|
||||
)
|
||||
parser.add_argument("--save_model", action="store_true", help="whether to save the model")
|
||||
parser.add_argument("--cached", action="store_true", help="whether to cache the input representations")
|
||||
parser.add_argument("--no_cuda", action="store_true", help="use to turn off cuda")
|
||||
args = parser.parse_args()
|
||||
|
||||
train_discriminator(**(vars(args)))
|
||||
161
examples/research_projects/rag/README.md
Normal file
161
examples/research_projects/rag/README.md
Normal file
@@ -0,0 +1,161 @@
|
||||
# Intro
|
||||
|
||||
Authors: @patrickvonplaten and @lhoestq
|
||||
|
||||
Aimed at tackling the knowledge-intensive NLP tasks (think tasks a human wouldn't be expected to solve without access to external knowledge sources), RAG models are seq2seq models with access to a retrieval mechanism providing relevant context documents at training and evaluation time.
|
||||
|
||||
A RAG model encapsulates two core components: a question encoder and a generator.
|
||||
During a forward pass, we encode the input with the question encoder and pass it
|
||||
to the retriever to extract relevant context documents. The documents are then prepended to the input.
|
||||
Such contextualized inputs are passed to the generator.
|
||||
|
||||
Read more about RAG at https://arxiv.org/abs/2005.11401.
|
||||
|
||||
# Finetuning
|
||||
|
||||
Our finetuning logic is based on scripts from [`examples/seq2seq`](https://github.com/huggingface/transformers/tree/master/examples/seq2seq). We accept training data in the same format as specified there - we expect a directory consisting of 6 text files:
|
||||
```bash
|
||||
train.source
|
||||
train.target
|
||||
val.source
|
||||
val.target
|
||||
test.source
|
||||
test.target
|
||||
```
|
||||
|
||||
A sample finetuning command (run ` ./examples/rag/finetune_rag.py --help` to list all available options):
|
||||
|
||||
```bash
|
||||
python examples/rag/finetune_rag.py \
|
||||
--data_dir $DATA_DIR \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--model_name_or_path $MODEL_NAME_OR_PATH \
|
||||
--model_type rag_sequence \
|
||||
--fp16 \
|
||||
--gpus 8
|
||||
```
|
||||
We publish two `base` models which can serve as a starting point for finetuning on downstream tasks (use them as `model_name_or_path`):
|
||||
- [`facebook/rag-sequence-base`](https://huggingface.co/facebook/rag-sequence-base) - a base for finetuning `RagSequenceForGeneration` models,
|
||||
- [`facebook/rag-token-base`](https://huggingface.co/facebook/rag-token-base) - a base for finetuning `RagTokenForGeneration` models.
|
||||
|
||||
The `base` models initialize the question encoder with [`facebook/dpr-question_encoder-single-nq-base`](https://huggingface.co/facebook/dpr-question_encoder-single-nq-base) and the generator with [`facebook/bart-large`](https://huggingface.co/facebook/bart-large).
|
||||
|
||||
If you would like to initialize finetuning with a base model using different question encoder and generator architectures, you can build it with a consolidation script, e.g.:
|
||||
```
|
||||
python examples/rag/consolidate_rag_checkpoint.py \
|
||||
--model_type rag_sequence \
|
||||
--generator_name_or_path facebook/bart-large-cnn \
|
||||
--question_encoder_name_or_path facebook/dpr-question_encoder-single-nq-base \
|
||||
--dest path/to/checkpoint
|
||||
```
|
||||
You will then be able to pass `path/to/checkpoint` as `model_name_or_path` to the `finetune_rag.py` script.
|
||||
|
||||
|
||||
# Evaluation
|
||||
Our evaluation script enables two modes of evaluation (controlled by the `eval_mode` argument): `e2e` - end2end evaluation, returns EM (exact match) and F1 scores calculated for the downstream task and `retrieval` - which returns precision@k of the documents retrieved for provided inputs.
|
||||
|
||||
The evaluation script expects paths to two files:
|
||||
- `evaluation_set` - a path to a file specifying the evaluation dataset, a single input per line.
|
||||
- `gold_data_path` - a path to a file contaning ground truth answers for datapoints from the `evaluation_set`, a single output per line. Check below for expected formats of the gold data files.
|
||||
|
||||
|
||||
## Retrieval evaluation
|
||||
For `retrieval` evaluation, we expect a gold data file where each line will consist of a tab-separated list of document titles constituting positive contexts for respective datapoints from the `evaluation_set`. E.g. given a question `who sings does he love me with reba` in the `evaluation_set`, a respective ground truth line could look as follows:
|
||||
```
|
||||
Does He Love You Does He Love You Red Sandy Spika dress of Reba McEntire Greatest Hits Volume Two (Reba McEntire album) Shoot for the Moon (album)
|
||||
```
|
||||
|
||||
We demonstrate how to evaluate retrieval against DPR evaluation data. You can download respective files from links listed [here](https://github.com/facebookresearch/DPR/blob/master/data/download_data.py#L39-L45).
|
||||
|
||||
1. Download and unzip the gold data file. We use the `biencoder-nq-dev` from https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz.
|
||||
```bash
|
||||
wget https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz && gzip -d biencoder-nq-dev.json.gz
|
||||
```
|
||||
|
||||
2. Parse the unziped file using the `parse_dpr_relevance_data.py`
|
||||
```bash
|
||||
mkdir output # or wherever you want to save this
|
||||
python examples/rag/parse_dpr_relevance_data.py \
|
||||
--src_path biencoder-nq-dev.json \
|
||||
--evaluation_set output/biencoder-nq-dev.questions \
|
||||
--gold_data_path output/biencoder-nq-dev.pages
|
||||
```
|
||||
3. Run evaluation:
|
||||
```bash
|
||||
python examples/rag/eval_rag.py \
|
||||
--model_name_or_path facebook/rag-sequence-nq \
|
||||
--model_type rag_sequence \
|
||||
--evaluation_set output/biencoder-nq-dev.questions \
|
||||
--gold_data_path output/biencoder-nq-dev.pages \
|
||||
--predictions_path output/retrieval_preds.tsv \
|
||||
--eval_mode retrieval \
|
||||
--k 1
|
||||
```
|
||||
```bash
|
||||
# EXPLANATION
|
||||
python examples/rag/eval_rag.py \
|
||||
--model_name_or_path facebook/rag-sequence-nq \ # model name or path of the model we're evaluating
|
||||
--model_type rag_sequence \ # RAG model type (rag_token or rag_sequence)
|
||||
--evaluation_set output/biencoder-nq-dev.questions \ # an input dataset for evaluation
|
||||
--gold_data_path poutput/biencoder-nq-dev.pages \ # a dataset containing ground truth answers for samples from the evaluation_set
|
||||
--predictions_path output/retrieval_preds.tsv \ # name of file where predictions will be stored
|
||||
--eval_mode retrieval \ # indicates whether we're performing retrieval evaluation or e2e evaluation
|
||||
--k 1 # parameter k for the precision@k metric
|
||||
|
||||
```
|
||||
## End-to-end evaluation
|
||||
|
||||
We support two formats of the gold data file (controlled by the `gold_data_mode` parameter):
|
||||
- `qa` - where a single line has the following format: `input [tab] output_list`, e.g.:
|
||||
```
|
||||
who is the owner of reading football club ['Xiu Li Dai', 'Dai Yongge', 'Dai Xiuli', 'Yongge Dai']
|
||||
```
|
||||
- `ans` - where a single line contains a single expected answer, e.g.:
|
||||
```
|
||||
Xiu Li Dai
|
||||
```
|
||||
|
||||
Predictions of the model for the samples from the `evaluation_set` will be saved under the path specified by the `predictions_path` parameter.
|
||||
If this path already exists, the script will use saved predictions to calculate metrics.
|
||||
Add `--recalculate` parameter to force the script to perform inference from scratch.
|
||||
|
||||
An example e2e evaluation run could look as follows:
|
||||
```bash
|
||||
python examples/rag/eval_rag.py \
|
||||
--model_name_or_path facebook/rag-sequence-nq \
|
||||
--model_type rag_sequence \
|
||||
--evaluation_set path/to/test.source \
|
||||
--gold_data_path path/to/gold_data \
|
||||
--predictions_path path/to/e2e_preds.txt \
|
||||
--eval_mode e2e \
|
||||
--gold_data_mode qa \
|
||||
--n_docs 5 \ # You can experiment with retrieving different number of documents at evaluation time
|
||||
--print_predictions \
|
||||
--recalculate \ # adding this parameter will force recalculating predictions even if predictions_path already exists
|
||||
```
|
||||
|
||||
# Use your own knowledge source
|
||||
|
||||
By default, RAG uses the English Wikipedia as a knowledge source, known as the 'wiki_dpr' dataset.
|
||||
With `use_custom_knowledge_dataset.py` you can build your own knowledge source, *e.g.* for RAG.
|
||||
|
||||
For instance, if documents are serialized as tab-separated csv files with the columns "title" and "text", one can use `use_own_knowledge_dataset.py` as follows:
|
||||
```bash
|
||||
python examples/rag/use_own_knowledge_dataset.py \
|
||||
--csv_path path/to/my_csv \
|
||||
--output_dir path/to/my_knowledge_dataset \
|
||||
```
|
||||
|
||||
The created outputs in `path/to/my_knowledge_dataset` can then be used to finetune RAG as follows:
|
||||
```bash
|
||||
python examples/rag/finetune_rag.py \
|
||||
--data_dir $DATA_DIR \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--model_name_or_path $MODEL_NAME_OR_PATH \
|
||||
--model_type rag_sequence \
|
||||
--fp16 \
|
||||
--gpus 8
|
||||
--index_name custom
|
||||
--passages_path path/to/data/my_knowledge_dataset
|
||||
--index_path path/to/my_knowledge_dataset_hnsw_index.faiss
|
||||
```
|
||||
5
examples/research_projects/rag/__init__.py
Normal file
5
examples/research_projects/rag/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.insert(1, os.path.dirname(os.path.realpath(__file__)))
|
||||
96
examples/research_projects/rag/_test_finetune_rag.py
Normal file
96
examples/research_projects/rag/_test_finetune_rag.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import finetune_rag
|
||||
from transformers.file_utils import is_apex_available
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
execute_subprocess_async,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
)
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class RagFinetuneExampleTests(TestCasePlus):
|
||||
def _create_dummy_data(self, data_dir):
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
contents = {"source": "What is love ?", "target": "life"}
|
||||
n_lines = {"train": 12, "val": 2, "test": 2}
|
||||
for split in ["train", "test", "val"]:
|
||||
for field in ["source", "target"]:
|
||||
content = "\n".join([contents[field]] * n_lines[split])
|
||||
with open(os.path.join(data_dir, f"{split}.{field}"), "w") as f:
|
||||
f.write(content)
|
||||
|
||||
def _run_finetune(self, gpus: int):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
output_dir = os.path.join(tmp_dir, "output")
|
||||
data_dir = os.path.join(tmp_dir, "data")
|
||||
self._create_dummy_data(data_dir=data_dir)
|
||||
|
||||
testargs = f"""
|
||||
--data_dir {data_dir} \
|
||||
--output_dir {output_dir} \
|
||||
--model_name_or_path facebook/rag-sequence-base \
|
||||
--model_type rag_sequence \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--n_val -1 \
|
||||
--val_check_interval 1.0 \
|
||||
--train_batch_size 2 \
|
||||
--eval_batch_size 1 \
|
||||
--max_source_length 25 \
|
||||
--max_target_length 25 \
|
||||
--val_max_target_length 25 \
|
||||
--test_max_target_length 25 \
|
||||
--label_smoothing 0.1 \
|
||||
--dropout 0.1 \
|
||||
--attention_dropout 0.1 \
|
||||
--weight_decay 0.001 \
|
||||
--adam_epsilon 1e-08 \
|
||||
--max_grad_norm 0.1 \
|
||||
--lr_scheduler polynomial \
|
||||
--learning_rate 3e-04 \
|
||||
--num_train_epochs 1 \
|
||||
--warmup_steps 4 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--distributed-port 8787 \
|
||||
--use_dummy_dataset 1 \
|
||||
""".split()
|
||||
|
||||
if gpus > 0:
|
||||
testargs.append(f"--gpus={gpus}")
|
||||
if is_apex_available():
|
||||
testargs.append("--fp16")
|
||||
else:
|
||||
testargs.append("--gpus=0")
|
||||
testargs.append("--distributed_backend=ddp_cpu")
|
||||
testargs.append("--num_processes=2")
|
||||
|
||||
cmd = [sys.executable, str(Path(finetune_rag.__file__).resolve())] + testargs
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
|
||||
metrics_save_path = os.path.join(output_dir, "metrics.json")
|
||||
with open(metrics_save_path) as f:
|
||||
result = json.load(f)
|
||||
return result
|
||||
|
||||
@require_torch_gpu
|
||||
def test_finetune_gpu(self):
|
||||
result = self._run_finetune(gpus=1)
|
||||
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_finetune_multigpu(self):
|
||||
result = self._run_finetune(gpus=2)
|
||||
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
|
||||
116
examples/research_projects/rag/callbacks_rag.py
Normal file
116
examples/research_projects/rag/callbacks_rag.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
from utils_rag import save_json
|
||||
|
||||
|
||||
def count_trainable_parameters(model):
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
params = sum([np.prod(p.size()) for p in model_parameters])
|
||||
return params
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_checkpoint_callback(output_dir, metric):
|
||||
"""Saves the best model by validation EM score."""
|
||||
if metric == "rouge2":
|
||||
exp = "{val_avg_rouge2:.4f}-{step_count}"
|
||||
elif metric == "bleu":
|
||||
exp = "{val_avg_bleu:.4f}-{step_count}"
|
||||
elif metric == "em":
|
||||
exp = "{val_avg_em:.4f}-{step_count}"
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function."
|
||||
)
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
filepath=os.path.join(output_dir, exp),
|
||||
monitor=f"val_{metric}",
|
||||
mode="max",
|
||||
save_top_k=3,
|
||||
period=1, # maybe save a checkpoint every time val is run, not just end of epoch.
|
||||
)
|
||||
return checkpoint_callback
|
||||
|
||||
|
||||
def get_early_stopping_callback(metric, patience):
|
||||
return EarlyStopping(
|
||||
monitor=f"val_{metric}", # does this need avg?
|
||||
mode="min" if "loss" in metric else "max",
|
||||
patience=patience,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
||||
class Seq2SeqLoggingCallback(pl.Callback):
|
||||
def on_batch_end(self, trainer, pl_module):
|
||||
lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)}
|
||||
pl_module.logger.log_metrics(lrs)
|
||||
|
||||
@rank_zero_only
|
||||
def _write_logs(
|
||||
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
|
||||
) -> None:
|
||||
logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****")
|
||||
metrics = trainer.callback_metrics
|
||||
trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]})
|
||||
# Log results
|
||||
od = Path(pl_module.hparams.output_dir)
|
||||
if type_path == "test":
|
||||
results_file = od / "test_results.txt"
|
||||
generations_file = od / "test_generations.txt"
|
||||
else:
|
||||
# this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
|
||||
# If people want this it will be easy enough to add back.
|
||||
results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt"
|
||||
generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
|
||||
results_file.parent.mkdir(exist_ok=True)
|
||||
generations_file.parent.mkdir(exist_ok=True)
|
||||
with open(results_file, "a+") as writer:
|
||||
for key in sorted(metrics):
|
||||
if key in ["log", "progress_bar", "preds"]:
|
||||
continue
|
||||
val = metrics[key]
|
||||
if isinstance(val, torch.Tensor):
|
||||
val = val.item()
|
||||
msg = f"{key}: {val:.6f}\n"
|
||||
writer.write(msg)
|
||||
|
||||
if not save_generations:
|
||||
return
|
||||
|
||||
if "preds" in metrics:
|
||||
content = "\n".join(metrics["preds"])
|
||||
generations_file.open("w+").write(content)
|
||||
|
||||
@rank_zero_only
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
try:
|
||||
npars = pl_module.model.model.num_parameters()
|
||||
except AttributeError:
|
||||
npars = pl_module.model.num_parameters()
|
||||
|
||||
n_trainable_pars = count_trainable_parameters(pl_module)
|
||||
# mp stands for million parameters
|
||||
trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6})
|
||||
|
||||
@rank_zero_only
|
||||
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
save_json(pl_module.metrics, pl_module.metrics_save_path)
|
||||
return self._write_logs(trainer, pl_module, "test")
|
||||
|
||||
@rank_zero_only
|
||||
def on_validation_end(self, trainer: pl.Trainer, pl_module):
|
||||
save_json(pl_module.metrics, pl_module.metrics_save_path)
|
||||
# Uncommenting this will save val generations
|
||||
# return self._write_logs(trainer, pl_module, "valid")
|
||||
99
examples/research_projects/rag/consolidate_rag_checkpoint.py
Normal file
99
examples/research_projects/rag/consolidate_rag_checkpoint.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
A script creating a RAG checkpoint from a generator and a question encoder checkpoints.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import AutoConfig, AutoTokenizer, RagConfig, RagSequenceForGeneration, RagTokenForGeneration
|
||||
|
||||
|
||||
def consolidate(
|
||||
model_type,
|
||||
generator_name_or_path: str,
|
||||
question_encoder_name_or_path: str,
|
||||
dest_dir: Path,
|
||||
config_name_or_path: str = None,
|
||||
generator_tokenizer_name_or_path: str = None,
|
||||
question_encoder_tokenizer_name_or_path: str = None,
|
||||
):
|
||||
|
||||
if config_name_or_path is None:
|
||||
config_name_or_path = "facebook/rag-token-base" if model_type == "rag_token" else "facebook/rag-sequence-base"
|
||||
|
||||
if generator_tokenizer_name_or_path is None:
|
||||
generator_tokenizer_name_or_path = generator_name_or_path
|
||||
|
||||
if question_encoder_tokenizer_name_or_path is None:
|
||||
question_encoder_tokenizer_name_or_path = question_encoder_name_or_path
|
||||
|
||||
model_class = RagTokenForGeneration if model_type == "rag_token" else RagSequenceForGeneration
|
||||
|
||||
# Save model.
|
||||
rag_config = RagConfig.from_pretrained(config_name_or_path)
|
||||
gen_config = AutoConfig.from_pretrained(generator_name_or_path)
|
||||
question_encoder_config = AutoConfig.from_pretrained(question_encoder_name_or_path)
|
||||
|
||||
rag_config.generator = gen_config
|
||||
rag_config.question_encoder = question_encoder_config
|
||||
|
||||
rag_model = model_class.from_pretrained_question_encoder_generator(
|
||||
question_encoder_name_or_path, generator_name_or_path, config=rag_config
|
||||
)
|
||||
rag_model.save_pretrained(dest_dir)
|
||||
|
||||
# Sanity check.
|
||||
model_class.from_pretrained(dest_dir)
|
||||
|
||||
# Save tokenizers.
|
||||
gen_tokenizer = AutoTokenizer.from_pretrained(generator_tokenizer_name_or_path)
|
||||
gen_tokenizer.save_pretrained(dest_dir / "generator_tokenizer/")
|
||||
question_encoder_tokenizer = AutoTokenizer.from_pretrained(question_encoder_tokenizer_name_or_path)
|
||||
question_encoder_tokenizer.save_pretrained(dest_dir / "question_encoder_tokenizer/")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
choices=["rag_sequence", "rag_token"],
|
||||
required=True,
|
||||
type=str,
|
||||
help="RAG model type: rag_sequence, rag_token",
|
||||
)
|
||||
parser.add_argument("--dest", type=str, required=True, help="Path to the output checkpoint directory.")
|
||||
parser.add_argument("--generator_name_or_path", type=str, required=True, help="Generator model identifier")
|
||||
parser.add_argument(
|
||||
"--question_encoder_name_or_path", type=str, required=True, help="Question encoder model identifier"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--generator_tokenizer_name_or_path",
|
||||
type=str,
|
||||
help="Generator tokenizer identifier, if not specified, resolves to ``generator_name_or_path``",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--question_encoder_tokenizer_name_or_path",
|
||||
type=str,
|
||||
help="Question encoder tokenizer identifier, if not specified, resolves to ``question_encoder_name_or_path``",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_name_or_path",
|
||||
type=str,
|
||||
help="Identifier of the model config to use, if not provided, resolves to a base config for a given ``model_type``",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
dest_dir = Path(args.dest)
|
||||
dest_dir.mkdir(exist_ok=True)
|
||||
|
||||
consolidate(
|
||||
args.model_type,
|
||||
args.generator_name_or_path,
|
||||
args.question_encoder_name_or_path,
|
||||
dest_dir,
|
||||
args.config_name_or_path,
|
||||
args.generator_tokenizer_name_or_path,
|
||||
args.question_encoder_tokenizer_name_or_path,
|
||||
)
|
||||
139
examples/research_projects/rag/distributed_retriever.py
Normal file
139
examples/research_projects/rag/distributed_retriever.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import psutil
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from transformers import RagRetriever
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RagPyTorchDistributedRetriever(RagRetriever):
|
||||
"""
|
||||
A distributed retriever built on top of the ``torch.distributed`` communication package. During training all workers
|
||||
initialize their own instance of the retriever, however, only the main worker loads the index into memory. The index is stored
|
||||
in cpu memory. The index will also work well in a non-distributed setup.
|
||||
|
||||
Args:
|
||||
config (:class:`~transformers.RagConfig`):
|
||||
The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build.
|
||||
question_encoder_tokenizer (:class:`~transformers.PretrainedTokenizer`):
|
||||
The tokenizer that was used to tokenize the question.
|
||||
It is used to decode the question and then use the generator_tokenizer.
|
||||
generator_tokenizer (:class:`~transformers.PretrainedTokenizer`):
|
||||
The tokenizer used for the generator part of the RagModel.
|
||||
index (:class:`~transformers.models.rag.retrieval_rag.Index`, optional, defaults to the one defined by the configuration):
|
||||
If specified, use this index instead of the one built using the configuration
|
||||
"""
|
||||
|
||||
_init_retrieval = False
|
||||
|
||||
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None):
|
||||
super().__init__(
|
||||
config,
|
||||
question_encoder_tokenizer=question_encoder_tokenizer,
|
||||
generator_tokenizer=generator_tokenizer,
|
||||
index=index,
|
||||
)
|
||||
self.process_group = None
|
||||
|
||||
def init_retrieval(self, distributed_port: int):
|
||||
"""
|
||||
Retriever initialization function, needs to be called from the training process. The function sets some common parameters
|
||||
and environment variables. On top of that, (only) the main process in the process group loads the index into memory.
|
||||
|
||||
Args:
|
||||
distributed_port (:obj:`int`):
|
||||
The port on which the main communication of the training run is carried out. We set the port for retrieval-related
|
||||
communication as ``distributed_port + 1``.
|
||||
"""
|
||||
|
||||
logger.info("initializing retrieval")
|
||||
|
||||
# initializing a separate process group for retrieval as the default
|
||||
# nccl backend doesn't support gather/scatter operations while gloo
|
||||
# is too slow to replace nccl for the core gpu communication
|
||||
if dist.is_initialized():
|
||||
logger.info("dist initialized")
|
||||
# needs to be set manually
|
||||
os.environ["GLOO_SOCKET_IFNAME"] = self._infer_socket_ifname()
|
||||
# avoid clash with the NCCL port
|
||||
os.environ["MASTER_PORT"] = str(distributed_port + 1)
|
||||
self.process_group = dist.new_group(ranks=None, backend="gloo")
|
||||
|
||||
# initialize retriever only on the main worker
|
||||
if not dist.is_initialized() or self._is_main():
|
||||
logger.info("dist not initialized / main")
|
||||
self.index.init_index()
|
||||
|
||||
# all processes wait untill the retriever is initialized by the main process
|
||||
if dist.is_initialized():
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
def _is_main(self):
|
||||
return dist.get_rank(group=self.process_group) == 0
|
||||
|
||||
def _scattered(self, scatter_list, target_shape, target_type=torch.float32):
|
||||
target_tensor = torch.empty(target_shape, dtype=target_type)
|
||||
dist.scatter(target_tensor, src=0, scatter_list=scatter_list, group=self.process_group)
|
||||
return target_tensor
|
||||
|
||||
def _infer_socket_ifname(self):
|
||||
addrs = psutil.net_if_addrs()
|
||||
# a hacky way to deal with varying network interface names
|
||||
ifname = next((addr for addr in addrs if addr.startswith("e")), None)
|
||||
return ifname
|
||||
|
||||
def retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, List[dict]]:
|
||||
"""
|
||||
Retrieves documents for specified ``question_hidden_states``. The main process, which has the access to the index stored in memory, gathers queries
|
||||
from all the processes in the main training process group, performs the retrieval and scatters back the results.
|
||||
|
||||
Args:
|
||||
question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`):
|
||||
A batch of query vectors to retrieve with.
|
||||
n_docs (:obj:`int`):
|
||||
The number of docs retrieved per query.
|
||||
|
||||
Output:
|
||||
retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)`
|
||||
The retrieval embeddings of the retrieved docs per query.
|
||||
doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`)
|
||||
The ids of the documents in the index
|
||||
doc_dicts (:obj:`List[dict]`):
|
||||
The retrieved_doc_embeds examples per query.
|
||||
"""
|
||||
|
||||
# single GPU training
|
||||
if not dist.is_initialized():
|
||||
doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs)
|
||||
return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids)
|
||||
|
||||
# distributed training
|
||||
world_size = dist.get_world_size(group=self.process_group)
|
||||
|
||||
# gather logic
|
||||
gather_list = None
|
||||
if self._is_main():
|
||||
gather_list = [torch.empty(question_hidden_states.shape, dtype=torch.float32) for _ in range(world_size)]
|
||||
dist.gather(torch.tensor(question_hidden_states), dst=0, gather_list=gather_list, group=self.process_group)
|
||||
|
||||
# scatter logic
|
||||
n_queries = question_hidden_states.shape[0]
|
||||
scatter_ids = []
|
||||
scatter_vectors = []
|
||||
if self._is_main():
|
||||
assert len(gather_list) == world_size
|
||||
ids, vectors = self._main_retrieve(torch.cat(gather_list).numpy(), n_docs)
|
||||
ids, vectors = torch.tensor(ids), torch.tensor(vectors)
|
||||
scatter_ids = self._chunk_tensor(ids, n_queries)
|
||||
scatter_vectors = self._chunk_tensor(vectors, n_queries)
|
||||
doc_ids = self._scattered(scatter_ids, [n_queries, n_docs], target_type=torch.int64)
|
||||
retrieved_doc_embeds = self._scattered(scatter_vectors, [n_queries, n_docs, question_hidden_states.shape[1]])
|
||||
|
||||
return retrieved_doc_embeds.numpy(), doc_ids.numpy(), self.index.get_doc_dicts(doc_ids)
|
||||
314
examples/research_projects/rag/eval_rag.py
Normal file
314
examples/research_projects/rag/eval_rag.py
Normal file
@@ -0,0 +1,314 @@
|
||||
""" Evaluation script for RAG models."""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import BartForConditionalGeneration, RagRetriever, RagSequenceForGeneration, RagTokenForGeneration
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # isort:skip
|
||||
from utils_rag import exact_match_score, f1_score # noqa: E402 # isort:skip
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
transformers_logging.set_verbosity_info()
|
||||
|
||||
|
||||
def infer_model_type(model_name_or_path):
|
||||
if "token" in model_name_or_path:
|
||||
return "rag_token"
|
||||
if "sequence" in model_name_or_path:
|
||||
return "rag_sequence"
|
||||
if "bart" in model_name_or_path:
|
||||
return "bart"
|
||||
return None
|
||||
|
||||
|
||||
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
||||
return max(metric_fn(prediction, gt) for gt in ground_truths)
|
||||
|
||||
|
||||
def get_scores(args, preds_path, gold_data_path):
|
||||
hypos = [line.strip() for line in open(preds_path, "r").readlines()]
|
||||
answers = []
|
||||
|
||||
if args.gold_data_mode == "qa":
|
||||
data = pd.read_csv(gold_data_path, sep="\t", header=None)
|
||||
for answer_list in data[1]:
|
||||
ground_truths = ast.literal_eval(answer_list)
|
||||
answers.append(ground_truths)
|
||||
else:
|
||||
references = [line.strip() for line in open(gold_data_path, "r").readlines()]
|
||||
answers = [[reference] for reference in references]
|
||||
|
||||
f1 = em = total = 0
|
||||
for prediction, ground_truths in zip(hypos, answers):
|
||||
total += 1
|
||||
em += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
|
||||
f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths)
|
||||
|
||||
em = 100.0 * em / total
|
||||
f1 = 100.0 * f1 / total
|
||||
|
||||
logger.info(f"F1: {f1:.2f}")
|
||||
logger.info(f"EM: {em:.2f}")
|
||||
|
||||
|
||||
def get_precision_at_k(args, preds_path, gold_data_path):
|
||||
k = args.k
|
||||
hypos = [line.strip() for line in open(preds_path, "r").readlines()]
|
||||
references = [line.strip() for line in open(gold_data_path, "r").readlines()]
|
||||
|
||||
em = total = 0
|
||||
for hypo, reference in zip(hypos, references):
|
||||
hypo_provenance = set(hypo.split("\t")[:k])
|
||||
ref_provenance = set(reference.split("\t"))
|
||||
total += 1
|
||||
em += len(hypo_provenance & ref_provenance) / k
|
||||
|
||||
em = 100.0 * em / total
|
||||
logger.info(f"Precision@{k}: {em: .2f}")
|
||||
|
||||
|
||||
def evaluate_batch_retrieval(args, rag_model, questions):
|
||||
def strip_title(title):
|
||||
if title.startswith('"'):
|
||||
title = title[1:]
|
||||
if title.endswith('"'):
|
||||
title = title[:-1]
|
||||
return title
|
||||
|
||||
retriever_input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
|
||||
questions,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
)["input_ids"].to(args.device)
|
||||
|
||||
question_enc_outputs = rag_model.rag.question_encoder(retriever_input_ids)
|
||||
question_enc_pool_output = question_enc_outputs.pooler_output
|
||||
|
||||
result = rag_model.retriever(
|
||||
retriever_input_ids,
|
||||
question_enc_pool_output.cpu().detach().to(torch.float32).numpy(),
|
||||
prefix=rag_model.rag.generator.config.prefix,
|
||||
n_docs=rag_model.config.n_docs,
|
||||
return_tensors="pt",
|
||||
)
|
||||
all_docs = rag_model.retriever.index.get_doc_dicts(result.doc_ids)
|
||||
provenance_strings = []
|
||||
for docs in all_docs:
|
||||
provenance = [strip_title(title) for title in docs["title"]]
|
||||
provenance_strings.append("\t".join(provenance))
|
||||
return provenance_strings
|
||||
|
||||
|
||||
def evaluate_batch_e2e(args, rag_model, questions):
|
||||
with torch.no_grad():
|
||||
inputs_dict = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
|
||||
questions, return_tensors="pt", padding=True, truncation=True
|
||||
)
|
||||
|
||||
input_ids = inputs_dict.input_ids.to(args.device)
|
||||
attention_mask = inputs_dict.attention_mask.to(args.device)
|
||||
outputs = rag_model.generate( # rag_model overwrites generate
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
num_beams=args.num_beams,
|
||||
min_length=args.min_length,
|
||||
max_length=args.max_length,
|
||||
early_stopping=False,
|
||||
num_return_sequences=1,
|
||||
bad_words_ids=[[0, 0]], # BART likes to repeat BOS tokens, dont allow it to generate more than one
|
||||
clean_up_tokenization=True,
|
||||
print_docs=args.print_docs,
|
||||
)
|
||||
answers = rag_model.retriever.generator_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
if args.print_predictions:
|
||||
for q, a in zip(questions, answers):
|
||||
logger.info("Q: {} - A: {}".format(q, a))
|
||||
|
||||
return answers
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
choices=["rag_sequence", "rag_token", "bart"],
|
||||
type=str,
|
||||
help="RAG model type: rag_sequence, rag_token or bart, if none specified, the type is inferred from the model_name_or_path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index_name",
|
||||
default=None,
|
||||
choices=["exact", "compressed", "legacy"],
|
||||
type=str,
|
||||
help="RAG model retriever type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to the retrieval index",
|
||||
)
|
||||
parser.add_argument("--n_docs", default=5, type=int, help="Number of retrieved docs")
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pretrained checkpoints or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_mode",
|
||||
choices=["e2e", "retrieval"],
|
||||
default="e2e",
|
||||
type=str,
|
||||
help="Evaluation mode, e2e calculates exact match and F1 of the downstream task, retrieval calculates precision@k.",
|
||||
)
|
||||
parser.add_argument("--k", default=1, type=int, help="k for the precision@k calculation")
|
||||
parser.add_argument(
|
||||
"--evaluation_set",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to a file containing evaluation samples",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gold_data_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to a tab-separated file with gold samples",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gold_data_mode",
|
||||
default="qa",
|
||||
type=str,
|
||||
choices=["qa", "ans"],
|
||||
help="Format of the gold data file"
|
||||
"qa - a single line in the following format: question [tab] answer_list"
|
||||
"ans - a single line of the gold file contains the expected answer string",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--predictions_path",
|
||||
type=str,
|
||||
default="predictions.txt",
|
||||
help="Name of the predictions file, to be stored in the checkpoints directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_all_checkpoints",
|
||||
action="store_true",
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_batch_size",
|
||||
default=8,
|
||||
type=int,
|
||||
help="Batch size per GPU/CPU for evaluation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recalculate",
|
||||
help="Recalculate predictions even if the prediction file exists",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_beams",
|
||||
default=4,
|
||||
type=int,
|
||||
help="Number of beams to be used when generating answers",
|
||||
)
|
||||
parser.add_argument("--min_length", default=1, type=int, help="Min length of the generated answers")
|
||||
parser.add_argument("--max_length", default=50, type=int, help="Max length of the generated answers")
|
||||
|
||||
parser.add_argument(
|
||||
"--print_predictions",
|
||||
action="store_true",
|
||||
help="If True, prints predictions while evaluating.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print_docs",
|
||||
action="store_true",
|
||||
help="If True, prints docs retried while generating.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
return args
|
||||
|
||||
|
||||
def main(args):
|
||||
model_kwargs = {}
|
||||
if args.model_type is None:
|
||||
args.model_type = infer_model_type(args.model_name_or_path)
|
||||
assert args.model_type is not None
|
||||
if args.model_type.startswith("rag"):
|
||||
model_class = RagTokenForGeneration if args.model_type == "rag_token" else RagSequenceForGeneration
|
||||
model_kwargs["n_docs"] = args.n_docs
|
||||
if args.index_name is not None:
|
||||
model_kwargs["index_name"] = args.index_name
|
||||
if args.index_path is not None:
|
||||
model_kwargs["index_path"] = args.index_path
|
||||
else:
|
||||
model_class = BartForConditionalGeneration
|
||||
|
||||
checkpoints = (
|
||||
[f.path for f in os.scandir(args.model_name_or_path) if f.is_dir()]
|
||||
if args.eval_all_checkpoints
|
||||
else [args.model_name_or_path]
|
||||
)
|
||||
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
|
||||
score_fn = get_scores if args.eval_mode == "e2e" else get_precision_at_k
|
||||
evaluate_batch_fn = evaluate_batch_e2e if args.eval_mode == "e2e" else evaluate_batch_retrieval
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
if os.path.exists(args.predictions_path) and (not args.recalculate):
|
||||
logger.info("Calculating metrics based on an existing predictions file: {}".format(args.predictions_path))
|
||||
score_fn(args, args.predictions_path, args.gold_data_path)
|
||||
continue
|
||||
|
||||
logger.info("***** Running evaluation for {} *****".format(checkpoint))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
logger.info(" Predictions will be stored under {}".format(args.predictions_path))
|
||||
|
||||
if args.model_type.startswith("rag"):
|
||||
retriever = RagRetriever.from_pretrained(checkpoint, **model_kwargs)
|
||||
model = model_class.from_pretrained(checkpoint, retriever=retriever, **model_kwargs)
|
||||
model.retriever.init_retrieval()
|
||||
else:
|
||||
model = model_class.from_pretrained(checkpoint, **model_kwargs)
|
||||
model.to(args.device)
|
||||
|
||||
with open(args.evaluation_set, "r") as eval_file, open(args.predictions_path, "w") as preds_file:
|
||||
questions = []
|
||||
for line in tqdm(eval_file):
|
||||
questions.append(line.strip())
|
||||
if len(questions) == args.eval_batch_size:
|
||||
answers = evaluate_batch_fn(args, model, questions)
|
||||
preds_file.write("\n".join(answers) + "\n")
|
||||
preds_file.flush()
|
||||
questions = []
|
||||
if len(questions) > 0:
|
||||
answers = evaluate_batch_fn(args, model, questions)
|
||||
preds_file.write("\n".join(answers))
|
||||
preds_file.flush()
|
||||
|
||||
score_fn(args, args.predictions_path, args.gold_data_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
main(args)
|
||||
512
examples/research_projects/rag/finetune_rag.py
Normal file
512
examples/research_projects/rag/finetune_rag.py
Normal file
@@ -0,0 +1,512 @@
|
||||
"""Finetuning script for RAG models. Adapted from examples.seq2seq.finetune.py"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from pytorch_lightning.accelerators.ddp_accelerator import DDPAccelerator
|
||||
from pytorch_lightning.cluster_environments import TorchElasticEnvironment
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
BartForConditionalGeneration,
|
||||
BatchEncoding,
|
||||
RagConfig,
|
||||
RagSequenceForGeneration,
|
||||
RagTokenForGeneration,
|
||||
RagTokenizer,
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
from callbacks_rag import ( # noqa: E402 # isort:skipq
|
||||
get_checkpoint_callback,
|
||||
get_early_stopping_callback,
|
||||
Seq2SeqLoggingCallback,
|
||||
)
|
||||
from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||
from utils_rag import ( # noqa: E402 # isort:skip
|
||||
calculate_exact_match,
|
||||
flatten_list,
|
||||
get_git_info,
|
||||
is_rag_model,
|
||||
lmap,
|
||||
pickle_save,
|
||||
save_git_info,
|
||||
save_json,
|
||||
set_extra_model_params,
|
||||
Seq2SeqDataset,
|
||||
)
|
||||
|
||||
# need the parent dir module
|
||||
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
|
||||
from lightning_base import BaseTransformer, add_generic_args, generic_train # noqa
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
transformers_logging.set_verbosity_info()
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
# In PTL >v1.0, `init_ddp_connection` method in the `LightningModule`
|
||||
# is no longer used, and is moved into DDPAccelerator instead.
|
||||
# We override DDPAccelerator to add our custom logic for initializing the
|
||||
# retriever.
|
||||
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/backends/test_accelerator_connector.py
|
||||
|
||||
|
||||
class CustomAccel(DDPAccelerator):
|
||||
def __init__(self, trainer=None, **kwargs):
|
||||
# Trainer is set later.
|
||||
super().__init__(trainer, **kwargs)
|
||||
|
||||
def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True):
|
||||
logger.info("Custom init_ddp_connection.")
|
||||
module = self.trainer.model
|
||||
if self.cluster_environment is None:
|
||||
self.cluster_environment = TorchElasticEnvironment()
|
||||
self.distributed_port = module.hparams.distributed_port
|
||||
os.environ["MASTER_PORT"] = str(self.distributed_port)
|
||||
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
|
||||
if module.is_rag_model:
|
||||
module.model.rag.retriever.init_retrieval(self.distributed_port)
|
||||
|
||||
|
||||
class GenerativeQAModule(BaseTransformer):
|
||||
mode = "generative_qa"
|
||||
loss_names = ["loss"]
|
||||
metric_names = ["em"]
|
||||
val_metric = "em"
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
# when loading from a pytorch lightning checkpoint, hparams are passed as dict
|
||||
if isinstance(hparams, dict):
|
||||
hparams = AttrDict(hparams)
|
||||
if hparams.model_type == "rag_sequence":
|
||||
self.model_class = RagSequenceForGeneration
|
||||
elif hparams.model_type == "rag_token":
|
||||
self.model_class = RagTokenForGeneration
|
||||
elif hparams.model_type == "bart":
|
||||
self.model_class = BartForConditionalGeneration
|
||||
else:
|
||||
self.model_class = T5ForConditionalGeneration
|
||||
self.is_rag_model = is_rag_model(hparams.model_type)
|
||||
|
||||
config_class = RagConfig if self.is_rag_model else AutoConfig
|
||||
config = config_class.from_pretrained(hparams.model_name_or_path)
|
||||
|
||||
# set retriever parameters
|
||||
config.index_name = hparams.index_name or config.index_name
|
||||
config.passages_path = hparams.passages_path or config.passages_path
|
||||
config.index_path = hparams.index_path or config.index_path
|
||||
config.use_dummy_dataset = hparams.use_dummy_dataset
|
||||
|
||||
# set extra_model_params for generator configs and load_model
|
||||
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "attention_dropout", "dropout")
|
||||
if self.is_rag_model:
|
||||
if hparams.prefix is not None:
|
||||
config.generator.prefix = hparams.prefix
|
||||
config.label_smoothing = hparams.label_smoothing
|
||||
hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator)
|
||||
retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path, config=config)
|
||||
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever)
|
||||
prefix = config.question_encoder.prefix
|
||||
else:
|
||||
if hparams.prefix is not None:
|
||||
config.prefix = hparams.prefix
|
||||
hparams, config = set_extra_model_params(extra_model_params, hparams, config)
|
||||
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config)
|
||||
prefix = config.prefix
|
||||
|
||||
tokenizer = (
|
||||
RagTokenizer.from_pretrained(hparams.model_name_or_path)
|
||||
if self.is_rag_model
|
||||
else AutoTokenizer.from_pretrained(hparams.model_name_or_path)
|
||||
)
|
||||
|
||||
super().__init__(hparams, config=config, tokenizer=tokenizer, model=model)
|
||||
|
||||
save_git_info(self.hparams.output_dir)
|
||||
self.output_dir = Path(self.hparams.output_dir)
|
||||
self.metrics_save_path = Path(self.output_dir) / "metrics.json"
|
||||
self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
|
||||
pickle_save(self.hparams, self.hparams_save_path)
|
||||
self.step_count = 0
|
||||
self.metrics = defaultdict(list)
|
||||
|
||||
self.dataset_kwargs: dict = dict(
|
||||
data_dir=self.hparams.data_dir,
|
||||
max_source_length=self.hparams.max_source_length,
|
||||
prefix=prefix or "",
|
||||
)
|
||||
n_observations_per_split = {
|
||||
"train": self.hparams.n_train,
|
||||
"val": self.hparams.n_val,
|
||||
"test": self.hparams.n_test,
|
||||
}
|
||||
self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
|
||||
|
||||
self.target_lens = {
|
||||
"train": self.hparams.max_target_length,
|
||||
"val": self.hparams.val_max_target_length,
|
||||
"test": self.hparams.test_max_target_length,
|
||||
}
|
||||
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
|
||||
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
|
||||
|
||||
self.hparams.git_sha = get_git_info()["repo_sha"]
|
||||
self.num_workers = hparams.num_workers
|
||||
self.distributed_port = self.hparams.distributed_port
|
||||
|
||||
# For single GPU training, init_ddp_connection is not called.
|
||||
# So we need to initialize the retrievers here.
|
||||
if hparams.gpus <= 1:
|
||||
self.model.retriever.init_retrieval(self.distributed_port)
|
||||
|
||||
def forward(self, input_ids, **kwargs):
|
||||
return self.model(input_ids, **kwargs)
|
||||
|
||||
def ids_to_clean_text(self, generated_ids: List[int]):
|
||||
gen_text = self.tokenizer.batch_decode(
|
||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
return lmap(str.strip, gen_text)
|
||||
|
||||
def _step(self, batch: dict) -> Tuple:
|
||||
source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
||||
|
||||
rag_kwargs = {}
|
||||
if isinstance(self.model, T5ForConditionalGeneration):
|
||||
decoder_input_ids = self.model._shift_right(target_ids)
|
||||
lm_labels = target_ids
|
||||
elif isinstance(self.model, BartForConditionalGeneration):
|
||||
decoder_input_ids = target_ids[:, :-1].contiguous()
|
||||
lm_labels = target_ids[:, 1:].clone()
|
||||
else:
|
||||
assert self.is_rag_model
|
||||
generator = self.model.rag.generator
|
||||
if isinstance(generator, T5ForConditionalGeneration):
|
||||
decoder_start_token_id = generator.config.decoder_start_token_id
|
||||
decoder_input_ids = (
|
||||
torch.cat(
|
||||
[torch.Tensor([[decoder_start_token_id]] * target_ids.shape[0]).to(target_ids), target_ids],
|
||||
dim=1,
|
||||
)
|
||||
if target_ids.shape[0] < self.target_lens["train"]
|
||||
else generator._shift_right(target_ids)
|
||||
)
|
||||
elif isinstance(generator, BartForConditionalGeneration):
|
||||
decoder_input_ids = target_ids
|
||||
lm_labels = decoder_input_ids
|
||||
rag_kwargs["reduce_loss"] = True
|
||||
|
||||
assert decoder_input_ids is not None
|
||||
|
||||
outputs = self(
|
||||
source_ids,
|
||||
attention_mask=source_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
use_cache=False,
|
||||
labels=lm_labels,
|
||||
**rag_kwargs,
|
||||
)
|
||||
|
||||
loss = outputs["loss"]
|
||||
return (loss,)
|
||||
|
||||
@property
|
||||
def pad(self) -> int:
|
||||
raise NotImplementedError("pad not implemented")
|
||||
|
||||
def training_step(self, batch, batch_idx) -> Dict:
|
||||
loss_tensors = self._step(batch)
|
||||
|
||||
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||
# tokens per batch
|
||||
tgt_pad_token_id = (
|
||||
self.tokenizer.generator.pad_token_id
|
||||
if isinstance(self.tokenizer, RagTokenizer)
|
||||
else self.tokenizer.pad_token_id
|
||||
)
|
||||
src_pad_token_id = (
|
||||
self.tokenizer.question_encoder.pad_token_id
|
||||
if isinstance(self.tokenizer, RagTokenizer)
|
||||
else self.tokenizer.pad_token_id
|
||||
)
|
||||
logs["tpb"] = (
|
||||
batch["input_ids"].ne(src_pad_token_id).sum() + batch["decoder_input_ids"].ne(tgt_pad_token_id).sum()
|
||||
)
|
||||
|
||||
return {"loss": loss_tensors[0], "log": logs}
|
||||
|
||||
def validation_step(self, batch, batch_idx) -> Dict:
|
||||
return self._generative_step(batch)
|
||||
|
||||
def validation_epoch_end(self, outputs, prefix="val") -> Dict:
|
||||
self.step_count += 1
|
||||
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
||||
loss = losses["loss"]
|
||||
gen_metrics = {
|
||||
k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
|
||||
}
|
||||
metrics_tensor: torch.FloatTensor = torch.tensor(gen_metrics[self.val_metric]).type_as(loss)
|
||||
gen_metrics.update({k: v.item() for k, v in losses.items()})
|
||||
|
||||
# fix for https://github.com/PyTorchLightning/pytorch-lightning/issues/2424
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(metrics_tensor, op=dist.ReduceOp.SUM)
|
||||
metrics_tensor = metrics_tensor / dist.get_world_size()
|
||||
gen_metrics.update({self.val_metric: metrics_tensor.item()})
|
||||
|
||||
losses.update(gen_metrics)
|
||||
metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
|
||||
metrics["step_count"] = self.step_count
|
||||
self.save_metrics(metrics, prefix) # writes to self.metrics_save_path
|
||||
preds = flatten_list([x["preds"] for x in outputs])
|
||||
return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_{self.val_metric}": metrics_tensor}
|
||||
|
||||
def save_metrics(self, latest_metrics, type_path) -> None:
|
||||
self.metrics[type_path].append(latest_metrics)
|
||||
save_json(self.metrics, self.metrics_save_path)
|
||||
|
||||
def calc_generative_metrics(self, preds, target) -> Dict:
|
||||
return calculate_exact_match(preds, target)
|
||||
|
||||
def _generative_step(self, batch: dict) -> dict:
|
||||
start_time = time.time()
|
||||
batch = BatchEncoding(batch).to(device=self.model.device)
|
||||
generated_ids = self.model.generate(
|
||||
batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
do_deduplication=False, # rag specific parameter
|
||||
use_cache=True,
|
||||
min_length=1,
|
||||
max_length=self.target_lens["val"],
|
||||
)
|
||||
|
||||
gen_time = (time.time() - start_time) / batch["input_ids"].shape[0]
|
||||
preds: List[str] = self.ids_to_clean_text(generated_ids)
|
||||
target: List[str] = self.ids_to_clean_text(batch["decoder_input_ids"])
|
||||
loss_tensors = self._step(batch)
|
||||
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||
gen_metrics: Dict = self.calc_generative_metrics(preds, target)
|
||||
|
||||
summ_len = np.mean(lmap(len, generated_ids))
|
||||
base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **gen_metrics)
|
||||
return base_metrics
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
return self._generative_step(batch)
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
return self.validation_epoch_end(outputs, prefix="test")
|
||||
|
||||
def get_dataset(self, type_path) -> Seq2SeqDataset:
|
||||
n_obs = self.n_obs[type_path]
|
||||
max_target_length = self.target_lens[type_path]
|
||||
dataset = Seq2SeqDataset(
|
||||
self.tokenizer,
|
||||
type_path=type_path,
|
||||
n_obs=n_obs,
|
||||
max_target_length=max_target_length,
|
||||
**self.dataset_kwargs,
|
||||
)
|
||||
return dataset
|
||||
|
||||
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
||||
dataset = self.get_dataset(type_path)
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=dataset.collate_fn,
|
||||
shuffle=shuffle,
|
||||
num_workers=self.num_workers,
|
||||
)
|
||||
return dataloader
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
||||
return dataloader
|
||||
|
||||
def val_dataloader(self) -> DataLoader:
|
||||
return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
|
||||
|
||||
def test_dataloader(self) -> DataLoader:
|
||||
return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
|
||||
|
||||
@pl.utilities.rank_zero_only
|
||||
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
save_path = self.output_dir.joinpath("checkpoint{}".format(self.step_count))
|
||||
self.model.config.save_step = self.step_count
|
||||
self.model.save_pretrained(save_path)
|
||||
self.tokenizer.save_pretrained(save_path)
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
BaseTransformer.add_model_specific_args(parser, root_dir)
|
||||
add_generic_args(parser, root_dir)
|
||||
parser.add_argument(
|
||||
"--max_source_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_target_length",
|
||||
default=25,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val_max_target_length",
|
||||
default=25,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_max_target_length",
|
||||
default=25,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
|
||||
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--n_val", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
|
||||
parser.add_argument(
|
||||
"--prefix",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Prefix added at the beginning of each text, typically used with T5-based models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--early_stopping_patience",
|
||||
type=int,
|
||||
default=-1,
|
||||
required=False,
|
||||
help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distributed-port", type=int, default=-1, required=False, help="Port number for distributed training."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
choices=["rag_sequence", "rag_token", "bart", "t5"],
|
||||
type=str,
|
||||
help="RAG model type: sequence or token, if none specified, the type is inferred from the model_name_or_path",
|
||||
)
|
||||
return parser
|
||||
|
||||
@staticmethod
|
||||
def add_retriever_specific_args(parser):
|
||||
parser.add_argument(
|
||||
"--index_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Name of the index to use: 'hf' for a canonical dataset from the datasets library (default), 'custom' for a local index, or 'legacy' for the orignal one)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--passages_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the dataset of passages for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the faiss index for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_dummy_dataset",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Whether to use the dummy version of the dataset index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main(args=None, model=None) -> GenerativeQAModule:
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
|
||||
parser = GenerativeQAModule.add_retriever_specific_args(parser)
|
||||
|
||||
args = args or parser.parse_args()
|
||||
|
||||
Path(args.output_dir).mkdir(exist_ok=True)
|
||||
if model is None:
|
||||
model: GenerativeQAModule = GenerativeQAModule(args)
|
||||
|
||||
dataset = Path(args.data_dir).name
|
||||
if (
|
||||
args.logger_name == "default"
|
||||
or args.fast_dev_run
|
||||
or str(args.output_dir).startswith("/tmp")
|
||||
or str(args.output_dir).startswith("/var")
|
||||
):
|
||||
logger = True # don't pollute wandb logs unnecessarily
|
||||
elif args.logger_name == "wandb":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
project = os.environ.get("WANDB_PROJECT", dataset)
|
||||
logger = WandbLogger(name=model.output_dir.name, project=project)
|
||||
|
||||
elif args.logger_name == "wandb_shared":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
||||
|
||||
es_callback = (
|
||||
get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
|
||||
if args.early_stopping_patience >= 0
|
||||
else False
|
||||
)
|
||||
|
||||
trainer: pl.Trainer = generic_train(
|
||||
model,
|
||||
args,
|
||||
logging_callback=Seq2SeqLoggingCallback(),
|
||||
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
|
||||
early_stopping_callback=es_callback,
|
||||
logger=logger,
|
||||
accelerator=CustomAccel() if args.gpus > 1 else None,
|
||||
)
|
||||
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
|
||||
|
||||
if not args.do_predict:
|
||||
return model
|
||||
|
||||
# test() without a model tests using the best checkpoint automatically
|
||||
trainer.test()
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
34
examples/research_projects/rag/finetune_rag.sh
Executable file
34
examples/research_projects/rag/finetune_rag.sh
Executable file
@@ -0,0 +1,34 @@
|
||||
# Add parent directory to python path to access lightning_base.py
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path
|
||||
# run ./examples/rag/finetune.sh --help to see all the possible options
|
||||
|
||||
python examples/rag/finetune_rag.py \
|
||||
--data_dir $DATA_DIR \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--model_name_or_path $MODEL_NAME_OR_PATH \
|
||||
--model_type rag_sequence \
|
||||
--fp16 \
|
||||
--gpus 8 \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--n_val -1 \
|
||||
--val_check_interval 0.25 \
|
||||
--train_batch_size 8 \
|
||||
--eval_batch_size 1 \
|
||||
--max_source_length 128 \
|
||||
--max_target_length 25 \
|
||||
--val_max_target_length 25 \
|
||||
--test_max_target_length 25 \
|
||||
--label_smoothing 0.1 \
|
||||
--dropout 0.1 \
|
||||
--attention_dropout 0.1 \
|
||||
--weight_decay 0.001 \
|
||||
--adam_epsilon 1e-08 \
|
||||
--max_grad_norm 0.1 \
|
||||
--lr_scheduler polynomial \
|
||||
--learning_rate 3e-05 \
|
||||
--num_train_epochs 100 \
|
||||
--warmup_steps 500 \
|
||||
--gradient_accumulation_steps 1
|
||||
391
examples/research_projects/rag/lightning_base.py
Normal file
391
examples/research_projects/rag/lightning_base.py
Normal file
@@ -0,0 +1,391 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
|
||||
from transformers import (
|
||||
AdamW,
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForPreTraining,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelWithLMHead,
|
||||
AutoTokenizer,
|
||||
PretrainedConfig,
|
||||
PreTrainedTokenizer,
|
||||
)
|
||||
from transformers.optimization import (
|
||||
Adafactor,
|
||||
get_cosine_schedule_with_warmup,
|
||||
get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||
get_linear_schedule_with_warmup,
|
||||
get_polynomial_decay_schedule_with_warmup,
|
||||
)
|
||||
from transformers.utils.versions import require_version_examples
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
require_version_examples("pytorch_lightning>=1.0.4")
|
||||
|
||||
MODEL_MODES = {
|
||||
"base": AutoModel,
|
||||
"sequence-classification": AutoModelForSequenceClassification,
|
||||
"question-answering": AutoModelForQuestionAnswering,
|
||||
"pretraining": AutoModelForPreTraining,
|
||||
"token-classification": AutoModelForTokenClassification,
|
||||
"language-modeling": AutoModelWithLMHead,
|
||||
"summarization": AutoModelForSeq2SeqLM,
|
||||
"translation": AutoModelForSeq2SeqLM,
|
||||
}
|
||||
|
||||
|
||||
# update this and the import above to support new schedulers from transformers.optimization
|
||||
arg_to_scheduler = {
|
||||
"linear": get_linear_schedule_with_warmup,
|
||||
"cosine": get_cosine_schedule_with_warmup,
|
||||
"cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||
"polynomial": get_polynomial_decay_schedule_with_warmup,
|
||||
# '': get_constant_schedule, # not supported for now
|
||||
# '': get_constant_schedule_with_warmup, # not supported for now
|
||||
}
|
||||
arg_to_scheduler_choices = sorted(arg_to_scheduler.keys())
|
||||
arg_to_scheduler_metavar = "{" + ", ".join(arg_to_scheduler_choices) + "}"
|
||||
|
||||
|
||||
class BaseTransformer(pl.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
hparams: argparse.Namespace,
|
||||
num_labels=None,
|
||||
mode="base",
|
||||
config=None,
|
||||
tokenizer=None,
|
||||
model=None,
|
||||
**config_kwargs
|
||||
):
|
||||
"""Initialize a model, tokenizer and config."""
|
||||
super().__init__()
|
||||
# TODO: move to self.save_hyperparameters()
|
||||
# self.save_hyperparameters()
|
||||
# can also expand arguments into trainer signature for easier reading
|
||||
|
||||
self.save_hyperparameters(hparams)
|
||||
self.step_count = 0
|
||||
self.output_dir = Path(self.hparams.output_dir)
|
||||
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
|
||||
if config is None:
|
||||
self.config = AutoConfig.from_pretrained(
|
||||
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
|
||||
**({"num_labels": num_labels} if num_labels is not None else {}),
|
||||
cache_dir=cache_dir,
|
||||
**config_kwargs,
|
||||
)
|
||||
else:
|
||||
self.config: PretrainedConfig = config
|
||||
|
||||
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
|
||||
for p in extra_model_params:
|
||||
if getattr(self.hparams, p, None):
|
||||
assert hasattr(self.config, p), f"model config doesn't have a `{p}` attribute"
|
||||
setattr(self.config, p, getattr(self.hparams, p))
|
||||
|
||||
if tokenizer is None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
else:
|
||||
self.tokenizer: PreTrainedTokenizer = tokenizer
|
||||
self.model_type = MODEL_MODES[mode]
|
||||
if model is None:
|
||||
self.model = self.model_type.from_pretrained(
|
||||
self.hparams.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
|
||||
config=self.config,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
else:
|
||||
self.model = model
|
||||
|
||||
def load_hf_checkpoint(self, *args, **kwargs):
|
||||
self.model = self.model_type.from_pretrained(*args, **kwargs)
|
||||
|
||||
def get_lr_scheduler(self):
|
||||
get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler]
|
||||
scheduler = get_schedule_func(
|
||||
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps()
|
||||
)
|
||||
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
|
||||
return scheduler
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""Prepare optimizer and schedule (linear warmup and decay)"""
|
||||
model = self.model
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": self.hparams.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
if self.hparams.adafactor:
|
||||
optimizer = Adafactor(
|
||||
optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False
|
||||
)
|
||||
|
||||
else:
|
||||
optimizer = AdamW(
|
||||
optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon
|
||||
)
|
||||
self.opt = optimizer
|
||||
|
||||
scheduler = self.get_lr_scheduler()
|
||||
|
||||
return [optimizer], [scheduler]
|
||||
|
||||
def test_step(self, batch, batch_nb):
|
||||
return self.validation_step(batch, batch_nb)
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
return self.validation_end(outputs)
|
||||
|
||||
def total_steps(self) -> int:
|
||||
"""The number of total training steps that will be run. Used for lr scheduler purposes."""
|
||||
num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores
|
||||
effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
|
||||
return (self.dataset_size / effective_batch_size) * self.hparams.max_epochs
|
||||
|
||||
def setup(self, mode):
|
||||
if mode == "test":
|
||||
self.dataset_size = len(self.test_dataloader().dataset)
|
||||
else:
|
||||
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
|
||||
self.dataset_size = len(self.train_dataloader().dataset)
|
||||
|
||||
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False):
|
||||
raise NotImplementedError("You must implement this for your task")
|
||||
|
||||
def train_dataloader(self):
|
||||
return self.train_loader
|
||||
|
||||
def val_dataloader(self):
|
||||
return self.get_dataloader("dev", self.hparams.eval_batch_size, shuffle=False)
|
||||
|
||||
def test_dataloader(self):
|
||||
return self.get_dataloader("test", self.hparams.eval_batch_size, shuffle=False)
|
||||
|
||||
def _feature_file(self, mode):
|
||||
return os.path.join(
|
||||
self.hparams.data_dir,
|
||||
"cached_{}_{}_{}".format(
|
||||
mode,
|
||||
list(filter(None, self.hparams.model_name_or_path.split("/"))).pop(),
|
||||
str(self.hparams.max_seq_length),
|
||||
),
|
||||
)
|
||||
|
||||
@pl.utilities.rank_zero_only
|
||||
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
save_path = self.output_dir.joinpath("best_tfmr")
|
||||
self.model.config.save_step = self.step_count
|
||||
self.model.save_pretrained(save_path)
|
||||
self.tokenizer.save_pretrained(save_path)
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from huggingface.co",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder_layerdrop",
|
||||
type=float,
|
||||
help="Encoder layer dropout probability (Optional). Goes into model.config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder_layerdrop",
|
||||
type=float,
|
||||
help="Decoder layer dropout probability (Optional). Goes into model.config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dropout",
|
||||
type=float,
|
||||
help="Dropout probability (Optional). Goes into model.config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attention_dropout",
|
||||
type=float,
|
||||
help="Attention dropout probability (Optional). Goes into model.config",
|
||||
)
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
default="linear",
|
||||
choices=arg_to_scheduler_choices,
|
||||
metavar=arg_to_scheduler_metavar,
|
||||
type=str,
|
||||
help="Learning rate scheduler",
|
||||
)
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
|
||||
parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int)
|
||||
parser.add_argument("--train_batch_size", default=32, type=int)
|
||||
parser.add_argument("--eval_batch_size", default=32, type=int)
|
||||
parser.add_argument("--adafactor", action="store_true")
|
||||
|
||||
|
||||
class LoggingCallback(pl.Callback):
|
||||
def on_batch_end(self, trainer, pl_module):
|
||||
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
|
||||
lrs = {f"lr_group_{i}": lr for i, lr in enumerate(lr_scheduler.get_lr())}
|
||||
pl_module.logger.log_metrics(lrs)
|
||||
|
||||
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
rank_zero_info("***** Validation results *****")
|
||||
metrics = trainer.callback_metrics
|
||||
# Log results
|
||||
for key in sorted(metrics):
|
||||
if key not in ["log", "progress_bar"]:
|
||||
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
|
||||
|
||||
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
rank_zero_info("***** Test results *****")
|
||||
metrics = trainer.callback_metrics
|
||||
# Log and save results to file
|
||||
output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
|
||||
with open(output_test_results_file, "w") as writer:
|
||||
for key in sorted(metrics):
|
||||
if key not in ["log", "progress_bar"]:
|
||||
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
|
||||
writer.write("{} = {}\n".format(key, str(metrics[key])))
|
||||
|
||||
|
||||
def add_generic_args(parser, root_dir) -> None:
|
||||
# To allow all pl args uncomment the following line
|
||||
# parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O2",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
|
||||
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.")
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
dest="accumulate_grad_batches",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
|
||||
)
|
||||
|
||||
|
||||
def generic_train(
|
||||
model: BaseTransformer,
|
||||
args: argparse.Namespace,
|
||||
early_stopping_callback=None,
|
||||
logger=True, # can pass WandbLogger() here
|
||||
extra_callbacks=[],
|
||||
checkpoint_callback=None,
|
||||
logging_callback=None,
|
||||
**extra_train_kwargs
|
||||
):
|
||||
pl.seed_everything(args.seed)
|
||||
|
||||
# init model
|
||||
odir = Path(model.hparams.output_dir)
|
||||
odir.mkdir(exist_ok=True)
|
||||
|
||||
# add custom checkpoints
|
||||
if checkpoint_callback is None:
|
||||
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
||||
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
|
||||
)
|
||||
if early_stopping_callback:
|
||||
extra_callbacks.append(early_stopping_callback)
|
||||
if logging_callback is None:
|
||||
logging_callback = LoggingCallback()
|
||||
|
||||
train_params = {}
|
||||
|
||||
# TODO: remove with PyTorch 1.6 since pl uses native amp
|
||||
if args.fp16:
|
||||
train_params["precision"] = 16
|
||||
train_params["amp_level"] = args.fp16_opt_level
|
||||
|
||||
if args.gpus > 1:
|
||||
train_params["distributed_backend"] = "ddp"
|
||||
|
||||
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
|
||||
train_params["accelerator"] = extra_train_kwargs.get("accelerator", None)
|
||||
train_params["profiler"] = extra_train_kwargs.get("profiler", None)
|
||||
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
weights_summary=None,
|
||||
callbacks=[logging_callback] + extra_callbacks,
|
||||
logger=logger,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
**train_params,
|
||||
)
|
||||
|
||||
if args.do_train:
|
||||
trainer.fit(model)
|
||||
|
||||
return trainer
|
||||
47
examples/research_projects/rag/parse_dpr_relevance_data.py
Normal file
47
examples/research_projects/rag/parse_dpr_relevance_data.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
This script reads DPR retriever training data and parses each datapoint. We save a line per datapoint.
|
||||
Each line consists of the query followed by a tab-separated list of Wikipedia page titles constituting
|
||||
positive contexts for a given query.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--src_path",
|
||||
type=str,
|
||||
default="biencoder-nq-dev.json",
|
||||
help="Path to raw DPR training data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--evaluation_set",
|
||||
type=str,
|
||||
help="where to store parsed evaluation_set file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gold_data_path",
|
||||
type=str,
|
||||
help="where to store parsed gold_data_path file",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.src_path, "r") as src_file, open(args.evaluation_set, "w") as eval_file, open(
|
||||
args.gold_data_path, "w"
|
||||
) as gold_file:
|
||||
dpr_records = json.load(src_file)
|
||||
for dpr_record in tqdm(dpr_records):
|
||||
question = dpr_record["question"]
|
||||
contexts = [context["title"] for context in dpr_record["positive_ctxs"]]
|
||||
eval_file.write(question + "\n")
|
||||
gold_file.write("\t".join(contexts) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
6
examples/research_projects/rag/requirements.txt
Normal file
6
examples/research_projects/rag/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
faiss-cpu >= 1.6.3
|
||||
datasets >= 1.0.1
|
||||
psutil >= 5.7.0
|
||||
torch >= 1.4.0
|
||||
transformers
|
||||
pytorch-lightning==1.0.4
|
||||
@@ -0,0 +1,2 @@
|
||||
Aaron Aaron Aaron ( or ; "Ahärôn") is a prophet, high priest, and the brother of Moses in the Abrahamic religions. Knowledge of Aaron, along with his brother Moses, comes exclusively from religious texts, such as the Bible and Quran. The Hebrew Bible relates that, unlike Moses, who grew up in the Egyptian royal court, Aaron and his elder sister Miriam remained with their kinsmen in the eastern border-land of Egypt (Goshen). When Moses first confronted the Egyptian king about the Israelites, Aaron served as his brother's spokesman ("prophet") to the Pharaoh. Part of the Law (Torah) that Moses received from God at Sinai granted Aaron the priesthood for himself and his male descendants, and he became the first High Priest of the Israelites. Aaron died before the Israelites crossed the North Jordan river and he was buried on Mount Hor (Numbers 33:39; Deuteronomy 10:6 says he died and was buried at Moserah). Aaron is also mentioned in the New Testament of the Bible. According to the Book of Exodus, Aaron first functioned as Moses' assistant. Because Moses complained that he could not speak well, God appointed Aaron as Moses' "prophet" (Exodus 4:10-17; 7:1). At the command of Moses, he let his rod turn into a snake. Then he stretched out his rod in order to bring on the first three plagues. After that, Moses tended to act and speak for himself. During the journey in the wilderness, Aaron was not always prominent or active. At the battle with Amalek, he was chosen with Hur to support the hand of Moses that held the "rod of God". When the revelation was given to Moses at biblical Mount Sinai, he headed the elders of Israel who accompanied Moses on the way to the summit.
|
||||
"Pokémon" Pokémon , also known as in Japan, is a media franchise managed by The Pokémon Company, a Japanese consortium between Nintendo, Game Freak, and Creatures. The franchise copyright is shared by all three companies, but Nintendo is the sole owner of the trademark. The franchise was created by Satoshi Tajiri in 1995, and is centered on fictional creatures called "Pokémon", which humans, known as Pokémon Trainers, catch and train to battle each other for sport. The English slogan for the franchise is "Gotta Catch 'Em All". Works within the franchise are set in the Pokémon universe. The franchise began as "Pokémon Red" and "Green" (released outside of Japan as "Pokémon Red" and "Blue"), a pair of video games for the original Game Boy that were developed by Game Freak and published by Nintendo in February 1996. "Pokémon" has since gone on to become the highest-grossing media franchise of all time, with over in revenue up until March 2017. The original video game series is the second best-selling video game franchise (behind Nintendo's "Mario" franchise) with more than 300million copies sold and over 800million mobile downloads. In addition, the "Pokémon" franchise includes the world's top-selling toy brand, the top-selling trading card game with over 25.7billion cards sold, an anime television series that has become the most successful video game adaptation with over 20 seasons and 1,000 episodes in 124 countries, as well as an anime film series, a , books, manga comics, music, and merchandise. The franchise is also represented in other Nintendo media, such as the "Super Smash Bros." series. In November 2005, 4Kids Entertainment, which had managed the non-game related licensing of "Pokémon", announced that it had agreed not to renew the "Pokémon" representation agreement. The Pokémon Company International oversees all "Pokémon" licensing outside Asia.
|
||||
|
Can't render this file because it contains an unexpected character in line 1 and column 35.
|
224
examples/research_projects/rag/test_distributed_retriever.py
Normal file
224
examples/research_projects/rag/test_distributed_retriever.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
from datasets import Dataset
|
||||
|
||||
import faiss
|
||||
from transformers import BartConfig, BartTokenizer, DPRConfig, DPRQuestionEncoderTokenizer, RagConfig
|
||||
from transformers.file_utils import is_datasets_available, is_faiss_available, is_psutil_available, is_torch_available
|
||||
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
||||
from transformers.models.rag.retrieval_rag import CustomHFIndex
|
||||
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
|
||||
from transformers.testing_utils import require_torch_non_multi_gpu_but_fix_me
|
||||
|
||||
|
||||
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
|
||||
|
||||
from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||
|
||||
|
||||
def require_distributed_retrieval(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with
|
||||
:class:`~transformers.RagRetriever`.
|
||||
|
||||
These tests are skipped when respective libraries are not installed.
|
||||
|
||||
"""
|
||||
if not (is_torch_available() and is_datasets_available() and is_faiss_available() and is_psutil_available()):
|
||||
test_case = unittest.skip("test requires PyTorch, Datasets, Faiss, psutil")(test_case)
|
||||
return test_case
|
||||
|
||||
|
||||
@require_distributed_retrieval
|
||||
class RagRetrieverTest(TestCase):
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
self.retrieval_vector_size = 8
|
||||
|
||||
# DPR tok
|
||||
vocab_tokens = [
|
||||
"[UNK]",
|
||||
"[CLS]",
|
||||
"[SEP]",
|
||||
"[PAD]",
|
||||
"[MASK]",
|
||||
"want",
|
||||
"##want",
|
||||
"##ed",
|
||||
"wa",
|
||||
"un",
|
||||
"runn",
|
||||
"##ing",
|
||||
",",
|
||||
"low",
|
||||
"lowest",
|
||||
]
|
||||
dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer")
|
||||
os.makedirs(dpr_tokenizer_path, exist_ok=True)
|
||||
self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
# BART tok
|
||||
vocab = [
|
||||
"l",
|
||||
"o",
|
||||
"w",
|
||||
"e",
|
||||
"r",
|
||||
"s",
|
||||
"t",
|
||||
"i",
|
||||
"d",
|
||||
"n",
|
||||
"\u0120",
|
||||
"\u0120l",
|
||||
"\u0120n",
|
||||
"\u0120lo",
|
||||
"\u0120low",
|
||||
"er",
|
||||
"\u0120lowest",
|
||||
"\u0120newer",
|
||||
"\u0120wider",
|
||||
"<unk>",
|
||||
]
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
|
||||
self.special_tokens_map = {"unk_token": "<unk>"}
|
||||
|
||||
bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer")
|
||||
os.makedirs(bart_tokenizer_path, exist_ok=True)
|
||||
self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"])
|
||||
self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
||||
fp.write(json.dumps(vocab_tokens) + "\n")
|
||||
with open(self.merges_file, "w", encoding="utf-8") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
|
||||
return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
|
||||
|
||||
def get_bart_tokenizer(self) -> BartTokenizer:
|
||||
return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer"))
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def get_dummy_dataset(self):
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"id": ["0", "1"],
|
||||
"text": ["foo", "bar"],
|
||||
"title": ["Foo", "Bar"],
|
||||
"embeddings": [np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size)],
|
||||
}
|
||||
)
|
||||
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
|
||||
return dataset
|
||||
|
||||
def get_dummy_pytorch_distributed_retriever(
|
||||
self, init_retrieval: bool, port=12345
|
||||
) -> RagPyTorchDistributedRetriever:
|
||||
dataset = self.get_dummy_dataset()
|
||||
config = RagConfig(
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
question_encoder=DPRConfig().to_dict(),
|
||||
generator=BartConfig().to_dict(),
|
||||
)
|
||||
with patch("transformers.models.rag.retrieval_rag.load_dataset") as mock_load_dataset:
|
||||
mock_load_dataset.return_value = dataset
|
||||
retriever = RagPyTorchDistributedRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||
generator_tokenizer=self.get_bart_tokenizer(),
|
||||
)
|
||||
if init_retrieval:
|
||||
retriever.init_retrieval(port)
|
||||
return retriever
|
||||
|
||||
def get_dummy_custom_hf_index_retriever(self, init_retrieval: bool, from_disk: bool, port=12345):
|
||||
dataset = self.get_dummy_dataset()
|
||||
config = RagConfig(
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
question_encoder=DPRConfig().to_dict(),
|
||||
generator=BartConfig().to_dict(),
|
||||
index_name="custom",
|
||||
)
|
||||
if from_disk:
|
||||
config.passages_path = os.path.join(self.tmpdirname, "dataset")
|
||||
config.index_path = os.path.join(self.tmpdirname, "index.faiss")
|
||||
dataset.get_index("embeddings").save(os.path.join(self.tmpdirname, "index.faiss"))
|
||||
dataset.drop_index("embeddings")
|
||||
dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset"))
|
||||
del dataset
|
||||
retriever = RagPyTorchDistributedRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||
generator_tokenizer=self.get_bart_tokenizer(),
|
||||
)
|
||||
else:
|
||||
retriever = RagPyTorchDistributedRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||
generator_tokenizer=self.get_bart_tokenizer(),
|
||||
index=CustomHFIndex(config.retrieval_vector_size, dataset),
|
||||
)
|
||||
if init_retrieval:
|
||||
retriever.init_retrieval(port)
|
||||
return retriever
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_pytorch_distributed_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_pytorch_distributed_retriever(init_retrieval=True)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
|
||||
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_custom_hf_index_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_custom_hf_index_retriever(init_retrieval=True, from_disk=False)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
|
||||
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_custom_pytorch_distributed_retriever_retrieve_from_disk(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_custom_hf_index_retriever(init_retrieval=True, from_disk=True)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
|
||||
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
204
examples/research_projects/rag/use_own_knowledge_dataset.py
Normal file
204
examples/research_projects/rag/use_own_knowledge_dataset.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from datasets import Features, Sequence, Value, load_dataset
|
||||
|
||||
import faiss
|
||||
from transformers import (
|
||||
DPRContextEncoder,
|
||||
DPRContextEncoderTokenizerFast,
|
||||
HfArgumentParser,
|
||||
RagRetriever,
|
||||
RagSequenceForGeneration,
|
||||
RagTokenizer,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
torch.set_grad_enabled(False)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def split_text(text: str, n=100, character=" ") -> List[str]:
|
||||
"""Split the text every ``n``-th occurrence of ``character``"""
|
||||
text = text.split(character)
|
||||
return [character.join(text[i : i + n]).strip() for i in range(0, len(text), n)]
|
||||
|
||||
|
||||
def split_documents(documents: dict) -> dict:
|
||||
"""Split documents into passages"""
|
||||
titles, texts = [], []
|
||||
for title, text in zip(documents["title"], documents["text"]):
|
||||
if text is not None:
|
||||
for passage in split_text(text):
|
||||
titles.append(title if title is not None else "")
|
||||
texts.append(passage)
|
||||
return {"title": titles, "text": texts}
|
||||
|
||||
|
||||
def embed(documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizerFast) -> dict:
|
||||
"""Compute the DPR embeddings of document passages"""
|
||||
input_ids = ctx_tokenizer(
|
||||
documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt"
|
||||
)["input_ids"]
|
||||
embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output
|
||||
return {"embeddings": embeddings.detach().cpu().numpy()}
|
||||
|
||||
|
||||
def main(
|
||||
rag_example_args: "RagExampleArguments",
|
||||
processing_args: "ProcessingArguments",
|
||||
index_hnsw_args: "IndexHnswArguments",
|
||||
):
|
||||
|
||||
######################################
|
||||
logger.info("Step 1 - Create the dataset")
|
||||
######################################
|
||||
|
||||
# The dataset needed for RAG must have three columns:
|
||||
# - title (string): title of the document
|
||||
# - text (string): text of a passage of the document
|
||||
# - embeddings (array of dimension d): DPR representation of the passage
|
||||
|
||||
# Let's say you have documents in tab-separated csv files with columns "title" and "text"
|
||||
assert os.path.isfile(rag_example_args.csv_path), "Please provide a valid path to a csv file"
|
||||
|
||||
# You can load a Dataset object this way
|
||||
dataset = load_dataset(
|
||||
"csv", data_files=[rag_example_args.csv_path], split="train", delimiter="\t", column_names=["title", "text"]
|
||||
)
|
||||
|
||||
# More info about loading csv files in the documentation: https://huggingface.co/docs/datasets/loading_datasets.html?highlight=csv#csv-files
|
||||
|
||||
# Then split the documents into passages of 100 words
|
||||
dataset = dataset.map(split_documents, batched=True, num_proc=processing_args.num_proc)
|
||||
|
||||
# And compute the embeddings
|
||||
ctx_encoder = DPRContextEncoder.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name).to(device=device)
|
||||
ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name)
|
||||
new_features = Features(
|
||||
{"text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32"))}
|
||||
) # optional, save as float32 instead of float64 to save space
|
||||
dataset = dataset.map(
|
||||
partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
|
||||
batched=True,
|
||||
batch_size=processing_args.batch_size,
|
||||
features=new_features,
|
||||
)
|
||||
|
||||
# And finally save your dataset
|
||||
passages_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset")
|
||||
dataset.save_to_disk(passages_path)
|
||||
# from datasets import load_from_disk
|
||||
# dataset = load_from_disk(passages_path) # to reload the dataset
|
||||
|
||||
######################################
|
||||
logger.info("Step 2 - Index the dataset")
|
||||
######################################
|
||||
|
||||
# Let's use the Faiss implementation of HNSW for fast approximate nearest neighbor search
|
||||
index = faiss.IndexHNSWFlat(index_hnsw_args.d, index_hnsw_args.m, faiss.METRIC_INNER_PRODUCT)
|
||||
dataset.add_faiss_index("embeddings", custom_index=index)
|
||||
|
||||
# And save the index
|
||||
index_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset_hnsw_index.faiss")
|
||||
dataset.get_index("embeddings").save(index_path)
|
||||
# dataset.load_faiss_index("embeddings", index_path) # to reload the index
|
||||
|
||||
######################################
|
||||
logger.info("Step 3 - Load RAG")
|
||||
######################################
|
||||
|
||||
# Easy way to load the model
|
||||
retriever = RagRetriever.from_pretrained(
|
||||
rag_example_args.rag_model_name, index_name="custom", indexed_dataset=dataset
|
||||
)
|
||||
model = RagSequenceForGeneration.from_pretrained(rag_example_args.rag_model_name, retriever=retriever)
|
||||
tokenizer = RagTokenizer.from_pretrained(rag_example_args.rag_model_name)
|
||||
|
||||
# For distributed fine-tuning you'll need to provide the paths instead, as the dataset and the index are loaded separately.
|
||||
# retriever = RagRetriever.from_pretrained(rag_model_name, index_name="custom", passages_path=passages_path, index_path=index_path)
|
||||
|
||||
######################################
|
||||
logger.info("Step 4 - Have fun")
|
||||
######################################
|
||||
|
||||
question = rag_example_args.question or "What does Moses' rod turn into ?"
|
||||
input_ids = tokenizer.question_encoder(question, return_tensors="pt")["input_ids"]
|
||||
generated = model.generate(input_ids)
|
||||
generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
|
||||
logger.info("Q: " + question)
|
||||
logger.info("A: " + generated_string)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RagExampleArguments:
|
||||
csv_path: str = field(
|
||||
default=str(Path(__file__).parent / "test_data" / "my_knowledge_dataset.csv"),
|
||||
metadata={"help": "Path to a tab-separated csv file with columns 'title' and 'text'"},
|
||||
)
|
||||
question: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Question that is passed as input to RAG. Default is 'What does Moses' rod turn into ?'."},
|
||||
)
|
||||
rag_model_name: str = field(
|
||||
default="facebook/rag-sequence-nq",
|
||||
metadata={"help": "The RAG model to use. Either 'facebook/rag-sequence-nq' or 'facebook/rag-token-nq'"},
|
||||
)
|
||||
dpr_ctx_encoder_model_name: str = field(
|
||||
default="facebook/dpr-ctx_encoder-multiset-base",
|
||||
metadata={
|
||||
"help": "The DPR context encoder model to use. Either 'facebook/dpr-ctx_encoder-single-nq-base' or 'facebook/dpr-ctx_encoder-multiset-base'"
|
||||
},
|
||||
)
|
||||
output_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to a directory where the dataset passages and the index will be saved"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessingArguments:
|
||||
num_proc: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The number of processes to use to split the documents into passages. Default is single process."
|
||||
},
|
||||
)
|
||||
batch_size: int = field(
|
||||
default=16,
|
||||
metadata={
|
||||
"help": "The batch size to use when computing the passages embeddings using the DPR context encoder."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexHnswArguments:
|
||||
d: int = field(
|
||||
default=768,
|
||||
metadata={"help": "The dimension of the embeddings to pass to the HNSW Faiss index."},
|
||||
)
|
||||
m: int = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The number of bi-directional links created for every new element during the HNSW index construction."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
parser = HfArgumentParser((RagExampleArguments, ProcessingArguments, IndexHnswArguments))
|
||||
rag_example_args, processing_args, index_hnsw_args = parser.parse_args_into_dataclasses()
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
rag_example_args.output_dir = rag_example_args.output_dir or tmp_dir
|
||||
main(rag_example_args, processing_args, index_hnsw_args)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user