Remove research projects (#36645)
* Remove research projects * Add new README to explain where the projects went * Trigger tests * Cleanup all references to research_projects
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
<!---
|
||||
Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -16,13 +16,5 @@ limitations under the License.
|
||||
|
||||
# Research projects
|
||||
|
||||
This folder contains various research projects using 🤗 Transformers. They are not maintained and require a specific
|
||||
version of 🤗 Transformers that is indicated in the requirements file of each folder. Updating them to the most recent version of the library will require some work.
|
||||
|
||||
To use any of them, just run the command
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
inside the folder of your choice.
|
||||
|
||||
If you need help with any of those, contact the author(s), indicated at the top of the `README` of each folder.
|
||||
This directory previously contained various research projects using 🤗 Transformers. They have been moved
|
||||
to a separate repo, so you can now find them at https://github.com/huggingface/transformers-research-projects/
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
## Adversarial evaluation of model performances
|
||||
|
||||
Here is an example on evaluating a model using adversarial evaluation of natural language inference with the Heuristic Analysis for NLI Systems (HANS) dataset [McCoy et al., 2019](https://arxiv.org/abs/1902.01007). The example was gracefully provided by [Nafise Sadat Moosavi](https://github.com/ns-moosavi).
|
||||
|
||||
The HANS dataset can be downloaded from [this location](https://github.com/tommccoy1/hans).
|
||||
|
||||
This is an example of using test_hans.py:
|
||||
|
||||
```bash
|
||||
export HANS_DIR=path-to-hans
|
||||
export MODEL_TYPE=type-of-the-model-e.g.-bert-roberta-xlnet-etc
|
||||
export MODEL_PATH=path-to-the-model-directory-that-is-trained-on-NLI-e.g.-by-using-run_glue.py
|
||||
|
||||
python run_hans.py \
|
||||
--task_name hans \
|
||||
--model_type $MODEL_TYPE \
|
||||
--do_eval \
|
||||
--data_dir $HANS_DIR \
|
||||
--model_name_or_path $MODEL_PATH \
|
||||
--max_seq_length 128 \
|
||||
--output_dir $MODEL_PATH \
|
||||
```
|
||||
|
||||
This will create the hans_predictions.txt file in MODEL_PATH, which can then be evaluated using hans/evaluate_heur_output.py from the HANS dataset.
|
||||
|
||||
The results of the BERT-base model that is trained on MNLI using batch size 8 and the random seed 42 on the HANS dataset is as follows:
|
||||
|
||||
```bash
|
||||
Heuristic entailed results:
|
||||
lexical_overlap: 0.9702
|
||||
subsequence: 0.9942
|
||||
constituent: 0.9962
|
||||
|
||||
Heuristic non-entailed results:
|
||||
lexical_overlap: 0.199
|
||||
subsequence: 0.0396
|
||||
constituent: 0.118
|
||||
```
|
||||
@@ -1 +0,0 @@
|
||||
transformers == 4.48.0
|
||||
@@ -1,242 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Finetuning the library models for sequence classification on HANS."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from utils_hans import HansDataset, InputFeatures, hans_processors, hans_tasks_num_labels
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
default_data_collator,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.trainer_utils import is_main_process
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
task_name: str = field(
|
||||
metadata={"help": "The name of the task to train selected in the list: " + ", ".join(hans_processors.keys())}
|
||||
)
|
||||
data_dir: str = field(
|
||||
metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."}
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": (
|
||||
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
)
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
|
||||
|
||||
def hans_data_collator(features: List[InputFeatures]) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Data collator that removes the "pairID" key if present.
|
||||
"""
|
||||
batch = default_data_collator(features)
|
||||
_ = batch.pop("pairID", None)
|
||||
return batch
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
if (
|
||||
os.path.exists(training_args.output_dir)
|
||||
and os.listdir(training_args.output_dir)
|
||||
and training_args.do_train
|
||||
and not training_args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use"
|
||||
" --overwrite_output_dir to overcome."
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
training_args.local_rank,
|
||||
training_args.device,
|
||||
training_args.n_gpu,
|
||||
bool(training_args.local_rank != -1),
|
||||
training_args.fp16,
|
||||
)
|
||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||
if is_main_process(training_args.local_rank):
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
logger.info("Training/evaluation parameters %s", training_args)
|
||||
|
||||
# Set seed
|
||||
set_seed(training_args.seed)
|
||||
|
||||
try:
|
||||
num_labels = hans_tasks_num_labels[data_args.task_name]
|
||||
except KeyError:
|
||||
raise ValueError("Task not found: %s" % (data_args.task_name))
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
#
|
||||
# Distributed training:
|
||||
# The .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=data_args.task_name,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
|
||||
# Get datasets
|
||||
train_dataset = (
|
||||
HansDataset(
|
||||
data_dir=data_args.data_dir,
|
||||
tokenizer=tokenizer,
|
||||
task=data_args.task_name,
|
||||
max_seq_length=data_args.max_seq_length,
|
||||
overwrite_cache=data_args.overwrite_cache,
|
||||
)
|
||||
if training_args.do_train
|
||||
else None
|
||||
)
|
||||
eval_dataset = (
|
||||
HansDataset(
|
||||
data_dir=data_args.data_dir,
|
||||
tokenizer=tokenizer,
|
||||
task=data_args.task_name,
|
||||
max_seq_length=data_args.max_seq_length,
|
||||
overwrite_cache=data_args.overwrite_cache,
|
||||
evaluate=True,
|
||||
)
|
||||
if training_args.do_eval
|
||||
else None
|
||||
)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
data_collator=hans_data_collator,
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
trainer.train(
|
||||
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
|
||||
)
|
||||
trainer.save_model()
|
||||
# For convenience, we also re-save the tokenizer to the same directory,
|
||||
# so that you can share your model easily on huggingface.co/models =)
|
||||
if trainer.is_world_master():
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
output = trainer.predict(eval_dataset)
|
||||
preds = output.predictions
|
||||
preds = np.argmax(preds, axis=1)
|
||||
|
||||
pair_ids = [ex.pairID for ex in eval_dataset]
|
||||
output_eval_file = os.path.join(training_args.output_dir, "hans_predictions.txt")
|
||||
label_list = eval_dataset.get_labels()
|
||||
if trainer.is_world_master():
|
||||
with open(output_eval_file, "w") as writer:
|
||||
writer.write("pairID,gold_label\n")
|
||||
for pid, pred in zip(pair_ids, preds):
|
||||
writer.write("ex" + str(pid) + "," + label_list[int(pred)] + "\n")
|
||||
|
||||
trainer._log(output.metrics)
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,339 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import tqdm
|
||||
from filelock import FileLock
|
||||
|
||||
from transformers import (
|
||||
BartTokenizer,
|
||||
BartTokenizerFast,
|
||||
DataProcessor,
|
||||
PreTrainedTokenizer,
|
||||
RobertaTokenizer,
|
||||
RobertaTokenizerFast,
|
||||
XLMRobertaTokenizer,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InputExample:
|
||||
"""
|
||||
A single training/test example for simple sequence classification.
|
||||
|
||||
Args:
|
||||
guid: Unique id for the example.
|
||||
text_a: string. The untokenized text of the first sequence. For single
|
||||
sequence tasks, only this sequence must be specified.
|
||||
text_b: (Optional) string. The untokenized text of the second sequence.
|
||||
Only must be specified for sequence pair tasks.
|
||||
label: (Optional) string. The label of the example. This should be
|
||||
specified for train and dev examples, but not for test examples.
|
||||
pairID: (Optional) string. Unique identifier for the pair of sentences.
|
||||
"""
|
||||
|
||||
guid: str
|
||||
text_a: str
|
||||
text_b: Optional[str] = None
|
||||
label: Optional[str] = None
|
||||
pairID: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InputFeatures:
|
||||
"""
|
||||
A single set of features of data.
|
||||
Property names are the same names as the corresponding inputs to a model.
|
||||
|
||||
Args:
|
||||
input_ids: Indices of input sequence tokens in the vocabulary.
|
||||
attention_mask: Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
Usually ``1`` for tokens that are NOT MASKED, ``0`` for MASKED (padded) tokens.
|
||||
token_type_ids: (Optional) Segment token indices to indicate first and second
|
||||
portions of the inputs. Only some models use them.
|
||||
label: (Optional) Label corresponding to the input. Int for classification problems,
|
||||
float for regression problems.
|
||||
pairID: (Optional) Unique identifier for the pair of sentences.
|
||||
"""
|
||||
|
||||
input_ids: List[int]
|
||||
attention_mask: Optional[List[int]] = None
|
||||
token_type_ids: Optional[List[int]] = None
|
||||
label: Optional[Union[int, float]] = None
|
||||
pairID: Optional[int] = None
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
class HansDataset(Dataset):
|
||||
"""
|
||||
This will be superseded by a framework-agnostic approach
|
||||
soon.
|
||||
"""
|
||||
|
||||
features: List[InputFeatures]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
task: str,
|
||||
max_seq_length: Optional[int] = None,
|
||||
overwrite_cache=False,
|
||||
evaluate: bool = False,
|
||||
):
|
||||
processor = hans_processors[task]()
|
||||
|
||||
cached_features_file = os.path.join(
|
||||
data_dir,
|
||||
"cached_{}_{}_{}_{}".format(
|
||||
"dev" if evaluate else "train",
|
||||
tokenizer.__class__.__name__,
|
||||
str(max_seq_length),
|
||||
task,
|
||||
),
|
||||
)
|
||||
label_list = processor.get_labels()
|
||||
if tokenizer.__class__ in (
|
||||
RobertaTokenizer,
|
||||
RobertaTokenizerFast,
|
||||
XLMRobertaTokenizer,
|
||||
BartTokenizer,
|
||||
BartTokenizerFast,
|
||||
):
|
||||
# HACK(label indices are swapped in RoBERTa pretrained model)
|
||||
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||
self.label_list = label_list
|
||||
|
||||
# Make sure only the first process in distributed training processes the dataset,
|
||||
# and the others will use the cache.
|
||||
lock_path = cached_features_file + ".lock"
|
||||
with FileLock(lock_path):
|
||||
if os.path.exists(cached_features_file) and not overwrite_cache:
|
||||
logger.info(f"Loading features from cached file {cached_features_file}")
|
||||
self.features = torch.load(cached_features_file)
|
||||
else:
|
||||
logger.info(f"Creating features from dataset file at {data_dir}")
|
||||
|
||||
examples = (
|
||||
processor.get_dev_examples(data_dir) if evaluate else processor.get_train_examples(data_dir)
|
||||
)
|
||||
|
||||
logger.info("Training examples: %s", len(examples))
|
||||
self.features = hans_convert_examples_to_features(examples, label_list, max_seq_length, tokenizer)
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(self.features, cached_features_file)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.features)
|
||||
|
||||
def __getitem__(self, i) -> InputFeatures:
|
||||
return self.features[i]
|
||||
|
||||
def get_labels(self):
|
||||
return self.label_list
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
class TFHansDataset:
|
||||
"""
|
||||
This will be superseded by a framework-agnostic approach
|
||||
soon.
|
||||
"""
|
||||
|
||||
features: List[InputFeatures]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
task: str,
|
||||
max_seq_length: Optional[int] = 128,
|
||||
overwrite_cache=False,
|
||||
evaluate: bool = False,
|
||||
):
|
||||
processor = hans_processors[task]()
|
||||
label_list = processor.get_labels()
|
||||
if tokenizer.__class__ in (
|
||||
RobertaTokenizer,
|
||||
RobertaTokenizerFast,
|
||||
XLMRobertaTokenizer,
|
||||
BartTokenizer,
|
||||
BartTokenizerFast,
|
||||
):
|
||||
# HACK(label indices are swapped in RoBERTa pretrained model)
|
||||
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||
self.label_list = label_list
|
||||
|
||||
examples = processor.get_dev_examples(data_dir) if evaluate else processor.get_train_examples(data_dir)
|
||||
self.features = hans_convert_examples_to_features(examples, label_list, max_seq_length, tokenizer)
|
||||
|
||||
def gen():
|
||||
for ex_index, ex in tqdm.tqdm(enumerate(self.features), desc="convert examples to features"):
|
||||
if ex_index % 10000 == 0:
|
||||
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
|
||||
|
||||
yield (
|
||||
{
|
||||
"example_id": 0,
|
||||
"input_ids": ex.input_ids,
|
||||
"attention_mask": ex.attention_mask,
|
||||
"token_type_ids": ex.token_type_ids,
|
||||
},
|
||||
ex.label,
|
||||
)
|
||||
|
||||
self.dataset = tf.data.Dataset.from_generator(
|
||||
gen,
|
||||
(
|
||||
{
|
||||
"example_id": tf.int32,
|
||||
"input_ids": tf.int32,
|
||||
"attention_mask": tf.int32,
|
||||
"token_type_ids": tf.int32,
|
||||
},
|
||||
tf.int64,
|
||||
),
|
||||
(
|
||||
{
|
||||
"example_id": tf.TensorShape([]),
|
||||
"input_ids": tf.TensorShape([None, None]),
|
||||
"attention_mask": tf.TensorShape([None, None]),
|
||||
"token_type_ids": tf.TensorShape([None, None]),
|
||||
},
|
||||
tf.TensorShape([]),
|
||||
),
|
||||
)
|
||||
|
||||
def get_dataset(self):
|
||||
return self.dataset
|
||||
|
||||
def __len__(self):
|
||||
return len(self.features)
|
||||
|
||||
def __getitem__(self, i) -> InputFeatures:
|
||||
return self.features[i]
|
||||
|
||||
def get_labels(self):
|
||||
return self.label_list
|
||||
|
||||
|
||||
class HansProcessor(DataProcessor):
|
||||
"""Processor for the HANS data set."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "heuristics_train_set.txt")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "heuristics_evaluation_set.txt")), "dev")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class.
|
||||
Note that we follow the standard three labels for MNLI
|
||||
(see :class:`~transformers.data.processors.utils.MnliProcessor`)
|
||||
but the HANS evaluation groups `contradiction` and `neutral` into `non-entailment` (label 0) while
|
||||
`entailment` is label 1."""
|
||||
return ["contradiction", "entailment", "neutral"]
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
examples = []
|
||||
for i, line in enumerate(lines):
|
||||
if i == 0:
|
||||
continue
|
||||
guid = "%s-%s" % (set_type, line[0])
|
||||
text_a = line[5]
|
||||
text_b = line[6]
|
||||
pairID = line[7][2:] if line[7].startswith("ex") else line[7]
|
||||
label = line[0]
|
||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, pairID=pairID))
|
||||
return examples
|
||||
|
||||
|
||||
def hans_convert_examples_to_features(
|
||||
examples: List[InputExample],
|
||||
label_list: List[str],
|
||||
max_length: int,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
):
|
||||
"""
|
||||
Loads a data file into a list of ``InputFeatures``
|
||||
|
||||
Args:
|
||||
examples: List of ``InputExamples`` containing the examples.
|
||||
label_list: List of labels. Can be obtained from the processor using the ``processor.get_labels()`` method.
|
||||
max_length: Maximum example length.
|
||||
tokenizer: Instance of a tokenizer that will tokenize the examples.
|
||||
|
||||
Returns:
|
||||
A list of task-specific ``InputFeatures`` which can be fed to the model.
|
||||
|
||||
"""
|
||||
|
||||
label_map = {label: i for i, label in enumerate(label_list)}
|
||||
|
||||
features = []
|
||||
for ex_index, example in tqdm.tqdm(enumerate(examples), desc="convert examples to features"):
|
||||
if ex_index % 10000 == 0:
|
||||
logger.info("Writing example %d" % (ex_index))
|
||||
|
||||
inputs = tokenizer(
|
||||
example.text_a,
|
||||
example.text_b,
|
||||
add_special_tokens=True,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_overflowing_tokens=True,
|
||||
)
|
||||
|
||||
label = label_map[example.label] if example.label in label_map else 0
|
||||
|
||||
pairID = int(example.pairID)
|
||||
|
||||
features.append(InputFeatures(**inputs, label=label, pairID=pairID))
|
||||
|
||||
for i, example in enumerate(examples[:5]):
|
||||
logger.info("*** Example ***")
|
||||
logger.info(f"guid: {example}")
|
||||
logger.info(f"features: {features[i]}")
|
||||
|
||||
return features
|
||||
|
||||
|
||||
hans_tasks_num_labels = {
|
||||
"hans": 3,
|
||||
}
|
||||
|
||||
hans_processors = {
|
||||
"hans": HansProcessor,
|
||||
}
|
||||
@@ -1,89 +0,0 @@
|
||||
# Patience-based Early Exit
|
||||
|
||||
Patience-based Early Exit (PABEE) is a plug-and-play inference method for pretrained language models.
|
||||
We have already implemented it on BERT and ALBERT. Basically, you can make your LM faster and more robust with PABEE. It can even improve the performance of ALBERT on GLUE. The only sacrifice is that the batch size can only be 1.
|
||||
Learn more in the paper ["BERT Loses Patience: Fast and Robust Inference with Early Exit"](https://arxiv.org/abs/2006.04152) and the official [GitHub repo](https://github.com/JetRunner/PABEE).
|
||||
|
||||

|
||||
|
||||
## Training
|
||||
|
||||
You can fine-tune a pretrained language model (you can choose from BERT and ALBERT) and train the internal classifiers by:
|
||||
```bash
|
||||
export GLUE_DIR=/path/to/glue_data
|
||||
export TASK_NAME=MRPC
|
||||
|
||||
python ./run_glue_with_pabee.py \
|
||||
--model_type albert \
|
||||
--model_name_or_path google-bert/bert-base-uncased/albert/albert-base-v2 \
|
||||
--task_name $TASK_NAME \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--do_lower_case \
|
||||
--data_dir "$GLUE_DIR/$TASK_NAME" \
|
||||
--max_seq_length 128 \
|
||||
--per_gpu_train_batch_size 32 \
|
||||
--per_gpu_eval_batch_size 32 \
|
||||
--learning_rate 2e-5 \
|
||||
--save_steps 50 \
|
||||
--logging_steps 50 \
|
||||
--num_train_epochs 5 \
|
||||
--output_dir /path/to/save/ \
|
||||
--evaluate_during_training
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
You can inference with different patience settings by:
|
||||
```bash
|
||||
export GLUE_DIR=/path/to/glue_data
|
||||
export TASK_NAME=MRPC
|
||||
|
||||
python ./run_glue_with_pabee.py \
|
||||
--model_type albert \
|
||||
--model_name_or_path /path/to/save/ \
|
||||
--task_name $TASK_NAME \
|
||||
--do_eval \
|
||||
--do_lower_case \
|
||||
--data_dir "$GLUE_DIR/$TASK_NAME" \
|
||||
--max_seq_length 128 \
|
||||
--per_gpu_eval_batch_size 1 \
|
||||
--learning_rate 2e-5 \
|
||||
--logging_steps 50 \
|
||||
--num_train_epochs 15 \
|
||||
--output_dir /path/to/save/ \
|
||||
--eval_all_checkpoints \
|
||||
--patience 3,4,5,6,7,8
|
||||
```
|
||||
where `patience` can be a list of patience settings, separated by a comma. It will help determine which patience works best.
|
||||
|
||||
When evaluating on a regression task (STS-B), you may add `--regression_threshold 0.1` to define the regression threshold.
|
||||
|
||||
## Results
|
||||
On the GLUE dev set:
|
||||
|
||||
| Model | \#Param | Speed | CoLA | MNLI | MRPC | QNLI | QQP | RTE | SST\-2 | STS\-B |
|
||||
|--------------|---------|--------|-------|-------|-------|-------|-------|-------|--------|--------|
|
||||
| ALBERT\-base | 12M | | 58\.9 | 84\.6 | 89\.5 | 91\.7 | 89\.6 | 78\.6 | 92\.8 | 89\.5 |
|
||||
| \+PABEE | 12M | 1\.57x | 61\.2 | 85\.1 | 90\.0 | 91\.8 | 89\.6 | 80\.1 | 93\.0 | 90\.1 |
|
||||
|
||||
| Model | \#Param | Speed\-up | MNLI | SST\-2 | STS\-B |
|
||||
|---------------|---------|-----------|-------|--------|--------|
|
||||
| BERT\-base | 108M | | 84\.5 | 92\.1 | 88\.9 |
|
||||
| \+PABEE | 108M | 1\.62x | 83\.6 | 92\.0 | 88\.7 |
|
||||
| ALBERT\-large | 18M | | 86\.4 | 94\.9 | 90\.4 |
|
||||
| \+PABEE | 18M | 2\.42x | 86\.8 | 95\.2 | 90\.6 |
|
||||
|
||||
|
||||
## Citation
|
||||
If you find this resource useful, please consider citing the following paper:
|
||||
```bibtex
|
||||
@misc{zhou2020bert,
|
||||
title={BERT Loses Patience: Fast and Robust Inference with Early Exit},
|
||||
author={Wangchunshu Zhou and Canwen Xu and Tao Ge and Julian McAuley and Ke Xu and Furu Wei},
|
||||
year={2020},
|
||||
eprint={2006.04152},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL}
|
||||
}
|
||||
```
|
||||
@@ -1,320 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 Google AI, Google Brain, the HuggingFace Inc. team and Microsoft Corporation.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch ALBERT model with Patience-based Early Exit."""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from transformers.models.albert.modeling_albert import (
|
||||
ALBERT_INPUTS_DOCSTRING,
|
||||
ALBERT_START_DOCSTRING,
|
||||
AlbertModel,
|
||||
AlbertPreTrainedModel,
|
||||
AlbertTransformer,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AlbertTransformerWithPabee(AlbertTransformer):
|
||||
def adaptive_forward(self, hidden_states, current_layer, attention_mask=None, head_mask=None):
|
||||
if current_layer == 0:
|
||||
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
|
||||
else:
|
||||
hidden_states = hidden_states[0]
|
||||
|
||||
layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
|
||||
|
||||
# Index of the hidden group
|
||||
group_idx = int(current_layer / (self.config.num_hidden_layers / self.config.num_hidden_groups))
|
||||
|
||||
layer_group_output = self.albert_layer_groups[group_idx](
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
|
||||
)
|
||||
hidden_states = layer_group_output[0]
|
||||
|
||||
return (hidden_states,)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare ALBERT Model transformer with PABEE outputting raw hidden-states without any specific head on top.",
|
||||
ALBERT_START_DOCSTRING,
|
||||
)
|
||||
class AlbertModelWithPabee(AlbertModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.encoder = AlbertTransformerWithPabee(config)
|
||||
|
||||
self.init_weights()
|
||||
self.patience = 0
|
||||
self.inference_instances_num = 0
|
||||
self.inference_layers_num = 0
|
||||
|
||||
self.regression_threshold = 0
|
||||
|
||||
def set_regression_threshold(self, threshold):
|
||||
self.regression_threshold = threshold
|
||||
|
||||
def set_patience(self, patience):
|
||||
self.patience = patience
|
||||
|
||||
def reset_stats(self):
|
||||
self.inference_instances_num = 0
|
||||
self.inference_layers_num = 0
|
||||
|
||||
def log_stats(self):
|
||||
avg_inf_layers = self.inference_layers_num / self.inference_instances_num
|
||||
message = (
|
||||
f"*** Patience = {self.patience} Avg. Inference Layers = {avg_inf_layers:.2f} Speed Up ="
|
||||
f" {1 - avg_inf_layers / self.config.num_hidden_layers:.2f} ***"
|
||||
)
|
||||
print(message)
|
||||
|
||||
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_dropout=None,
|
||||
output_layers=None,
|
||||
regression=False,
|
||||
):
|
||||
r"""
|
||||
Return:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
|
||||
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
|
||||
Last layer hidden-state of the first token of the sequence (classification token)
|
||||
further processed by a Linear layer and a Tanh activation function. The Linear
|
||||
layer weights are trained from the next sentence prediction (classification)
|
||||
objective during pre-training.
|
||||
|
||||
This output is usually *not* a good summary
|
||||
of the semantic content of the input, you're often better with averaging or pooling
|
||||
the sequence of hidden-states for the whole input sequence.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(input_shape, device=device)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
||||
)
|
||||
encoder_outputs = embedding_output
|
||||
|
||||
if self.training:
|
||||
res = []
|
||||
for i in range(self.config.num_hidden_layers):
|
||||
encoder_outputs = self.encoder.adaptive_forward(
|
||||
encoder_outputs,
|
||||
current_layer=i,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
)
|
||||
|
||||
pooled_output = self.pooler_activation(self.pooler(encoder_outputs[0][:, 0]))
|
||||
logits = output_layers[i](output_dropout(pooled_output))
|
||||
res.append(logits)
|
||||
elif self.patience == 0: # Use all layers for inference
|
||||
encoder_outputs = self.encoder(encoder_outputs, extended_attention_mask, head_mask=head_mask)
|
||||
pooled_output = self.pooler_activation(self.pooler(encoder_outputs[0][:, 0]))
|
||||
res = [output_layers[self.config.num_hidden_layers - 1](pooled_output)]
|
||||
else:
|
||||
patient_counter = 0
|
||||
patient_result = None
|
||||
calculated_layer_num = 0
|
||||
for i in range(self.config.num_hidden_layers):
|
||||
calculated_layer_num += 1
|
||||
encoder_outputs = self.encoder.adaptive_forward(
|
||||
encoder_outputs,
|
||||
current_layer=i,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
)
|
||||
|
||||
pooled_output = self.pooler_activation(self.pooler(encoder_outputs[0][:, 0]))
|
||||
logits = output_layers[i](pooled_output)
|
||||
if regression:
|
||||
labels = logits.detach()
|
||||
if patient_result is not None:
|
||||
patient_labels = patient_result.detach()
|
||||
if (patient_result is not None) and torch.abs(patient_result - labels) < self.regression_threshold:
|
||||
patient_counter += 1
|
||||
else:
|
||||
patient_counter = 0
|
||||
else:
|
||||
labels = logits.detach().argmax(dim=1)
|
||||
if patient_result is not None:
|
||||
patient_labels = patient_result.detach().argmax(dim=1)
|
||||
if (patient_result is not None) and torch.all(labels.eq(patient_labels)):
|
||||
patient_counter += 1
|
||||
else:
|
||||
patient_counter = 0
|
||||
|
||||
patient_result = logits
|
||||
if patient_counter == self.patience:
|
||||
break
|
||||
res = [patient_result]
|
||||
self.inference_layers_num += calculated_layer_num
|
||||
self.inference_instances_num += 1
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""Albert Model transformer with PABEE and a sequence classification/regression head on top (a linear layer on top of
|
||||
the pooled output) e.g. for GLUE tasks. """,
|
||||
ALBERT_START_DOCSTRING,
|
||||
)
|
||||
class AlbertForSequenceClassificationWithPabee(AlbertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.albert = AlbertModelWithPabee(config)
|
||||
self.dropout = nn.Dropout(config.classifier_dropout_prob)
|
||||
self.classifiers = nn.ModuleList(
|
||||
[nn.Linear(config.hidden_size, self.config.num_labels) for _ in range(config.num_hidden_layers)]
|
||||
)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the sequence classification/regression loss.
|
||||
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
||||
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
|
||||
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Returns:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
|
||||
loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
|
||||
Examples::
|
||||
|
||||
from transformers import AlbertTokenizer
|
||||
from pabee import AlbertForSequenceClassificationWithPabee
|
||||
from torch import nn
|
||||
import torch
|
||||
|
||||
tokenizer = AlbertTokenizer.from_pretrained('albert/albert-base-v2')
|
||||
model = AlbertForSequenceClassificationWithPabee.from_pretrained('albert/albert-base-v2')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, labels=labels)
|
||||
loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
|
||||
logits = self.albert(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_dropout=self.dropout,
|
||||
output_layers=self.classifiers,
|
||||
regression=self.num_labels == 1,
|
||||
)
|
||||
|
||||
outputs = (logits[-1],)
|
||||
|
||||
if labels is not None:
|
||||
total_loss = None
|
||||
total_weights = 0
|
||||
for ix, logits_item in enumerate(logits):
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits_item.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits_item.view(-1, self.num_labels), labels.view(-1))
|
||||
if total_loss is None:
|
||||
total_loss = loss
|
||||
else:
|
||||
total_loss += loss * (ix + 1)
|
||||
total_weights += ix + 1
|
||||
outputs = (total_loss / total_weights,) + outputs
|
||||
|
||||
return outputs
|
||||
@@ -1,345 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Google AI Language Team Authors, The HuggingFace Inc. team and Microsoft Corporation.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch BERT model with Patience-based Early Exit."""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from transformers.models.bert.modeling_bert import (
|
||||
BERT_INPUTS_DOCSTRING,
|
||||
BERT_START_DOCSTRING,
|
||||
BertEncoder,
|
||||
BertModel,
|
||||
BertPreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BertEncoderWithPabee(BertEncoder):
|
||||
def adaptive_forward(self, hidden_states, current_layer, attention_mask=None, head_mask=None):
|
||||
layer_outputs = self.layer[current_layer](hidden_states, attention_mask, head_mask[current_layer])
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Bert Model transformer with PABEE outputting raw hidden-states without any specific head on top.",
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class BertModelWithPabee(BertModel):
|
||||
"""
|
||||
|
||||
The model can behave as an encoder (with only self-attention) as well
|
||||
as a decoder, in which case a layer of cross-attention is added between
|
||||
the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani,
|
||||
Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
||||
|
||||
To behave as a decoder the model needs to be initialized with the
|
||||
:obj:`is_decoder` argument of the configuration set to :obj:`True`; an
|
||||
:obj:`encoder_hidden_states` is expected as an input to the forward pass.
|
||||
|
||||
.. _`Attention is all you need`:
|
||||
https://arxiv.org/abs/1706.03762
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.encoder = BertEncoderWithPabee(config)
|
||||
|
||||
self.init_weights()
|
||||
self.patience = 0
|
||||
self.inference_instances_num = 0
|
||||
self.inference_layers_num = 0
|
||||
|
||||
self.regression_threshold = 0
|
||||
|
||||
def set_regression_threshold(self, threshold):
|
||||
self.regression_threshold = threshold
|
||||
|
||||
def set_patience(self, patience):
|
||||
self.patience = patience
|
||||
|
||||
def reset_stats(self):
|
||||
self.inference_instances_num = 0
|
||||
self.inference_layers_num = 0
|
||||
|
||||
def log_stats(self):
|
||||
avg_inf_layers = self.inference_layers_num / self.inference_instances_num
|
||||
message = (
|
||||
f"*** Patience = {self.patience} Avg. Inference Layers = {avg_inf_layers:.2f} Speed Up ="
|
||||
f" {1 - avg_inf_layers / self.config.num_hidden_layers:.2f} ***"
|
||||
)
|
||||
print(message)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
output_dropout=None,
|
||||
output_layers=None,
|
||||
regression=False,
|
||||
):
|
||||
r"""
|
||||
Return:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
|
||||
Last layer hidden-state of the first token of the sequence (classification token)
|
||||
further processed by a Linear layer and a Tanh activation function. The Linear
|
||||
layer weights are trained from the next sentence prediction (classification)
|
||||
objective during pre-training.
|
||||
|
||||
This output is usually *not* a good summary
|
||||
of the semantic content of the input, you're often better with averaging or pooling
|
||||
the sequence of hidden-states for the whole input sequence.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(input_shape, device=device)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
|
||||
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.is_decoder and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
||||
)
|
||||
encoder_outputs = embedding_output
|
||||
|
||||
if self.training:
|
||||
res = []
|
||||
for i in range(self.config.num_hidden_layers):
|
||||
encoder_outputs = self.encoder.adaptive_forward(
|
||||
encoder_outputs, current_layer=i, attention_mask=extended_attention_mask, head_mask=head_mask
|
||||
)
|
||||
|
||||
pooled_output = self.pooler(encoder_outputs)
|
||||
logits = output_layers[i](output_dropout(pooled_output))
|
||||
res.append(logits)
|
||||
elif self.patience == 0: # Use all layers for inference
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
)
|
||||
pooled_output = self.pooler(encoder_outputs[0])
|
||||
res = [output_layers[self.config.num_hidden_layers - 1](pooled_output)]
|
||||
else:
|
||||
patient_counter = 0
|
||||
patient_result = None
|
||||
calculated_layer_num = 0
|
||||
for i in range(self.config.num_hidden_layers):
|
||||
calculated_layer_num += 1
|
||||
encoder_outputs = self.encoder.adaptive_forward(
|
||||
encoder_outputs, current_layer=i, attention_mask=extended_attention_mask, head_mask=head_mask
|
||||
)
|
||||
|
||||
pooled_output = self.pooler(encoder_outputs)
|
||||
logits = output_layers[i](pooled_output)
|
||||
if regression:
|
||||
labels = logits.detach()
|
||||
if patient_result is not None:
|
||||
patient_labels = patient_result.detach()
|
||||
if (patient_result is not None) and torch.abs(patient_result - labels) < self.regression_threshold:
|
||||
patient_counter += 1
|
||||
else:
|
||||
patient_counter = 0
|
||||
else:
|
||||
labels = logits.detach().argmax(dim=1)
|
||||
if patient_result is not None:
|
||||
patient_labels = patient_result.detach().argmax(dim=1)
|
||||
if (patient_result is not None) and torch.all(labels.eq(patient_labels)):
|
||||
patient_counter += 1
|
||||
else:
|
||||
patient_counter = 0
|
||||
|
||||
patient_result = logits
|
||||
if patient_counter == self.patience:
|
||||
break
|
||||
res = [patient_result]
|
||||
self.inference_layers_num += calculated_layer_num
|
||||
self.inference_instances_num += 1
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""Bert Model transformer with PABEE and a sequence classification/regression head on top (a linear layer on top of
|
||||
the pooled output) e.g. for GLUE tasks. """,
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class BertForSequenceClassificationWithPabee(BertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.bert = BertModelWithPabee(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifiers = nn.ModuleList(
|
||||
[nn.Linear(config.hidden_size, self.config.num_labels) for _ in range(config.num_hidden_layers)]
|
||||
)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the sequence classification/regression loss.
|
||||
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
|
||||
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Returns:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
|
||||
Examples::
|
||||
|
||||
from transformers import BertTokenizer, BertForSequenceClassification
|
||||
from pabee import BertForSequenceClassificationWithPabee
|
||||
from torch import nn
|
||||
import torch
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained('google-bert/bert-base-uncased')
|
||||
model = BertForSequenceClassificationWithPabee.from_pretrained('google-bert/bert-base-uncased')
|
||||
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, labels=labels)
|
||||
|
||||
loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
|
||||
logits = self.bert(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_dropout=self.dropout,
|
||||
output_layers=self.classifiers,
|
||||
regression=self.num_labels == 1,
|
||||
)
|
||||
|
||||
outputs = (logits[-1],)
|
||||
|
||||
if labels is not None:
|
||||
total_loss = None
|
||||
total_weights = 0
|
||||
for ix, logits_item in enumerate(logits):
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits_item.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits_item.view(-1, self.num_labels), labels.view(-1))
|
||||
if total_loss is None:
|
||||
total_loss = loss
|
||||
else:
|
||||
total_loss += loss * (ix + 1)
|
||||
total_weights += ix + 1
|
||||
outputs = (total_loss / total_weights,) + outputs
|
||||
|
||||
return outputs
|
||||
@@ -1 +0,0 @@
|
||||
transformers == 4.38.0
|
||||
@@ -1,751 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Google AI Language Team Authors, The HuggingFace Inc. team and Microsoft Corporation.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Training and inference using the library models for sequence classification on GLUE (Bert, Albert) with PABEE."""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pabee.modeling_pabee_albert import AlbertForSequenceClassificationWithPabee
|
||||
from pabee.modeling_pabee_bert import BertForSequenceClassificationWithPabee
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
AlbertConfig,
|
||||
AlbertTokenizer,
|
||||
BertConfig,
|
||||
BertTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from transformers import glue_compute_metrics as compute_metrics
|
||||
from transformers import glue_convert_examples_to_features as convert_examples_to_features
|
||||
from transformers import glue_output_modes as output_modes
|
||||
from transformers import glue_processors as processors
|
||||
from transformers.trainer_utils import is_main_process
|
||||
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except ImportError:
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForSequenceClassificationWithPabee, BertTokenizer),
|
||||
"albert": (AlbertConfig, AlbertForSequenceClassificationWithPabee, AlbertTokenizer),
|
||||
}
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer):
|
||||
"""Train the model"""
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter()
|
||||
|
||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
if args.max_steps > 0:
|
||||
t_total = args.max_steps
|
||||
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
||||
else:
|
||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||
]
|
||||
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
|
||||
os.path.join(args.model_name_or_path, "scheduler.pt")
|
||||
):
|
||||
# Load in optimizer and scheduler states
|
||||
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
||||
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
|
||||
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||
|
||||
# multi-gpu training (should be after apex fp16 initialization)
|
||||
if args.n_gpu > 1:
|
||||
model = nn.DataParallel(model)
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if args.local_rank != -1:
|
||||
model = nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True,
|
||||
)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size
|
||||
* args.gradient_accumulation_steps
|
||||
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
global_step = 0
|
||||
epochs_trained = 0
|
||||
steps_trained_in_current_epoch = 0
|
||||
# Check if continuing training from a checkpoint
|
||||
if os.path.exists(args.model_name_or_path):
|
||||
# set global_step to global_step of last saved checkpoint from model path
|
||||
global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
|
||||
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||
|
||||
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
||||
logger.info(" Continuing training from epoch %d", epochs_trained)
|
||||
logger.info(" Continuing training from global step %d", global_step)
|
||||
logger.info(
|
||||
" Will skip the first %d steps in the first epoch",
|
||||
steps_trained_in_current_epoch,
|
||||
)
|
||||
|
||||
tr_loss, logging_loss = 0.0, 0.0
|
||||
model.zero_grad()
|
||||
train_iterator = trange(
|
||||
epochs_trained,
|
||||
int(args.num_train_epochs),
|
||||
desc="Epoch",
|
||||
disable=args.local_rank not in [-1, 0],
|
||||
)
|
||||
set_seed(args) # Added here for reproducibility
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
# Skip past any already trained steps if resuming training
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
steps_trained_in_current_epoch -= 1
|
||||
continue
|
||||
|
||||
model.train()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
inputs = {
|
||||
"input_ids": batch[0],
|
||||
"attention_mask": batch[1],
|
||||
"labels": batch[3],
|
||||
}
|
||||
inputs["token_type_ids"] = batch[2]
|
||||
outputs = model(**inputs)
|
||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||
|
||||
if args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
if args.fp16:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
tr_loss += loss.item()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
if args.fp16:
|
||||
nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||
else:
|
||||
nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step() # Update learning rate schedule
|
||||
model.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
logs = {}
|
||||
if (
|
||||
args.local_rank == -1 and args.evaluate_during_training
|
||||
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||
results = evaluate(args, model, tokenizer)
|
||||
for key, value in results.items():
|
||||
eval_key = "eval_{}".format(key)
|
||||
logs[eval_key] = value
|
||||
|
||||
loss_scalar = (tr_loss - logging_loss) / args.logging_steps
|
||||
learning_rate_scalar = scheduler.get_lr()[0]
|
||||
logs["learning_rate"] = learning_rate_scalar
|
||||
logs["loss"] = loss_scalar
|
||||
logging_loss = tr_loss
|
||||
|
||||
for key, value in logs.items():
|
||||
tb_writer.add_scalar(key, value, global_step)
|
||||
print(json.dumps({**logs, **{"step": global_step}}))
|
||||
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
epoch_iterator.close()
|
||||
break
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
train_iterator.close()
|
||||
break
|
||||
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer.close()
|
||||
|
||||
return global_step, tr_loss / global_step
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, prefix="", patience=0):
|
||||
if args.model_type == "albert":
|
||||
model.albert.set_regression_threshold(args.regression_threshold)
|
||||
model.albert.set_patience(patience)
|
||||
model.albert.reset_stats()
|
||||
elif args.model_type == "bert":
|
||||
model.bert.set_regression_threshold(args.regression_threshold)
|
||||
model.bert.set_patience(patience)
|
||||
model.bert.reset_stats()
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
||||
eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
|
||||
eval_outputs_dirs = (args.output_dir, args.output_dir + "-MM") if args.task_name == "mnli" else (args.output_dir,)
|
||||
|
||||
results = {}
|
||||
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
|
||||
eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
|
||||
|
||||
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(eval_output_dir)
|
||||
|
||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(eval_dataset)
|
||||
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
# multi-gpu eval
|
||||
if args.n_gpu > 1 and not isinstance(model, nn.DataParallel):
|
||||
model = nn.DataParallel(model)
|
||||
|
||||
# Eval!
|
||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||
logger.info(" Num examples = %d", len(eval_dataset))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
eval_loss = 0.0
|
||||
nb_eval_steps = 0
|
||||
preds = None
|
||||
out_label_ids = None
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
model.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = {
|
||||
"input_ids": batch[0],
|
||||
"attention_mask": batch[1],
|
||||
"labels": batch[3],
|
||||
}
|
||||
inputs["token_type_ids"] = batch[2]
|
||||
outputs = model(**inputs)
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
|
||||
eval_loss += tmp_eval_loss.mean().item()
|
||||
nb_eval_steps += 1
|
||||
if preds is None:
|
||||
preds = logits.detach().cpu().numpy()
|
||||
out_label_ids = inputs["labels"].detach().cpu().numpy()
|
||||
else:
|
||||
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
||||
out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
if args.output_mode == "classification":
|
||||
preds = np.argmax(preds, axis=1)
|
||||
elif args.output_mode == "regression":
|
||||
preds = np.squeeze(preds)
|
||||
result = compute_metrics(eval_task, preds, out_label_ids)
|
||||
results.update(result)
|
||||
|
||||
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results {} *****".format(prefix))
|
||||
for key in sorted(result.keys()):
|
||||
logger.info(" %s = %s", key, str(result[key]))
|
||||
print(" %s = %s" % (key, str(result[key])))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
if args.eval_all_checkpoints and patience != 0:
|
||||
if args.model_type == "albert":
|
||||
model.albert.log_stats()
|
||||
elif args.model_type == "bert":
|
||||
model.bert.log_stats()
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
||||
if args.local_rank not in [-1, 0] and not evaluate:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
processor = processors[task]()
|
||||
output_mode = output_modes[task]
|
||||
# Load data features from cache or dataset file
|
||||
cached_features_file = os.path.join(
|
||||
args.data_dir,
|
||||
"cached_{}_{}_{}_{}".format(
|
||||
"dev" if evaluate else "train",
|
||||
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||||
str(args.max_seq_length),
|
||||
str(task),
|
||||
),
|
||||
)
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features = torch.load(cached_features_file)
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s", args.data_dir)
|
||||
label_list = processor.get_labels()
|
||||
if task in ["mnli", "mnli-mm"] and args.model_type in ["roberta", "xlmroberta"]:
|
||||
# HACK(label indices are swapped in RoBERTa pretrained model)
|
||||
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||
examples = (
|
||||
processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
||||
)
|
||||
features = convert_examples_to_features(
|
||||
examples,
|
||||
tokenizer,
|
||||
label_list=label_list,
|
||||
max_length=args.max_seq_length,
|
||||
output_mode=output_mode,
|
||||
)
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(features, cached_features_file)
|
||||
|
||||
if args.local_rank == 0 and not evaluate:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
# Convert to Tensors and build dataset
|
||||
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
||||
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
|
||||
if output_mode == "classification":
|
||||
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
||||
elif output_mode == "regression":
|
||||
all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
|
||||
|
||||
dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_name",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--patience",
|
||||
default="0",
|
||||
type=str,
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--regression_threshold",
|
||||
default=0,
|
||||
type=float,
|
||||
required=False,
|
||||
)
|
||||
|
||||
# Other parameters
|
||||
parser.add_argument(
|
||||
"--config_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained config name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from huggingface.co",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help=(
|
||||
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||
parser.add_argument(
|
||||
"--evaluate_during_training",
|
||||
action="store_true",
|
||||
help="Run evaluation during training at each logging step.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_lower_case",
|
||||
action="store_true",
|
||||
help="Set this flag if you are using an uncased model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--per_gpu_train_batch_size",
|
||||
default=8,
|
||||
type=int,
|
||||
help="Batch size per GPU/CPU for training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_gpu_eval_batch_size",
|
||||
default=1,
|
||||
type=int,
|
||||
help="Batch size per GPU/CPU for evaluation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
default=5e-5,
|
||||
type=float,
|
||||
help="The initial learning rate for Adam.",
|
||||
)
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs",
|
||||
default=3.0,
|
||||
type=float,
|
||||
help="Total number of training epochs to perform.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||
)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
|
||||
parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
|
||||
parser.add_argument(
|
||||
"--save_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Save checkpoint every X updates steps.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_all_checkpoints",
|
||||
action="store_true",
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||
)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||
parser.add_argument(
|
||||
"--overwrite_output_dir",
|
||||
action="store_true",
|
||||
help="Overwrite the content of the output directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache",
|
||||
action="store_true",
|
||||
help="Overwrite the cached training and evaluation sets",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
help=(
|
||||
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. "
|
||||
"See details at https://nvidia.github.io/apex/amp.html"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local_rank",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="For distributed training: local_rank",
|
||||
)
|
||||
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if (
|
||||
os.path.exists(args.output_dir)
|
||||
and os.listdir(args.output_dir)
|
||||
and args.do_train
|
||||
and not args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||
args.output_dir
|
||||
)
|
||||
)
|
||||
|
||||
# Setup distant debugging if needed
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
|
||||
# Setup CUDA, GPU & distributed training
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
args.n_gpu = torch.cuda.device_count()
|
||||
else: # Initializes the distributed backend which will take care of synchronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank,
|
||||
device,
|
||||
args.n_gpu,
|
||||
bool(args.local_rank != -1),
|
||||
args.fp16,
|
||||
)
|
||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||
if is_main_process(args.local_rank):
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
|
||||
# Prepare GLUE task
|
||||
args.task_name = args.task_name.lower()
|
||||
if args.task_name not in processors:
|
||||
raise ValueError("Task not found: %s" % (args.task_name))
|
||||
processor = processors[args.task_name]()
|
||||
args.output_mode = output_modes[args.task_name]
|
||||
label_list = processor.get_labels()
|
||||
num_labels = len(label_list)
|
||||
|
||||
if args.patience != "0" and args.per_gpu_eval_batch_size != 1:
|
||||
raise ValueError("The eval batch size must be 1 with PABEE inference on.")
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=args.task_name,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
model.to(args.device)
|
||||
|
||||
print("Total Model Parameters:", sum(param.numel() for param in model.parameters()))
|
||||
output_layers_param_num = sum(param.numel() for param in model.classifiers.parameters())
|
||||
print("Output Layers Parameters:", output_layers_param_num)
|
||||
single_output_layer_param_num = sum(param.numel() for param in model.classifiers[0].parameters())
|
||||
print(
|
||||
"Added Output Layers Parameters:",
|
||||
output_layers_param_num - single_output_layer_param_num,
|
||||
)
|
||||
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Training
|
||||
if args.do_train:
|
||||
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
patience_list = [int(x) for x in args.patience.split(",")]
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = [
|
||||
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||
]
|
||||
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
model.to(args.device)
|
||||
|
||||
print(f"Evaluation for checkpoint {prefix}")
|
||||
for patience in patience_list:
|
||||
result = evaluate(args, model, tokenizer, prefix=prefix, patience=patience)
|
||||
result = {k + "_{}".format(global_step): v for k, v in result.items()}
|
||||
results.update(result)
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,51 +0,0 @@
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
import run_glue_with_pabee
|
||||
|
||||
from transformers.testing_utils import TestCasePlus
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def get_setup_file():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-f")
|
||||
args = parser.parse_args()
|
||||
return args.f
|
||||
|
||||
|
||||
class PabeeTests(TestCasePlus):
|
||||
def test_run_glue(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_glue_with_pabee.py
|
||||
--model_type albert
|
||||
--model_name_or_path albert/albert-base-v2
|
||||
--data_dir ./tests/fixtures/tests_samples/MRPC/
|
||||
--output_dir {tmp_dir}
|
||||
--overwrite_output_dir
|
||||
--task_name mrpc
|
||||
--do_train
|
||||
--do_eval
|
||||
--per_gpu_train_batch_size=2
|
||||
--per_gpu_eval_batch_size=1
|
||||
--learning_rate=2e-5
|
||||
--max_steps=50
|
||||
--warmup_steps=2
|
||||
--seed=42
|
||||
--max_seq_length=128
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
result = run_glue_with_pabee.main()
|
||||
for value in result.values():
|
||||
self.assertGreaterEqual(value, 0.75)
|
||||
@@ -1,61 +0,0 @@
|
||||
# Text Summarization with Pretrained Encoders
|
||||
|
||||
This folder contains part of the code necessary to reproduce the results on abstractive summarization from the article [Text Summarization with Pretrained Encoders](https://arxiv.org/pdf/1908.08345.pdf) by [Yang Liu](https://nlp-yang.github.io/) and [Mirella Lapata](https://homepages.inf.ed.ac.uk/mlap/). It can also be used to summarize any document.
|
||||
|
||||
The original code can be found on the Yang Liu's [github repository](https://github.com/nlpyang/PreSumm).
|
||||
|
||||
The model is loaded with the pre-trained weights for the abstractive summarization model trained on the CNN/Daily Mail dataset with an extractive and then abstractive tasks.
|
||||
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/transformers && cd transformers
|
||||
pip install .
|
||||
pip install nltk py-rouge
|
||||
cd examples/seq2seq/bertabs
|
||||
```
|
||||
|
||||
## Reproduce the authors' ROUGE score
|
||||
|
||||
To be able to reproduce the authors' results on the CNN/Daily Mail dataset you first need to download both CNN and Daily Mail datasets [from Kyunghyun Cho's website](https://cs.nyu.edu/~kcho/DMQA/) (the links next to "Stories") in the same folder. Then uncompress the archives by running:
|
||||
|
||||
```bash
|
||||
tar -xvf cnn_stories.tgz && tar -xvf dailymail_stories.tgz
|
||||
```
|
||||
|
||||
And move all the stories to the same folder. We will refer as `$DATA_PATH` the path to where you uncompressed both archive. Then run the following in the same folder as `run_summarization.py`:
|
||||
|
||||
```bash
|
||||
python run_summarization.py \
|
||||
--documents_dir $DATA_PATH \
|
||||
--summaries_output_dir $SUMMARIES_PATH \ # optional
|
||||
--no_cuda false \
|
||||
--batch_size 4 \
|
||||
--min_length 50 \
|
||||
--max_length 200 \
|
||||
--beam_size 5 \
|
||||
--alpha 0.95 \
|
||||
--block_trigram true \
|
||||
--compute_rouge true
|
||||
```
|
||||
|
||||
The scripts executes on GPU if one is available and if `no_cuda` is not set to `true`. Inference on multiple GPUs is not supported yet. The ROUGE scores will be displayed in the console at the end of evaluation and written in a `rouge_scores.txt` file. The script takes 30 hours to compute with a single Tesla V100 GPU and a batch size of 10 (300,000 texts to summarize).
|
||||
|
||||
## Summarize any text
|
||||
|
||||
Put the documents that you would like to summarize in a folder (the path to which is referred to as `$DATA_PATH` below) and run the following in the same folder as `run_summarization.py`:
|
||||
|
||||
```bash
|
||||
python run_summarization.py \
|
||||
--documents_dir $DATA_PATH \
|
||||
--summaries_output_dir $SUMMARIES_PATH \ # optional
|
||||
--no_cuda false \
|
||||
--batch_size 4 \
|
||||
--min_length 50 \
|
||||
--max_length 200 \
|
||||
--beam_size 5 \
|
||||
--alpha 0.95 \
|
||||
--block_trigram true \
|
||||
```
|
||||
|
||||
You may want to play around with `min_length`, `max_length` and `alpha` to suit your use case. If you want to compute ROUGE on another dataset you will need to tweak the stories/summaries import in `utils_summarization.py` and tell it where to fetch the reference summaries.
|
||||
@@ -1,98 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""BertAbs configuration"""
|
||||
|
||||
import logging
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
BERTABS_FINETUNED_CONFIG_MAP = {
|
||||
"bertabs-finetuned-cnndm": "https://huggingface.co/remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization/resolve/main/config.json",
|
||||
}
|
||||
|
||||
|
||||
class BertAbsConfig(PretrainedConfig):
|
||||
r"""Class to store the configuration of the BertAbs model.
|
||||
|
||||
Arguments:
|
||||
vocab_size: int
|
||||
Number of tokens in the vocabulary.
|
||||
max_pos: int
|
||||
The maximum sequence length that this model will be used with.
|
||||
enc_layer: int
|
||||
The number of hidden layers in the Transformer encoder.
|
||||
enc_hidden_size: int
|
||||
The size of the encoder's layers.
|
||||
enc_heads: int
|
||||
The number of attention heads for each attention layer in the encoder.
|
||||
enc_ff_size: int
|
||||
The size of the encoder's feed-forward layers.
|
||||
enc_dropout: int
|
||||
The dropout probability for all fully connected layers in the
|
||||
embeddings, layers, pooler and also the attention probabilities in
|
||||
the encoder.
|
||||
dec_layer: int
|
||||
The number of hidden layers in the decoder.
|
||||
dec_hidden_size: int
|
||||
The size of the decoder's layers.
|
||||
dec_heads: int
|
||||
The number of attention heads for each attention layer in the decoder.
|
||||
dec_ff_size: int
|
||||
The size of the decoder's feed-forward layers.
|
||||
dec_dropout: int
|
||||
The dropout probability for all fully connected layers in the
|
||||
embeddings, layers, pooler and also the attention probabilities in
|
||||
the decoder.
|
||||
"""
|
||||
|
||||
model_type = "bertabs"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=30522,
|
||||
max_pos=512,
|
||||
enc_layers=6,
|
||||
enc_hidden_size=512,
|
||||
enc_heads=8,
|
||||
enc_ff_size=512,
|
||||
enc_dropout=0.2,
|
||||
dec_layers=6,
|
||||
dec_hidden_size=768,
|
||||
dec_heads=8,
|
||||
dec_ff_size=2048,
|
||||
dec_dropout=0.2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.max_pos = max_pos
|
||||
|
||||
self.enc_layers = enc_layers
|
||||
self.enc_hidden_size = enc_hidden_size
|
||||
self.enc_heads = enc_heads
|
||||
self.enc_ff_size = enc_ff_size
|
||||
self.enc_dropout = enc_dropout
|
||||
|
||||
self.dec_layers = dec_layers
|
||||
self.dec_hidden_size = dec_hidden_size
|
||||
self.dec_heads = dec_heads
|
||||
self.dec_ff_size = dec_ff_size
|
||||
self.dec_dropout = dec_dropout
|
||||
@@ -1,185 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BertExtAbs's checkpoints.
|
||||
|
||||
The script looks like it is doing something trivial but it is not. The "weights"
|
||||
proposed by the authors are actually the entire model pickled. We need to load
|
||||
the model within the original codebase to be able to only save its `state_dict`.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
from model_bertabs import BertAbsSummarizer
|
||||
from models.model_builder import AbsSummarizer # The authors' implementation
|
||||
|
||||
from transformers import BertTokenizer
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SAMPLE_TEXT = "Hello world! cécé herlolip"
|
||||
|
||||
|
||||
BertAbsConfig = namedtuple(
|
||||
"BertAbsConfig",
|
||||
[
|
||||
"temp_dir",
|
||||
"large",
|
||||
"use_bert_emb",
|
||||
"finetune_bert",
|
||||
"encoder",
|
||||
"share_emb",
|
||||
"max_pos",
|
||||
"enc_layers",
|
||||
"enc_hidden_size",
|
||||
"enc_heads",
|
||||
"enc_ff_size",
|
||||
"enc_dropout",
|
||||
"dec_layers",
|
||||
"dec_hidden_size",
|
||||
"dec_heads",
|
||||
"dec_ff_size",
|
||||
"dec_dropout",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def convert_bertabs_checkpoints(path_to_checkpoints, dump_path):
|
||||
"""Copy/paste and tweak the pre-trained weights provided by the creators
|
||||
of BertAbs for the internal architecture.
|
||||
"""
|
||||
|
||||
# Instantiate the authors' model with the pre-trained weights
|
||||
config = BertAbsConfig(
|
||||
temp_dir=".",
|
||||
finetune_bert=False,
|
||||
large=False,
|
||||
share_emb=True,
|
||||
use_bert_emb=False,
|
||||
encoder="bert",
|
||||
max_pos=512,
|
||||
enc_layers=6,
|
||||
enc_hidden_size=512,
|
||||
enc_heads=8,
|
||||
enc_ff_size=512,
|
||||
enc_dropout=0.2,
|
||||
dec_layers=6,
|
||||
dec_hidden_size=768,
|
||||
dec_heads=8,
|
||||
dec_ff_size=2048,
|
||||
dec_dropout=0.2,
|
||||
)
|
||||
checkpoints = torch.load(path_to_checkpoints, lambda storage, loc: storage)
|
||||
original = AbsSummarizer(config, torch.device("cpu"), checkpoints)
|
||||
original.eval()
|
||||
|
||||
new_model = BertAbsSummarizer(config, torch.device("cpu"))
|
||||
new_model.eval()
|
||||
|
||||
# -------------------
|
||||
# Convert the weights
|
||||
# -------------------
|
||||
|
||||
logging.info("convert the model")
|
||||
new_model.bert.load_state_dict(original.bert.state_dict())
|
||||
new_model.decoder.load_state_dict(original.decoder.state_dict())
|
||||
new_model.generator.load_state_dict(original.generator.state_dict())
|
||||
|
||||
# ----------------------------------
|
||||
# Make sure the outpus are identical
|
||||
# ----------------------------------
|
||||
|
||||
logging.info("Make sure that the models' outputs are identical")
|
||||
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
||||
|
||||
# prepare the model inputs
|
||||
encoder_input_ids = tokenizer.encode("This is sample éàalj'-.")
|
||||
encoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(encoder_input_ids)))
|
||||
encoder_input_ids = torch.tensor(encoder_input_ids).unsqueeze(0)
|
||||
decoder_input_ids = tokenizer.encode("This is sample 3 éàalj'-.")
|
||||
decoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(decoder_input_ids)))
|
||||
decoder_input_ids = torch.tensor(decoder_input_ids).unsqueeze(0)
|
||||
|
||||
# failsafe to make sure the weights reset does not affect the
|
||||
# loaded weights.
|
||||
assert torch.max(torch.abs(original.generator[0].weight - new_model.generator[0].weight)) == 0
|
||||
|
||||
# forward pass
|
||||
src = encoder_input_ids
|
||||
tgt = decoder_input_ids
|
||||
segs = token_type_ids = None
|
||||
clss = None
|
||||
mask_src = encoder_attention_mask = None
|
||||
mask_tgt = decoder_attention_mask = None
|
||||
mask_cls = None
|
||||
|
||||
# The original model does not apply the generator layer immediatly but rather in
|
||||
# the beam search (where it combines softmax + linear layer). Since we already
|
||||
# apply the softmax in our generation process we only apply the linear layer here.
|
||||
# We make sure that the outputs of the full stack are identical
|
||||
output_original_model = original(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)[0]
|
||||
output_original_generator = original.generator(output_original_model)
|
||||
|
||||
output_converted_model = new_model(
|
||||
encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask
|
||||
)[0]
|
||||
output_converted_generator = new_model.generator(output_converted_model)
|
||||
|
||||
maximum_absolute_difference = torch.max(torch.abs(output_converted_model - output_original_model)).item()
|
||||
print("Maximum absolute difference between weights: {:.2f}".format(maximum_absolute_difference))
|
||||
maximum_absolute_difference = torch.max(torch.abs(output_converted_generator - output_original_generator)).item()
|
||||
print("Maximum absolute difference between weights: {:.2f}".format(maximum_absolute_difference))
|
||||
|
||||
are_identical = torch.allclose(output_converted_model, output_original_model, atol=1e-3)
|
||||
if are_identical:
|
||||
logging.info("all weights are equal up to 1e-3")
|
||||
else:
|
||||
raise ValueError("the weights are different. The new model is likely different from the original one.")
|
||||
|
||||
# The model has been saved with torch.save(model) and this is bound to the exact
|
||||
# directory structure. We save the state_dict instead.
|
||||
logging.info("saving the model's state dictionary")
|
||||
torch.save(
|
||||
new_model.state_dict(), "./bertabs-finetuned-cnndm-extractive-abstractive-summarization/pytorch_model.bin"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--bertabs_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path the official PyTorch dump.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the output PyTorch model.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_bertabs_checkpoints(
|
||||
args.bertabs_checkpoint_path,
|
||||
args.pytorch_dump_folder_path,
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +0,0 @@
|
||||
transformers == 4.38.0
|
||||
|
||||
# For ROUGE
|
||||
nltk
|
||||
py-rouge
|
||||
@@ -1,347 +0,0 @@
|
||||
#! /usr/bin/python3
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
from modeling_bertabs import BertAbs, build_predictor
|
||||
from torch.utils.data import DataLoader, SequentialSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import BertTokenizer
|
||||
|
||||
from .utils_summarization import (
|
||||
CNNDMDataset,
|
||||
build_mask,
|
||||
compute_token_type_ids,
|
||||
encode_for_summarization,
|
||||
truncate_or_pad,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||
|
||||
|
||||
Batch = namedtuple("Batch", ["document_names", "batch_size", "src", "segs", "mask_src", "tgt_str"])
|
||||
|
||||
|
||||
def evaluate(args):
|
||||
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased", do_lower_case=True)
|
||||
model = BertAbs.from_pretrained("remi/bertabs-finetuned-extractive-abstractive-summarization")
|
||||
model.to(args.device)
|
||||
model.eval()
|
||||
|
||||
symbols = {
|
||||
"BOS": tokenizer.vocab["[unused0]"],
|
||||
"EOS": tokenizer.vocab["[unused1]"],
|
||||
"PAD": tokenizer.vocab["[PAD]"],
|
||||
}
|
||||
|
||||
if args.compute_rouge:
|
||||
reference_summaries = []
|
||||
generated_summaries = []
|
||||
|
||||
import nltk
|
||||
import rouge
|
||||
|
||||
nltk.download("punkt")
|
||||
rouge_evaluator = rouge.Rouge(
|
||||
metrics=["rouge-n", "rouge-l"],
|
||||
max_n=2,
|
||||
limit_length=True,
|
||||
length_limit=args.beam_size,
|
||||
length_limit_type="words",
|
||||
apply_avg=True,
|
||||
apply_best=False,
|
||||
alpha=0.5, # Default F1_score
|
||||
weight_factor=1.2,
|
||||
stemming=True,
|
||||
)
|
||||
|
||||
# these (unused) arguments are defined to keep the compatibility
|
||||
# with the legacy code and will be deleted in a next iteration.
|
||||
args.result_path = ""
|
||||
args.temp_dir = ""
|
||||
|
||||
data_iterator = build_data_iterator(args, tokenizer)
|
||||
predictor = build_predictor(args, tokenizer, symbols, model)
|
||||
|
||||
logger.info("***** Running evaluation *****")
|
||||
logger.info(" Number examples = %d", len(data_iterator.dataset))
|
||||
logger.info(" Batch size = %d", args.batch_size)
|
||||
logger.info("")
|
||||
logger.info("***** Beam Search parameters *****")
|
||||
logger.info(" Beam size = %d", args.beam_size)
|
||||
logger.info(" Minimum length = %d", args.min_length)
|
||||
logger.info(" Maximum length = %d", args.max_length)
|
||||
logger.info(" Alpha (length penalty) = %.2f", args.alpha)
|
||||
logger.info(" Trigrams %s be blocked", ("will" if args.block_trigram else "will NOT"))
|
||||
|
||||
for batch in tqdm(data_iterator):
|
||||
batch_data = predictor.translate_batch(batch)
|
||||
translations = predictor.from_batch(batch_data)
|
||||
summaries = [format_summary(t) for t in translations]
|
||||
save_summaries(summaries, args.summaries_output_dir, batch.document_names)
|
||||
|
||||
if args.compute_rouge:
|
||||
reference_summaries += batch.tgt_str
|
||||
generated_summaries += summaries
|
||||
|
||||
if args.compute_rouge:
|
||||
scores = rouge_evaluator.get_scores(generated_summaries, reference_summaries)
|
||||
str_scores = format_rouge_scores(scores)
|
||||
save_rouge_scores(str_scores)
|
||||
print(str_scores)
|
||||
|
||||
|
||||
def save_summaries(summaries, path, original_document_name):
|
||||
"""Write the summaries in fies that are prefixed by the original
|
||||
files' name with the `_summary` appended.
|
||||
|
||||
Attributes:
|
||||
original_document_names: List[string]
|
||||
Name of the document that was summarized.
|
||||
path: string
|
||||
Path were the summaries will be written
|
||||
summaries: List[string]
|
||||
The summaries that we produced.
|
||||
"""
|
||||
for summary, document_name in zip(summaries, original_document_name):
|
||||
# Prepare the summary file's name
|
||||
if "." in document_name:
|
||||
bare_document_name = ".".join(document_name.split(".")[:-1])
|
||||
extension = document_name.split(".")[-1]
|
||||
name = bare_document_name + "_summary." + extension
|
||||
else:
|
||||
name = document_name + "_summary"
|
||||
|
||||
file_path = os.path.join(path, name)
|
||||
with open(file_path, "w") as output:
|
||||
output.write(summary)
|
||||
|
||||
|
||||
def format_summary(translation):
|
||||
"""Transforms the output of the `from_batch` function
|
||||
into nicely formatted summaries.
|
||||
"""
|
||||
raw_summary, _, _ = translation
|
||||
summary = (
|
||||
raw_summary.replace("[unused0]", "")
|
||||
.replace("[unused3]", "")
|
||||
.replace("[PAD]", "")
|
||||
.replace("[unused1]", "")
|
||||
.replace(r" +", " ")
|
||||
.replace(" [unused2] ", ". ")
|
||||
.replace("[unused2]", "")
|
||||
.strip()
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
def format_rouge_scores(scores):
|
||||
return """\n
|
||||
****** ROUGE SCORES ******
|
||||
|
||||
** ROUGE 1
|
||||
F1 >> {:.3f}
|
||||
Precision >> {:.3f}
|
||||
Recall >> {:.3f}
|
||||
|
||||
** ROUGE 2
|
||||
F1 >> {:.3f}
|
||||
Precision >> {:.3f}
|
||||
Recall >> {:.3f}
|
||||
|
||||
** ROUGE L
|
||||
F1 >> {:.3f}
|
||||
Precision >> {:.3f}
|
||||
Recall >> {:.3f}""".format(
|
||||
scores["rouge-1"]["f"],
|
||||
scores["rouge-1"]["p"],
|
||||
scores["rouge-1"]["r"],
|
||||
scores["rouge-2"]["f"],
|
||||
scores["rouge-2"]["p"],
|
||||
scores["rouge-2"]["r"],
|
||||
scores["rouge-l"]["f"],
|
||||
scores["rouge-l"]["p"],
|
||||
scores["rouge-l"]["r"],
|
||||
)
|
||||
|
||||
|
||||
def save_rouge_scores(str_scores):
|
||||
with open("rouge_scores.txt", "w") as output:
|
||||
output.write(str_scores)
|
||||
|
||||
|
||||
#
|
||||
# LOAD the dataset
|
||||
#
|
||||
|
||||
|
||||
def build_data_iterator(args, tokenizer):
|
||||
dataset = load_and_cache_examples(args, tokenizer)
|
||||
sampler = SequentialSampler(dataset)
|
||||
|
||||
def collate_fn(data):
|
||||
return collate(data, tokenizer, block_size=512, device=args.device)
|
||||
|
||||
iterator = DataLoader(
|
||||
dataset,
|
||||
sampler=sampler,
|
||||
batch_size=args.batch_size,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
return iterator
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer):
|
||||
dataset = CNNDMDataset(args.documents_dir)
|
||||
return dataset
|
||||
|
||||
|
||||
def collate(data, tokenizer, block_size, device):
|
||||
"""Collate formats the data passed to the data loader.
|
||||
|
||||
In particular we tokenize the data batch after batch to avoid keeping them
|
||||
all in memory. We output the data as a namedtuple to fit the original BertAbs's
|
||||
API.
|
||||
"""
|
||||
data = [x for x in data if not len(x[1]) == 0] # remove empty_files
|
||||
names = [name for name, _, _ in data]
|
||||
summaries = [" ".join(summary_list) for _, _, summary_list in data]
|
||||
|
||||
encoded_text = [encode_for_summarization(story, summary, tokenizer) for _, story, summary in data]
|
||||
encoded_stories = torch.tensor(
|
||||
[truncate_or_pad(story, block_size, tokenizer.pad_token_id) for story, _ in encoded_text]
|
||||
)
|
||||
encoder_token_type_ids = compute_token_type_ids(encoded_stories, tokenizer.cls_token_id)
|
||||
encoder_mask = build_mask(encoded_stories, tokenizer.pad_token_id)
|
||||
|
||||
batch = Batch(
|
||||
document_names=names,
|
||||
batch_size=len(encoded_stories),
|
||||
src=encoded_stories.to(device),
|
||||
segs=encoder_token_type_ids.to(device),
|
||||
mask_src=encoder_mask.to(device),
|
||||
tgt_str=summaries,
|
||||
)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def decode_summary(summary_tokens, tokenizer):
|
||||
"""Decode the summary and return it in a format
|
||||
suitable for evaluation.
|
||||
"""
|
||||
summary_tokens = summary_tokens.to("cpu").numpy()
|
||||
summary = tokenizer.decode(summary_tokens)
|
||||
sentences = summary.split(".")
|
||||
sentences = [s + "." for s in sentences]
|
||||
return sentences
|
||||
|
||||
|
||||
def main():
|
||||
"""The main function defines the interface with the users."""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--documents_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The folder where the documents to summarize are located.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--summaries_output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=False,
|
||||
help="The folder in which the summaries should be written. Defaults to the folder where the documents are",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compute_rouge",
|
||||
default=False,
|
||||
type=bool,
|
||||
required=False,
|
||||
help="Compute the ROUGE metrics during evaluation. Only available for the CNN/DailyMail dataset.",
|
||||
)
|
||||
# EVALUATION options
|
||||
parser.add_argument(
|
||||
"--no_cuda",
|
||||
default=False,
|
||||
type=bool,
|
||||
help="Whether to force the execution on CPU.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
default=4,
|
||||
type=int,
|
||||
help="Batch size per GPU/CPU for training.",
|
||||
)
|
||||
# BEAM SEARCH arguments
|
||||
parser.add_argument(
|
||||
"--min_length",
|
||||
default=50,
|
||||
type=int,
|
||||
help="Minimum number of tokens for the summaries.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_length",
|
||||
default=200,
|
||||
type=int,
|
||||
help="Maixmum number of tokens for the summaries.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--beam_size",
|
||||
default=5,
|
||||
type=int,
|
||||
help="The number of beams to start with for each example.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha",
|
||||
default=0.95,
|
||||
type=float,
|
||||
help="The value of alpha for the length penalty in the beam search.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--block_trigram",
|
||||
default=True,
|
||||
type=bool,
|
||||
help="Whether to block the existence of repeating trigrams in the text generated by beam search.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Select device (distributed not available)
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
|
||||
# Check the existence of directories
|
||||
if not args.summaries_output_dir:
|
||||
args.summaries_output_dir = args.documents_dir
|
||||
|
||||
if not documents_dir_is_valid(args.documents_dir):
|
||||
raise FileNotFoundError(
|
||||
"We could not find the directory you specified for the documents to summarize, or it was empty. Please"
|
||||
" specify a valid path."
|
||||
)
|
||||
os.makedirs(args.summaries_output_dir, exist_ok=True)
|
||||
|
||||
evaluate(args)
|
||||
|
||||
|
||||
def documents_dir_is_valid(path):
|
||||
if not os.path.exists(path):
|
||||
return False
|
||||
|
||||
file_list = os.listdir(path)
|
||||
if len(file_list) == 0:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,98 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .utils_summarization import build_mask, compute_token_type_ids, process_story, truncate_or_pad
|
||||
|
||||
|
||||
class SummarizationDataProcessingTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.block_size = 10
|
||||
|
||||
def test_fit_to_block_sequence_too_small(self):
|
||||
"""Pad the sequence with 0 if the sequence is smaller than the block size."""
|
||||
sequence = [1, 2, 3, 4]
|
||||
expected_output = [1, 2, 3, 4, 0, 0, 0, 0, 0, 0]
|
||||
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)
|
||||
|
||||
def test_fit_to_block_sequence_fit_exactly(self):
|
||||
"""Do nothing if the sequence is the right size."""
|
||||
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)
|
||||
|
||||
def test_fit_to_block_sequence_too_big(self):
|
||||
"""Truncate the sequence if it is too long."""
|
||||
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
|
||||
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)
|
||||
|
||||
def test_process_story_no_highlights(self):
|
||||
"""Processing a story with no highlights returns an empty list for the summary."""
|
||||
raw_story = """It was the year of Our Lord one thousand seven hundred and
|
||||
seventy-five.\n\nSpiritual revelations were conceded to England at that
|
||||
favoured period, as at this."""
|
||||
_, summary_lines = process_story(raw_story)
|
||||
self.assertEqual(summary_lines, [])
|
||||
|
||||
def test_process_empty_story(self):
|
||||
"""An empty story returns an empty collection of lines."""
|
||||
raw_story = ""
|
||||
story_lines, summary_lines = process_story(raw_story)
|
||||
self.assertEqual(story_lines, [])
|
||||
self.assertEqual(summary_lines, [])
|
||||
|
||||
def test_process_story_with_missing_period(self):
|
||||
raw_story = (
|
||||
"It was the year of Our Lord one thousand seven hundred and "
|
||||
"seventy-five\n\nSpiritual revelations were conceded to England "
|
||||
"at that favoured period, as at this.\n@highlight\n\nIt was the best of times"
|
||||
)
|
||||
story_lines, summary_lines = process_story(raw_story)
|
||||
|
||||
expected_story_lines = [
|
||||
"It was the year of Our Lord one thousand seven hundred and seventy-five.",
|
||||
"Spiritual revelations were conceded to England at that favoured period, as at this.",
|
||||
]
|
||||
self.assertEqual(expected_story_lines, story_lines)
|
||||
|
||||
expected_summary_lines = ["It was the best of times."]
|
||||
self.assertEqual(expected_summary_lines, summary_lines)
|
||||
|
||||
def test_build_mask_no_padding(self):
|
||||
sequence = torch.tensor([1, 2, 3, 4])
|
||||
expected = torch.tensor([1, 1, 1, 1])
|
||||
np.testing.assert_array_equal(build_mask(sequence, 0).numpy(), expected.numpy())
|
||||
|
||||
def test_build_mask(self):
|
||||
sequence = torch.tensor([1, 2, 3, 4, 23, 23, 23])
|
||||
expected = torch.tensor([1, 1, 1, 1, 0, 0, 0])
|
||||
np.testing.assert_array_equal(build_mask(sequence, 23).numpy(), expected.numpy())
|
||||
|
||||
def test_build_mask_with_padding_equal_to_one(self):
|
||||
sequence = torch.tensor([8, 2, 3, 4, 1, 1, 1])
|
||||
expected = torch.tensor([1, 1, 1, 1, 0, 0, 0])
|
||||
np.testing.assert_array_equal(build_mask(sequence, 1).numpy(), expected.numpy())
|
||||
|
||||
def test_compute_token_type_ids(self):
|
||||
separator = 101
|
||||
batch = torch.tensor([[1, 2, 3, 4, 5, 6], [1, 2, 3, 101, 5, 6], [1, 101, 3, 4, 101, 6]])
|
||||
expected = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 1, 1]])
|
||||
|
||||
result = compute_token_type_ids(batch, separator)
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
@@ -1,167 +0,0 @@
|
||||
import os
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
# ------------
|
||||
# Data loading
|
||||
# ------------
|
||||
|
||||
|
||||
class CNNDMDataset(Dataset):
|
||||
"""Abstracts the dataset used to train seq2seq models.
|
||||
|
||||
The class will process the documents that are located in the specified
|
||||
folder. The preprocessing will work on any document that is reasonably
|
||||
formatted. On the CNN/DailyMail dataset it will extract both the story
|
||||
and the summary.
|
||||
|
||||
CNN/Daily News:
|
||||
|
||||
The CNN/Daily News raw datasets are downloaded from [1]. The stories are
|
||||
stored in different files; the summary appears at the end of the story as
|
||||
sentences that are prefixed by the special `@highlight` line. To process
|
||||
the data, untar both datasets in the same folder, and pass the path to this
|
||||
folder as the "data_dir argument. The formatting code was inspired by [2].
|
||||
|
||||
[1] https://cs.nyu.edu/~kcho/
|
||||
[2] https://github.com/abisee/cnn-dailymail/
|
||||
"""
|
||||
|
||||
def __init__(self, path="", prefix="train"):
|
||||
"""We initialize the class by listing all the documents to summarize.
|
||||
Files are not read in memory due to the size of some datasets (like CNN/DailyMail).
|
||||
"""
|
||||
assert os.path.isdir(path)
|
||||
|
||||
self.documents = []
|
||||
story_filenames_list = os.listdir(path)
|
||||
for story_filename in story_filenames_list:
|
||||
if "summary" in story_filename:
|
||||
continue
|
||||
path_to_story = os.path.join(path, story_filename)
|
||||
if not os.path.isfile(path_to_story):
|
||||
continue
|
||||
self.documents.append(path_to_story)
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the number of documents."""
|
||||
return len(self.documents)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
document_path = self.documents[idx]
|
||||
document_name = document_path.split("/")[-1]
|
||||
with open(document_path, encoding="utf-8") as source:
|
||||
raw_story = source.read()
|
||||
story_lines, summary_lines = process_story(raw_story)
|
||||
return document_name, story_lines, summary_lines
|
||||
|
||||
|
||||
def process_story(raw_story):
|
||||
"""Extract the story and summary from a story file.
|
||||
|
||||
Arguments:
|
||||
raw_story (str): content of the story file as an utf-8 encoded string.
|
||||
|
||||
Raises:
|
||||
IndexError: If the story is empty or contains no highlights.
|
||||
"""
|
||||
nonempty_lines = list(filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")]))
|
||||
|
||||
# for some unknown reason some lines miss a period, add it
|
||||
nonempty_lines = [_add_missing_period(line) for line in nonempty_lines]
|
||||
|
||||
# gather article lines
|
||||
story_lines = []
|
||||
lines = deque(nonempty_lines)
|
||||
while True:
|
||||
try:
|
||||
element = lines.popleft()
|
||||
if element.startswith("@highlight"):
|
||||
break
|
||||
story_lines.append(element)
|
||||
except IndexError:
|
||||
# if "@highlight" is absent from the file we pop
|
||||
# all elements until there is None, raising an exception.
|
||||
return story_lines, []
|
||||
|
||||
# gather summary lines
|
||||
summary_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))
|
||||
|
||||
return story_lines, summary_lines
|
||||
|
||||
|
||||
def _add_missing_period(line):
|
||||
END_TOKENS = [".", "!", "?", "...", "'", "`", '"', "\u2019", "\u2019", ")"]
|
||||
if line.startswith("@highlight"):
|
||||
return line
|
||||
if line[-1] in END_TOKENS:
|
||||
return line
|
||||
return line + "."
|
||||
|
||||
|
||||
# --------------------------
|
||||
# Encoding and preprocessing
|
||||
# --------------------------
|
||||
|
||||
|
||||
def truncate_or_pad(sequence, block_size, pad_token_id):
|
||||
"""Adapt the source and target sequences' lengths to the block size.
|
||||
If the sequence is shorter we append padding token to the right of the sequence.
|
||||
"""
|
||||
if len(sequence) > block_size:
|
||||
return sequence[:block_size]
|
||||
else:
|
||||
sequence.extend([pad_token_id] * (block_size - len(sequence)))
|
||||
return sequence
|
||||
|
||||
|
||||
def build_mask(sequence, pad_token_id):
|
||||
"""Builds the mask. The attention mechanism will only attend to positions
|
||||
with value 1."""
|
||||
mask = torch.ones_like(sequence)
|
||||
idx_pad_tokens = sequence == pad_token_id
|
||||
mask[idx_pad_tokens] = 0
|
||||
return mask
|
||||
|
||||
|
||||
def encode_for_summarization(story_lines, summary_lines, tokenizer):
|
||||
"""Encode the story and summary lines, and join them
|
||||
as specified in [1] by using `[SEP] [CLS]` tokens to separate
|
||||
sentences.
|
||||
"""
|
||||
story_lines_token_ids = [tokenizer.encode(line) for line in story_lines]
|
||||
story_token_ids = [token for sentence in story_lines_token_ids for token in sentence]
|
||||
summary_lines_token_ids = [tokenizer.encode(line) for line in summary_lines]
|
||||
summary_token_ids = [token for sentence in summary_lines_token_ids for token in sentence]
|
||||
|
||||
return story_token_ids, summary_token_ids
|
||||
|
||||
|
||||
def compute_token_type_ids(batch, separator_token_id):
|
||||
"""Segment embeddings as described in [1]
|
||||
|
||||
The values {0,1} were found in the repository [2].
|
||||
|
||||
Attributes:
|
||||
batch: torch.Tensor, size [batch_size, block_size]
|
||||
Batch of input.
|
||||
separator_token_id: int
|
||||
The value of the token that separates the segments.
|
||||
|
||||
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
|
||||
arXiv preprint arXiv:1908.08345 (2019).
|
||||
[2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217)
|
||||
"""
|
||||
batch_embeddings = []
|
||||
for sequence in batch:
|
||||
sentence_num = -1
|
||||
embeddings = []
|
||||
for s in sequence:
|
||||
if s == separator_token_id:
|
||||
sentence_num += 1
|
||||
embeddings.append(sentence_num % 2)
|
||||
batch_embeddings.append(embeddings)
|
||||
return torch.tensor(batch_embeddings)
|
||||
@@ -1 +0,0 @@
|
||||
transformers == 4.38.0
|
||||
@@ -1,453 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2018 CMU and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Bertology: this script shows how you can explore the internals of the models in the library to:
|
||||
- compute the entropy of the head attentions
|
||||
- compute the importance of each head
|
||||
- prune (remove) the low importance head.
|
||||
Some parts of this script are adapted from the code of Michel et al. (http://arxiv.org/abs/1905.10650)
|
||||
which is available at https://github.com/pmichel31415/are-16-heads-really-better-than-1
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader, SequentialSampler, Subset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
GlueDataset,
|
||||
default_data_collator,
|
||||
glue_compute_metrics,
|
||||
glue_output_modes,
|
||||
glue_processors,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.trainer_utils import is_main_process
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def entropy(p):
|
||||
"""Compute the entropy of a probability distribution"""
|
||||
plogp = p * torch.log(p)
|
||||
plogp[p == 0] = 0
|
||||
return -plogp.sum(dim=-1)
|
||||
|
||||
|
||||
def print_2d_tensor(tensor):
|
||||
"""Print a 2D tensor"""
|
||||
logger.info("lv, h >\t" + "\t".join(f"{x + 1}" for x in range(len(tensor))))
|
||||
for row in range(len(tensor)):
|
||||
if tensor.dtype != torch.long:
|
||||
logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:.5f}" for x in tensor[row].cpu().data))
|
||||
else:
|
||||
logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:d}" for x in tensor[row].cpu().data))
|
||||
|
||||
|
||||
def compute_heads_importance(
|
||||
args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None, actually_pruned=False
|
||||
):
|
||||
"""This method shows how to compute:
|
||||
- head attention entropy
|
||||
- head importance scores according to http://arxiv.org/abs/1905.10650
|
||||
"""
|
||||
# Prepare our tensors
|
||||
n_layers, n_heads = model.config.num_hidden_layers, model.config.num_attention_heads
|
||||
head_importance = torch.zeros(n_layers, n_heads).to(args.device)
|
||||
attn_entropy = torch.zeros(n_layers, n_heads).to(args.device)
|
||||
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(n_layers, n_heads).to(args.device)
|
||||
|
||||
head_mask.requires_grad_(requires_grad=True)
|
||||
# If actually pruned attention multi-head, set head mask to None to avoid shape mismatch
|
||||
if actually_pruned:
|
||||
head_mask = None
|
||||
|
||||
preds = None
|
||||
labels = None
|
||||
tot_tokens = 0.0
|
||||
|
||||
for step, inputs in enumerate(tqdm(eval_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
|
||||
for k, v in inputs.items():
|
||||
inputs[k] = v.to(args.device)
|
||||
|
||||
# Do a forward pass (not with torch.no_grad() since we need gradients for importance score - see below)
|
||||
outputs = model(**inputs, head_mask=head_mask)
|
||||
loss, logits, all_attentions = (
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
outputs[-1],
|
||||
) # Loss and logits are the first, attention the last
|
||||
loss.backward() # Backpropagate to populate the gradients in the head mask
|
||||
|
||||
if compute_entropy:
|
||||
for layer, attn in enumerate(all_attentions):
|
||||
masked_entropy = entropy(attn.detach()) * inputs["attention_mask"].float().unsqueeze(1)
|
||||
attn_entropy[layer] += masked_entropy.sum(-1).sum(0).detach()
|
||||
|
||||
if compute_importance:
|
||||
head_importance += head_mask.grad.abs().detach()
|
||||
|
||||
# Also store our logits/labels if we want to compute metrics afterwards
|
||||
if preds is None:
|
||||
preds = logits.detach().cpu().numpy()
|
||||
labels = inputs["labels"].detach().cpu().numpy()
|
||||
else:
|
||||
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
||||
labels = np.append(labels, inputs["labels"].detach().cpu().numpy(), axis=0)
|
||||
|
||||
tot_tokens += inputs["attention_mask"].float().detach().sum().data
|
||||
|
||||
# Normalize
|
||||
attn_entropy /= tot_tokens
|
||||
head_importance /= tot_tokens
|
||||
# Layerwise importance normalization
|
||||
if not args.dont_normalize_importance_by_layer:
|
||||
exponent = 2
|
||||
norm_by_layer = torch.pow(torch.pow(head_importance, exponent).sum(-1), 1 / exponent)
|
||||
head_importance /= norm_by_layer.unsqueeze(-1) + 1e-20
|
||||
|
||||
if not args.dont_normalize_global_importance:
|
||||
head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min())
|
||||
|
||||
# Print/save matrices
|
||||
np.save(os.path.join(args.output_dir, "attn_entropy.npy"), attn_entropy.detach().cpu().numpy())
|
||||
np.save(os.path.join(args.output_dir, "head_importance.npy"), head_importance.detach().cpu().numpy())
|
||||
|
||||
logger.info("Attention entropies")
|
||||
print_2d_tensor(attn_entropy)
|
||||
logger.info("Head importance scores")
|
||||
print_2d_tensor(head_importance)
|
||||
logger.info("Head ranked by importance scores")
|
||||
head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=args.device)
|
||||
head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(
|
||||
head_importance.numel(), device=args.device
|
||||
)
|
||||
head_ranks = head_ranks.view_as(head_importance)
|
||||
print_2d_tensor(head_ranks)
|
||||
|
||||
return attn_entropy, head_importance, preds, labels
|
||||
|
||||
|
||||
def mask_heads(args, model, eval_dataloader):
|
||||
"""This method shows how to mask head (set some heads to zero), to test the effect on the network,
|
||||
based on the head importance scores, as described in Michel et al. (http://arxiv.org/abs/1905.10650)
|
||||
"""
|
||||
_, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False)
|
||||
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
||||
original_score = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
|
||||
logger.info("Pruning: original score: %f, threshold: %f", original_score, original_score * args.masking_threshold)
|
||||
|
||||
new_head_mask = torch.ones_like(head_importance)
|
||||
num_to_mask = max(1, int(new_head_mask.numel() * args.masking_amount))
|
||||
|
||||
current_score = original_score
|
||||
while current_score >= original_score * args.masking_threshold:
|
||||
head_mask = new_head_mask.clone() # save current head mask
|
||||
# heads from least important to most - keep only not-masked heads
|
||||
head_importance[head_mask == 0.0] = float("Inf")
|
||||
current_heads_to_mask = head_importance.view(-1).sort()[1]
|
||||
|
||||
if len(current_heads_to_mask) <= num_to_mask:
|
||||
break
|
||||
|
||||
# mask heads
|
||||
current_heads_to_mask = current_heads_to_mask[:num_to_mask]
|
||||
logger.info("Heads to mask: %s", str(current_heads_to_mask.tolist()))
|
||||
new_head_mask = new_head_mask.view(-1)
|
||||
new_head_mask[current_heads_to_mask] = 0.0
|
||||
new_head_mask = new_head_mask.view_as(head_mask)
|
||||
new_head_mask = new_head_mask.clone().detach()
|
||||
print_2d_tensor(new_head_mask)
|
||||
|
||||
# Compute metric and head importance again
|
||||
_, head_importance, preds, labels = compute_heads_importance(
|
||||
args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask
|
||||
)
|
||||
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
||||
current_score = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
|
||||
logger.info(
|
||||
"Masking: current score: %f, remaining heads %d (%.1f percents)",
|
||||
current_score,
|
||||
new_head_mask.sum(),
|
||||
new_head_mask.sum() / new_head_mask.numel() * 100,
|
||||
)
|
||||
|
||||
logger.info("Final head mask")
|
||||
print_2d_tensor(head_mask)
|
||||
np.save(os.path.join(args.output_dir, "head_mask.npy"), head_mask.detach().cpu().numpy())
|
||||
|
||||
return head_mask
|
||||
|
||||
|
||||
def prune_heads(args, model, eval_dataloader, head_mask):
|
||||
"""This method shows how to prune head (remove heads weights) based on
|
||||
the head importance scores as described in Michel et al. (http://arxiv.org/abs/1905.10650)
|
||||
"""
|
||||
# Try pruning and test time speedup
|
||||
# Pruning is like masking but we actually remove the masked weights
|
||||
before_time = datetime.now()
|
||||
_, _, preds, labels = compute_heads_importance(
|
||||
args, model, eval_dataloader, compute_entropy=False, compute_importance=False, head_mask=head_mask
|
||||
)
|
||||
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
||||
score_masking = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
|
||||
original_time = datetime.now() - before_time
|
||||
|
||||
original_num_params = sum(p.numel() for p in model.parameters())
|
||||
heads_to_prune = {
|
||||
layer: (1 - head_mask[layer].long()).nonzero().squeeze().tolist() for layer in range(len(head_mask))
|
||||
}
|
||||
|
||||
assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item()
|
||||
model.prune_heads(heads_to_prune)
|
||||
pruned_num_params = sum(p.numel() for p in model.parameters())
|
||||
|
||||
before_time = datetime.now()
|
||||
_, _, preds, labels = compute_heads_importance(
|
||||
args,
|
||||
model,
|
||||
eval_dataloader,
|
||||
compute_entropy=False,
|
||||
compute_importance=False,
|
||||
head_mask=None,
|
||||
actually_pruned=True,
|
||||
)
|
||||
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
||||
score_pruning = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
|
||||
new_time = datetime.now() - before_time
|
||||
|
||||
logger.info(
|
||||
"Pruning: original num of params: %.2e, after pruning %.2e (%.1f percents)",
|
||||
original_num_params,
|
||||
pruned_num_params,
|
||||
pruned_num_params / original_num_params * 100,
|
||||
)
|
||||
logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning)
|
||||
logger.info("Pruning: speed ratio (new timing / original timing): %f percents", original_time / new_time * 100)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_name",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The name of the task to train selected in the list: " + ", ".join(glue_processors.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
|
||||
# Other parameters
|
||||
parser.add_argument(
|
||||
"--config_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained config name or path if not the same as model_name_or_path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name_or_path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from huggingface.co",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_subset", type=int, default=-1, help="If > 0: limit the data to a subset of data_subset instances."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_output_dir", action="store_true", help="Whether to overwrite data in output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dont_normalize_importance_by_layer", action="store_true", help="Don't normalize importance score by layers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dont_normalize_global_importance",
|
||||
action="store_true",
|
||||
help="Don't normalize all importance scores between 0 and 1",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--try_masking", action="store_true", help="Whether to try to mask head until a threshold of accuracy."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--masking_threshold",
|
||||
default=0.9,
|
||||
type=float,
|
||||
help="masking threshold in term of metrics (stop masking when metric < threshold * original metric value).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--masking_amount", default=0.1, type=float, help="Amount to heads to masking at each masking step."
|
||||
)
|
||||
parser.add_argument("--metric_name", default="acc", type=str, help="Metric to use for head masking.")
|
||||
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help=(
|
||||
"The maximum total input sequence length after WordPiece tokenization. \n"
|
||||
"Sequences longer than this will be truncated, sequences shorter padded."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")
|
||||
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
|
||||
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
|
||||
# Setup devices and distributed training
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
|
||||
else:
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
args.device = torch.device("cuda", args.local_rank)
|
||||
args.n_gpu = 1
|
||||
torch.distributed.init_process_group(backend="nccl") # Initializes the distributed backend
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
logger.info("device: {} n_gpu: {}, distributed: {}".format(args.device, args.n_gpu, bool(args.local_rank != -1)))
|
||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||
if is_main_process(args.local_rank):
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
# Set seeds
|
||||
set_seed(args.seed)
|
||||
|
||||
# Prepare GLUE task
|
||||
args.task_name = args.task_name.lower()
|
||||
if args.task_name not in glue_processors:
|
||||
raise ValueError("Task not found: %s" % (args.task_name))
|
||||
processor = glue_processors[args.task_name]()
|
||||
args.output_mode = glue_output_modes[args.task_name]
|
||||
label_list = processor.get_labels()
|
||||
num_labels = len(label_list)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
#
|
||||
# Distributed training:
|
||||
# The .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=args.task_name,
|
||||
output_attentions=True,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
|
||||
# Distributed and parallel training
|
||||
model.to(args.device)
|
||||
if args.local_rank != -1:
|
||||
model = nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||
)
|
||||
elif args.n_gpu > 1:
|
||||
model = nn.DataParallel(model)
|
||||
|
||||
# Print/save training arguments
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
torch.save(args, os.path.join(args.output_dir, "run_args.bin"))
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Prepare dataset for the GLUE task
|
||||
eval_dataset = GlueDataset(args, tokenizer=tokenizer, mode="dev")
|
||||
if args.data_subset > 0:
|
||||
eval_dataset = Subset(eval_dataset, list(range(min(args.data_subset, len(eval_dataset)))))
|
||||
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset, sampler=eval_sampler, batch_size=args.batch_size, collate_fn=default_data_collator
|
||||
)
|
||||
|
||||
# Compute head entropy and importance score
|
||||
compute_heads_importance(args, model, eval_dataloader)
|
||||
|
||||
# Try head masking (set heads to zero until the score goes under a threshole)
|
||||
# and head pruning (remove masked heads and see the effect on the network)
|
||||
if args.try_masking and args.masking_threshold > 0.0 and args.masking_threshold < 1.0:
|
||||
head_mask = mask_heads(args, model, eval_dataloader)
|
||||
prune_heads(args, model, eval_dataloader, head_mask)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,391 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""This script is adapted from the Bertology pruning code (https://github.com/huggingface/transformers/blob/783d7d2629e97c5f0c5f9ef01b8c66410275c204/examples/research_projects/bertology/run_bertology.py)
|
||||
to prune GPT-like models. The author is @altsoph.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader, RandomSampler, TensorDataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import GPT2LMHeadModel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def save_model(model, dirpath):
|
||||
# save results
|
||||
if os.path.exists(dirpath):
|
||||
if os.path.exists(os.path.join(dirpath, "config.json")) and os.path.isfile(
|
||||
os.path.join(dirpath, "config.json")
|
||||
):
|
||||
os.remove(os.path.join(dirpath, "config.json"))
|
||||
if os.path.exists(os.path.join(dirpath, "pytorch_model.bin")) and os.path.isfile(
|
||||
os.path.join(dirpath, "pytorch_model.bin")
|
||||
):
|
||||
os.remove(os.path.join(dirpath, "pytorch_model.bin"))
|
||||
else:
|
||||
os.makedirs(dirpath)
|
||||
model.save_pretrained(dirpath)
|
||||
|
||||
|
||||
def entropy(p, unlogit=False):
|
||||
"""Compute the entropy of a probability distribution"""
|
||||
exponent = 2
|
||||
if unlogit:
|
||||
p = torch.pow(p, exponent)
|
||||
plogp = p * torch.log(p)
|
||||
plogp[p == 0] = 0
|
||||
return -plogp.sum(dim=-1)
|
||||
|
||||
|
||||
def print_2d_tensor(tensor):
|
||||
"""Print a 2D tensor"""
|
||||
logger.info("lv, h >\t" + "\t".join(f"{x + 1}" for x in range(len(tensor))))
|
||||
for row in range(len(tensor)):
|
||||
if tensor.dtype != torch.long:
|
||||
logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:.5f}" for x in tensor[row].cpu().data))
|
||||
else:
|
||||
logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:d}" for x in tensor[row].cpu().data))
|
||||
|
||||
|
||||
def compute_heads_importance(
|
||||
args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None, actually_pruned=False
|
||||
):
|
||||
"""This method shows how to compute:
|
||||
- head attention entropy
|
||||
- head importance scores according to http://arxiv.org/abs/1905.10650
|
||||
"""
|
||||
# Prepare our tensors
|
||||
n_layers, n_heads = model.config.num_hidden_layers, model.config.num_attention_heads
|
||||
head_importance = torch.zeros(n_layers, n_heads).to(args.device)
|
||||
attn_entropy = torch.zeros(n_layers, n_heads).to(args.device)
|
||||
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(n_layers, n_heads).to(args.device)
|
||||
|
||||
head_mask.requires_grad_(requires_grad=True)
|
||||
# If actually pruned attention multi-head, set head mask to None to avoid shape mismatch
|
||||
if actually_pruned:
|
||||
head_mask = None
|
||||
|
||||
tot_tokens = 0.0
|
||||
total_loss = 0.0
|
||||
for step, inputs in enumerate(tqdm(eval_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
|
||||
inputs = tuple(t.to(args.device) for t in inputs)
|
||||
(input_ids,) = inputs
|
||||
|
||||
# Do a forward pass (not with torch.no_grad() since we need gradients for importance score - see below)
|
||||
outputs = model(input_ids, labels=input_ids, head_mask=head_mask)
|
||||
# (loss), lm_logits, presents, (all hidden_states), (attentions)
|
||||
loss, _, all_attentions = (
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
outputs[-1],
|
||||
) # Loss and logits are the first, attention the last
|
||||
loss.backward() # Backpropagate to populate the gradients in the head mask
|
||||
total_loss += loss.detach().cpu().numpy()
|
||||
if compute_entropy:
|
||||
for layer, attn in enumerate(all_attentions):
|
||||
masked_entropy = entropy(attn.detach(), True)
|
||||
attn_entropy[layer] += masked_entropy.sum(-1).sum(0).sum(0).detach()
|
||||
|
||||
if compute_importance:
|
||||
head_importance += head_mask.grad.abs().detach()
|
||||
tot_tokens += torch.ones_like(input_ids).float().detach().sum().data
|
||||
|
||||
# Normalize
|
||||
attn_entropy /= tot_tokens
|
||||
head_importance /= tot_tokens
|
||||
# Layerwise importance normalization
|
||||
if not args.dont_normalize_importance_by_layer:
|
||||
exponent = 2
|
||||
norm_by_layer = torch.pow(torch.pow(head_importance, exponent).sum(-1), 1 / exponent)
|
||||
head_importance /= norm_by_layer.unsqueeze(-1) + 1e-20
|
||||
|
||||
if not args.dont_normalize_global_importance:
|
||||
head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min())
|
||||
|
||||
# Print matrices
|
||||
if compute_entropy:
|
||||
logger.info("Attention entropies")
|
||||
print_2d_tensor(attn_entropy)
|
||||
if compute_importance:
|
||||
logger.info("Head importance scores")
|
||||
print_2d_tensor(head_importance)
|
||||
logger.info("Head ranked by importance scores")
|
||||
head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=args.device)
|
||||
head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(
|
||||
head_importance.numel(), device=args.device
|
||||
)
|
||||
head_ranks = head_ranks.view_as(head_importance)
|
||||
print_2d_tensor(head_ranks)
|
||||
return attn_entropy, head_importance, total_loss
|
||||
|
||||
|
||||
def mask_heads(args, model, eval_dataloader):
|
||||
"""This method shows how to mask head (set some heads to zero), to test the effect on the network,
|
||||
based on the head importance scores, as described in Michel et al. (http://arxiv.org/abs/1905.10650)
|
||||
"""
|
||||
_, head_importance, loss = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False)
|
||||
original_score = 1 / loss # instead of downsteam score use the LM loss
|
||||
logger.info("Pruning: original score: %f, threshold: %f", original_score, original_score * args.masking_threshold)
|
||||
|
||||
new_head_mask = torch.ones_like(head_importance)
|
||||
num_to_mask = max(1, int(new_head_mask.numel() * args.masking_amount))
|
||||
|
||||
current_score = original_score
|
||||
while current_score >= original_score * args.masking_threshold:
|
||||
head_mask = new_head_mask.clone().detach() # save current head mask
|
||||
# heads from least important to most - keep only not-masked heads
|
||||
head_importance[head_mask == 0.0] = float("Inf")
|
||||
current_heads_to_mask = head_importance.view(-1).sort()[1]
|
||||
|
||||
if len(current_heads_to_mask) <= num_to_mask:
|
||||
print("BREAK BY num_to_mask")
|
||||
break
|
||||
|
||||
# mask heads
|
||||
current_heads_to_mask = current_heads_to_mask[:num_to_mask]
|
||||
logger.info("Heads to mask: %s", str(current_heads_to_mask.tolist()))
|
||||
new_head_mask = new_head_mask.view(-1)
|
||||
new_head_mask[current_heads_to_mask] = 0.0
|
||||
new_head_mask = new_head_mask.view_as(head_mask)
|
||||
new_head_mask = new_head_mask.clone().detach()
|
||||
print_2d_tensor(new_head_mask)
|
||||
|
||||
# Compute metric and head importance again
|
||||
_, head_importance, loss = compute_heads_importance(
|
||||
args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask
|
||||
)
|
||||
current_score = 1 / loss
|
||||
logger.info(
|
||||
"Masking: current score: %f, remaining heads %d (%.1f percents)",
|
||||
current_score,
|
||||
new_head_mask.sum(),
|
||||
new_head_mask.sum() / new_head_mask.numel() * 100,
|
||||
)
|
||||
|
||||
logger.info("Final head mask")
|
||||
print_2d_tensor(head_mask)
|
||||
np.save(os.path.join(args.output_dir, "head_mask.npy"), head_mask.detach().cpu().numpy())
|
||||
|
||||
return head_mask
|
||||
|
||||
|
||||
def prune_heads(args, model, eval_dataloader, head_mask):
|
||||
"""This method shows how to prune head (remove heads weights) based on
|
||||
the head importance scores as described in Michel et al. (http://arxiv.org/abs/1905.10650)
|
||||
"""
|
||||
# Try pruning and test time speedup
|
||||
# Pruning is like masking but we actually remove the masked weights
|
||||
before_time = datetime.now()
|
||||
_, _, loss = compute_heads_importance(
|
||||
args, model, eval_dataloader, compute_entropy=False, compute_importance=False, head_mask=head_mask
|
||||
)
|
||||
score_masking = 1 / loss
|
||||
original_time = datetime.now() - before_time
|
||||
|
||||
original_num_params = sum(p.numel() for p in model.parameters())
|
||||
heads_to_prune = {
|
||||
layer: (1 - head_mask[layer].long()).nonzero().squeeze().tolist() for layer in range(len(head_mask))
|
||||
}
|
||||
|
||||
for k, v in heads_to_prune.items():
|
||||
if isinstance(v, int):
|
||||
heads_to_prune[k] = [
|
||||
v,
|
||||
]
|
||||
|
||||
assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item()
|
||||
model.prune_heads(heads_to_prune)
|
||||
pruned_num_params = sum(p.numel() for p in model.parameters())
|
||||
|
||||
before_time = datetime.now()
|
||||
_, _, loss = compute_heads_importance(
|
||||
args,
|
||||
model,
|
||||
eval_dataloader,
|
||||
compute_entropy=False,
|
||||
compute_importance=False,
|
||||
head_mask=None,
|
||||
actually_pruned=True,
|
||||
)
|
||||
|
||||
score_pruning = 1 / loss
|
||||
new_time = datetime.now() - before_time
|
||||
|
||||
logger.info(
|
||||
"Pruning: original num of params: %.2e, after pruning %.2e (%.1f percents)",
|
||||
original_num_params,
|
||||
pruned_num_params,
|
||||
pruned_num_params / original_num_params * 100,
|
||||
)
|
||||
logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning)
|
||||
logger.info("Pruning: speed ratio (original timing / new timing): %f percents", original_time / new_time * 100)
|
||||
save_model(model, args.output_dir)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
|
||||
# Other parameters
|
||||
parser.add_argument(
|
||||
"--config_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained config name or path if not the same as model_name_or_path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name_or_path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_subset", type=int, default=-1, help="If > 0: limit the data to a subset of data_subset instances."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_output_dir", action="store_true", help="Whether to overwrite data in output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dont_normalize_importance_by_layer", action="store_true", help="Don't normalize importance score by layers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dont_normalize_global_importance",
|
||||
action="store_true",
|
||||
help="Don't normalize all importance scores between 0 and 1",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--try_masking", action="store_true", help="Whether to try to mask head until a threshold of accuracy."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--masking_threshold",
|
||||
default=0.9,
|
||||
type=float,
|
||||
help="masking threshold in term of metrics (stop masking when metric < threshold * original metric value).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--masking_amount", default=0.1, type=float, help="Amount to heads to masking at each masking step."
|
||||
)
|
||||
parser.add_argument("--metric_name", default="acc", type=str, help="Metric to use for head masking.")
|
||||
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help=(
|
||||
"The maximum total input sequence length after WordPiece tokenization. \n"
|
||||
"Sequences longer than this will be truncated, sequences shorter padded."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")
|
||||
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
|
||||
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
|
||||
# Setup devices and distributed training
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
|
||||
else:
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
args.device = torch.device("cuda", args.local_rank)
|
||||
args.n_gpu = 1
|
||||
torch.distributed.init_process_group(backend="nccl") # Initializes the distributed backend
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
logger.info("device: {} n_gpu: {}, distributed: {}".format(args.device, args.n_gpu, bool(args.local_rank != -1)))
|
||||
|
||||
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path)
|
||||
|
||||
# Distributed and parallel training
|
||||
model.to(args.device)
|
||||
if args.local_rank != -1:
|
||||
model = nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||
)
|
||||
elif args.n_gpu > 1:
|
||||
model = nn.DataParallel(model)
|
||||
|
||||
# Print/save training arguments
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
torch.save(args, os.path.join(args.output_dir, "run_args.bin"))
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Prepare dataset
|
||||
numpy_data = np.concatenate(
|
||||
[
|
||||
np.loadtxt(args.data_dir, dtype=np.int64),
|
||||
]
|
||||
)
|
||||
train_tensor_dataset = (torch.from_numpy(numpy_data),)
|
||||
train_data = TensorDataset(*train_tensor_dataset)
|
||||
train_sampler = RandomSampler(train_data)
|
||||
eval_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.batch_size)
|
||||
|
||||
# Compute head entropy and importance score
|
||||
compute_heads_importance(args, model, eval_dataloader)
|
||||
|
||||
# Try head masking (set heads to zero until the score goes under a threshole)
|
||||
# and head pruning (remove masked heads and see the effect on the network)
|
||||
if args.try_masking and args.masking_threshold > 0.0 and args.masking_threshold < 1.0:
|
||||
head_mask = mask_heads(args, model, eval_dataloader)
|
||||
prune_heads(args, model, eval_dataloader, head_mask)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,316 +0,0 @@
|
||||
# CodeParrot 🦜
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/lvwerra/repo-images/raw/main/code-highlighting-streamlit.png" alt="drawing" width="350"/>
|
||||
</p>
|
||||
|
||||
## What is this about?
|
||||
This is an open-source effort to train and evaluate code generation models. CodeParrot 🦜 is a GPT-2 model trained from scratch on Python code. The highlights of this project are:
|
||||
- initialize and train a GPT-2 language model from scratch for code generation
|
||||
- train a custom tokenizer adapted for Python code
|
||||
- clean and deduplicate a large (>100GB) dataset with `datasets`
|
||||
- train with `accelerate` on multiple GPUs using data parallelism and mixed precision
|
||||
- continuously push checkpoints to the hub with `huggingface_hub`
|
||||
- stream the dataset with `datasets` during training to avoid disk bottlenecks
|
||||
- apply the `code_eval` metric in `datasets` to evaluate on [OpenAI's _HumanEval_ benchmark](https://huggingface.co/datasets/openai_humaneval)
|
||||
- showcase examples for downstream tasks with code models in [examples](https://github.com/huggingface/transformers/tree/main/examples/research_projects/codeparrot/examples) folder:
|
||||
- Algorithmic complexity prediction
|
||||
- Code generation from english text
|
||||
- Code explanation
|
||||
|
||||
## Installation
|
||||
To install the dependencies simply run the following command:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
To reproduce the results you can follow the scripts in the following sections. Note that we don't always show all possible arguments to the scripts. To get the full list of arguments with descriptions you can run the following command on any script:
|
||||
|
||||
```bash
|
||||
python scripts/some_script.py --help
|
||||
```
|
||||
|
||||
Before you run any of the scripts make sure you are logged in and can push to the hub:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
Additionally, sure you have git-lfs installed. You can find instructions for how to install it [here](https://git-lfs.github.com/).
|
||||
|
||||
## Dataset
|
||||
The source of the dataset is the GitHub dump available on Google's [BigQuery](https://cloud.google.com/blog/topics/public-datasets/github-on-bigquery-analyze-all-the-open-source-code). The database was queried for all Python files with less than 1MB in size resulting in a 180GB dataset with over 20M files. The dataset is available on the Hugging Face Hub [here](https://huggingface.co/datasets/transformersbook/codeparrot).
|
||||
|
||||
### Preprocessing
|
||||
The raw dataset contains many duplicates. We deduplicated and filtered the dataset using the heuristics proposed in OpenAI's Codex [paper](https://arxiv.org/abs/2107.03374) and some new ones:
|
||||
|
||||
- exact deduplication using each file's hash after having removed whistespaces.
|
||||
- near deduplication using MinHash and Jaccard similarity. MinHash with a Jaccard threshold (default=0.85) is first used to create duplicate clusters. Then these clusters are then reduced to unique files based on the exact Jaccard similarity. See `deduplicate_dataset` in `minhash_deduplication.py` for a detailed description.
|
||||
- filtering files with max line length > 1000
|
||||
- filtering files with mean line length > 100
|
||||
- fraction of alphanumeric characters < 0.25
|
||||
- containing the word "auto-generated" or similar in the first 5 lines
|
||||
- filtering with a probability of 0.7 of files with a mention of "test file" or "configuration file" or similar in the first 5 lines
|
||||
- filtering with a probability of 0.7 of files with high occurrence of the keywords "test " or "config"
|
||||
- filtering with a probability of 0.7 of files without a mention of the keywords `def` , `for`, `while` and `class`
|
||||
- filtering files that use the assignment operator `=` less than 5 times
|
||||
- filtering files with ratio between number of characters and number of tokens after tokenization < 1.5 (the average ratio is 3.6)
|
||||
|
||||
The script to process the full dataset can be found in `scripts/preprocessing.py`. Executing the script on 16 vCPUs takes roughly 3h and removes 70% of the original dataset. The cleaned [train](https://huggingface.co/datasets/codeparrot/codeparrot-clean-train-v2) and [validation](https://huggingface.co/datasets/codeparrot/codeparrot-clean-valid-v2) splits are also available on the Hub if you want to skip this step or use the data for another project.
|
||||
|
||||
To execute the preprocessing run the following command:
|
||||
```bash
|
||||
python scripts/preprocessing.py \
|
||||
--dataset_name transformersbook/codeparrot \
|
||||
--output_dir codeparrot-clean
|
||||
```
|
||||
During preprocessing the dataset is downloaded and stored locally as well as caches of the computations. Make sure you have more than 500GB free disk space to execute it.
|
||||
|
||||
### Pretokenization
|
||||
The tokenization of the data might be slow during the training especially for small models. We provide code to pretokenize the data beforehand in `scripts/pretokenizing.py`, but this step is optional. The dataset is downloaded and stored locally and the tokenized data is pushed to the hub. The tokenized clean [train](https://huggingface.co/datasets/codeparrot/tokenized-codeparrot-train) and [validation](https://huggingface.co/datasets/codeparrot/tokenized-codeparrot-valid) datasets are available if you want to use them directly.
|
||||
|
||||
To execute the pretokenization, for the clean train data for instance, run the following command:
|
||||
```bash
|
||||
python scripts/pretokenizing.py \
|
||||
--dataset_name codeparrot/codeparrot-clean-train \
|
||||
--tokenized_data_repo tokenized-codeparrot-train
|
||||
```
|
||||
|
||||
## Tokenizer
|
||||
Before training a new model for code we create a new tokenizer that is efficient at code tokenization. To train the tokenizer you can run the following command:
|
||||
```bash
|
||||
python scripts/bpe_training.py \
|
||||
--base_tokenizer openai-community/gpt2 \
|
||||
--dataset_name codeparrot/codeparrot-clean-train
|
||||
```
|
||||
|
||||
_Note:_ We originally trained the tokenizer on the unprocessed train split of the dataset `transformersbook/codeparrot-train`.
|
||||
|
||||
## Training
|
||||
The models are randomly initialized and trained from scratch. To initialize a new model you can run:
|
||||
|
||||
```bash
|
||||
python scripts/initialize_model.py \
|
||||
--config_name openai-community/gpt2-large \
|
||||
--tokenizer_name codeparrot/codeparrot \
|
||||
--model_name codeparrot \
|
||||
--push_to_hub True
|
||||
```
|
||||
This will initialize a new model with the architecture and configuration of `openai-community/gpt2-large` and use the tokenizer to appropriately size the input embeddings. Finally, the initilaized model is pushed the hub.
|
||||
|
||||
We can either pass the name of a text dataset or a pretokenized dataset which speeds up training a bit.
|
||||
Now that the tokenizer and model are also ready we can start training the model. The main training script is built with `accelerate` to scale across a wide range of platforms and infrastructure scales. We train two models with [110M](https://huggingface.co/codeparrot/codeparrot-small/) and [1.5B](https://huggingface.co/codeparrot/codeparrot/) parameters for 25-30B tokens on a 16xA100 (40GB) machine which takes 1 day and 1 week, respectively.
|
||||
|
||||
First you need to configure `accelerate` and login to Weights & Biases:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
wandb login
|
||||
```
|
||||
|
||||
Note that during the `accelerate` configuration we enabled FP16. Then to train the large model you can run
|
||||
|
||||
```bash
|
||||
accelerate launch scripts/codeparrot_training.py
|
||||
```
|
||||
|
||||
If you want to train the small model you need to make some modifications:
|
||||
|
||||
```bash
|
||||
accelerate launch scripts/codeparrot_training.py \
|
||||
--model_ckpt codeparrot/codeparrot-small \
|
||||
--train_batch_size 12 \
|
||||
--valid_batch_size 12 \
|
||||
--learning_rate 5e-4 \
|
||||
--num_warmup_steps 2000 \
|
||||
--gradient_accumulation 1 \
|
||||
--gradient_checkpointing False \
|
||||
--max_train_steps 150000 \
|
||||
--save_checkpoint_steps 15000
|
||||
```
|
||||
|
||||
Recall that you can see the full set of possible options with descriptions (for all scripts) by running:
|
||||
|
||||
```bash
|
||||
python scripts/codeparrot_training.py --help
|
||||
```
|
||||
|
||||
Instead of streaming the dataset from the hub you can also stream it from disk. This can be helpful for long training runs where the connection can be interrupted sometimes. To stream locally you simply need to clone the datasets and replace the dataset name with their path. In this example we store the data in a folder called `data`:
|
||||
|
||||
```bash
|
||||
git lfs install
|
||||
mkdir data
|
||||
git -C "./data" clone https://huggingface.co/datasets/codeparrot/codeparrot-clean-train
|
||||
git -C "./data" clone https://huggingface.co/datasets/codeparrot/codeparrot-clean-valid
|
||||
```
|
||||
|
||||
And then pass the paths to the datasets when we run the training script:
|
||||
|
||||
```bash
|
||||
accelerate launch scripts/codeparrot_training.py \
|
||||
--model_ckpt codeparrot/codeparrot-small \
|
||||
--dataset_name_train ./data/codeparrot-clean-train \
|
||||
--dataset_name_valid ./data/codeparrot-clean-valid \
|
||||
--train_batch_size 12 \
|
||||
--valid_batch_size 12 \
|
||||
--learning_rate 5e-4 \
|
||||
--num_warmup_steps 2000 \
|
||||
--gradient_accumulation 1 \
|
||||
--gradient_checkpointing False \
|
||||
--max_train_steps 150000 \
|
||||
--save_checkpoint_steps 15000
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
For evaluating the language modeling loss on the validation set or any other dataset you can use the following command:
|
||||
```bash
|
||||
python scripts/validation_loss.py \
|
||||
--model_ckpt codeparrot/codeparrot \
|
||||
--dataset_name codeparrot/codeparrot-clean-valid
|
||||
```
|
||||
In addition we evaluate the model on OpenAI's _HumanEval_ benchmark. You can run the evaluation with the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch scripts/human_eval.py --model_ckpt codeparrot/codeparrot \
|
||||
--do_sample True \
|
||||
--temperature 0.2 \
|
||||
--top_p 0.95 \
|
||||
--n_samples=200 \
|
||||
--HF_ALLOW_CODE_EVAL="0"
|
||||
```
|
||||
|
||||
The results as well as reference values are shown in the following table:
|
||||
|
||||
| Model | pass@1 | pass@10 | pass@100|
|
||||
|-------|--------|---------|---------|
|
||||
|CodeParrot 🦜 (110M) | 3.80% | 6.57% | 12.78% |
|
||||
|CodeParrot 🦜 (1.5B) | 3.99% | 8.69% | 17.88% |
|
||||
|||||
|
||||
|Codex (25M)| 3.21% | 7.1% | 12.89%|
|
||||
|Codex (85M)| 8.22% | 12.81% | 22.40% |
|
||||
|Codex (300M)| 13.17%| 20.37% | 36.27% |
|
||||
|Codex (12B)| 28.81%| 46.81% | 72.31% |
|
||||
|||||
|
||||
|GPT-neo (125M)| 0.75% | 1.88% | 2.97% |
|
||||
|GPT-neo (1.5B)| 4.79% | 7.47% | 16.30% |
|
||||
|GPT-neo (2.7B)| 6.41% | 11.27% | 21.37% |
|
||||
|GPT-J (6B)| 11.62% | 15.74% | 27.74% |
|
||||
|
||||
The numbers were obtained by sampling with `T = [0.2, 0.6, 0.8]` and picking the best value for each metric. Both CodeParrot 🦜 models are still underfitted and longer training would likely improve the performance.
|
||||
|
||||
## Demo
|
||||
Give the model a shot yourself! There are three demos to interact with CodeParrot 🦜:
|
||||
- [Code generation](https://huggingface.co/spaces/codeparrot/codeparrot-generation)
|
||||
- [Code highlighting](https://huggingface.co/spaces/codeparrot/codeparrot-highlighting)
|
||||
- [Comparison to other code models](https://huggingface.co/spaces/codeparrot/loubnabnl/code-generation-models)
|
||||
|
||||
## Training with Megatron
|
||||
[Megatron](https://github.com/NVIDIA/Megatron-LM) is a framework developed by NVIDIA for training large transformer models. While the CodeParrot code is easy to follow and modify to your needs the Megatron framework lets you train models faster. Below we explain how to use it.
|
||||
|
||||
### Setup
|
||||
You can pull an NVIDIA PyTorch Container that comes with all the required installations from [NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch). See [documentation](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html) for more details:
|
||||
|
||||
With the following Docker command you can run the container (`xx.xx` denotes your Docker version), and clone [Megatron repository](https://github.com/NVIDIA/Megatron-LM) into it:
|
||||
```bash
|
||||
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:xx.xx-py3
|
||||
git clone https://github.com/NVIDIA/Megatron-LM
|
||||
```
|
||||
|
||||
You also need to add the vocabulary file and merges table of the tokenizer that you trained on code into the container. You can also find these files in [vocab.json](https://huggingface.co/codeparrot/codeparrot/raw/main/vocab.json) and [merges.txt](https://huggingface.co/codeparrot/codeparrot/raw/main/merges.txt).
|
||||
```bash
|
||||
sudo docker cp vocab.json CONTAINER_ID:/workspace/Megatron-LM
|
||||
sudo docker cp merges.txt CONTAINER_ID:/workspace/Megatron-LM
|
||||
```
|
||||
|
||||
### Data preprocessing
|
||||
The training data requires preprocessing. First, you need to convert it into a loose json format, with one json containing a text sample per line. In python this can be done this way:
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
train_data = load_dataset('codeparrot/codeparrot-clean-train', split='train')
|
||||
train_data.to_json("codeparrot_data.json", lines=True)
|
||||
```
|
||||
|
||||
The data is then tokenized, shuffled and processed into a binary format for training using the following command:
|
||||
```bash
|
||||
pip install nltk
|
||||
cd Megatron-LM
|
||||
python tools/preprocess_data.py \
|
||||
--input codeparrot_data.json \
|
||||
--output-prefix codeparrot \
|
||||
--vocab vocab.json \
|
||||
--dataset-impl mmap \
|
||||
--tokenizer-type GPT2BPETokenizer \
|
||||
--merge-file merges.txt \
|
||||
--json-keys content \
|
||||
--workers 32 \
|
||||
--chunk-size 25 \
|
||||
--append-eod
|
||||
```
|
||||
This outputs two files `codeparrot_content_document.idx` and `codeparrot_content_document.bin` which are used in the training.
|
||||
|
||||
### Training
|
||||
You can configure the model architecture and training parameters as shown below, or put it in a bash script that you will run. This runs on 8 GPUs the 110M parameter CodeParrot pretraining, with the same settings as before. Note that the data is partitioned by default into a 969:30:1 ratio for training/validation/test sets.
|
||||
```bash
|
||||
GPUS_PER_NODE=8
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6001
|
||||
NNODES=1
|
||||
NODE_RANK=0
|
||||
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
|
||||
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
|
||||
CHECKPOINT_PATH=/workspace/Megatron-LM/experiments/codeparrot-small
|
||||
VOCAB_FILE=vocab.json
|
||||
MERGE_FILE=merges.txt
|
||||
DATA_PATH=codeparrot_content_document
|
||||
GPT_ARGS="--num-layers 12
|
||||
--hidden-size 768
|
||||
--num-attention-heads 12
|
||||
--seq-length 1024
|
||||
--max-position-embeddings 1024
|
||||
--micro-batch-size 12
|
||||
--global-batch-size 192
|
||||
--lr 0.0005
|
||||
--train-iters 150000
|
||||
--lr-decay-iters 150000
|
||||
--lr-decay-style cosine
|
||||
--lr-warmup-iters 2000
|
||||
--weight-decay .1
|
||||
--adam-beta2 .999
|
||||
--fp16
|
||||
--log-interval 10
|
||||
--save-interval 2000
|
||||
--eval-interval 200
|
||||
--eval-iters 10
|
||||
"
|
||||
TENSORBOARD_ARGS="--tensorboard-dir experiments/tensorboard"
|
||||
python3 -m torch.distributed.launch $DISTRIBUTED_ARGS \
|
||||
pretrain_gpt.py \
|
||||
--tensor-model-parallel-size 1 \
|
||||
--pipeline-model-parallel-size 1 \
|
||||
$GPT_ARGS \
|
||||
--vocab-file $VOCAB_FILE \
|
||||
--merge-file $MERGE_FILE \
|
||||
--save $CHECKPOINT_PATH \
|
||||
--load $CHECKPOINT_PATH \
|
||||
--data-path $DATA_PATH \
|
||||
$TENSORBOARD_ARGS
|
||||
```
|
||||
The training takes almost 12 hours in this setting.
|
||||
|
||||
### Convert model to `transformers`
|
||||
After training we want to use the model in `transformers` e.g. to evaluate it on HumanEval. You can convert it to `transformers` following [this](https://huggingface.co/nvidia/megatron-gpt2-345m) tutorial. For instance, after the training is finished you can copy the weights of the last iteration 150k and convert the `model_optim_rng.pt` file to a `pytorch_model.bin` file that is supported by `transformers`.
|
||||
|
||||
```bash
|
||||
mkdir -p nvidia/megatron-codeparrot-small
|
||||
sudo docker cp CONTAINER_ID:/workspace/Megatron-LM/experiments/codeparrot-small/iter_0150000/mp_rank_00/model_optim_rng.pt nvidia/megatron-codeparrot-small
|
||||
git clone https://github.com/huggingface/transformers.git
|
||||
git clone https://github.com/NVIDIA/Megatron-LM.git
|
||||
export PYTHONPATH=Megatron-LM
|
||||
python transformers/src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py nvidia/megatron-codeparrot-small/model_optim_rng.pt
|
||||
```
|
||||
Be careful, you will need to replace the generated vocabulary file and merges table after the conversion, with the original ones if you plan to load the tokenizer from there.
|
||||
|
||||
## Further Resources
|
||||
A detailed description of the project can be found in the chapter "Training Transformers from Scratch" in the upcoming O'Reilly book [Natural Language Processing with Transformers](https://learning.oreilly.com/library/view/natural-language-processing/9781098103231/).
|
||||
|
||||
This example was provided by [Leandro von Werra](www.github.com/lvwerra).
|
||||
@@ -1,58 +0,0 @@
|
||||
# Examples
|
||||
In this folder we showcase some examples to use code models for downstream tasks.
|
||||
|
||||
## Complexity prediction
|
||||
In this task we want to predict the complexity of Java programs in [CodeComplex](https://huggingface.co/datasets/codeparrot/codecomplex) dataset. Using Hugging Face `trainer`, we finetuned [multilingual CodeParrot](https://huggingface.co/codeparrot/codeparrot-small-multi) and [UniXcoder](https://huggingface.co/microsoft/unixcoder-base-nine) on it, and we used the latter to build this Java complexity prediction [space](https://huggingface.co/spaces/codeparrot/code-complexity-predictor) on Hugging Face hub.
|
||||
|
||||
To fine-tune a model on this dataset you can use the following commands:
|
||||
|
||||
```python
|
||||
python train_complexity_predictor.py \
|
||||
--model_ckpt microsoft/unixcoder-base-nine \
|
||||
--num_epochs 60 \
|
||||
--num_warmup_steps 10 \
|
||||
--batch_size 8 \
|
||||
--learning_rate 5e-4
|
||||
```
|
||||
|
||||
## Code generation: text to python
|
||||
In this task we want to train a model to generate code from english text. We finetuned Codeparrot-small on [github-jupyter-text-to-code](https://huggingface.co/datasets/codeparrot/github-jupyter-text-to-code), a dataset where the samples are a succession of docstrings and their Python code, originally extracted from Jupyter notebooks parsed in this [dataset](https://huggingface.co/datasets/codeparrot/github-jupyter-parsed).
|
||||
|
||||
To fine-tune a model on this dataset we use the same [script](https://github.com/huggingface/transformers/blob/main/examples/research_projects/codeparrot/scripts/codeparrot_training.py) as the pretraining of codeparrot:
|
||||
|
||||
```python
|
||||
accelerate launch scripts/codeparrot_training.py \
|
||||
--model_ckpt codeparrot/codeparrot-small \
|
||||
--dataset_name_train codeparrot/github-jupyter-text-to-code \
|
||||
--dataset_name_valid codeparrot/github-jupyter-text-to-code \
|
||||
--train_batch_size 12 \
|
||||
--valid_batch_size 12 \
|
||||
--learning_rate 5e-4 \
|
||||
--num_warmup_steps 100 \
|
||||
--gradient_accumulation 1 \
|
||||
--gradient_checkpointing False \
|
||||
--max_train_steps 3000 \
|
||||
--save_checkpoint_steps 200 \
|
||||
--save_dir jupyter-text-to-python
|
||||
```
|
||||
|
||||
## Code explanation: python to text
|
||||
In this task we want to train a model to explain python code. We finetuned Codeparrot-small on [github-jupyter-code-to-text](https://huggingface.co/datasets/codeparrot/github-jupyter-code-to-text), a dataset where the samples are a succession of Python code and its explanation as a docstring, we just inverted the order of text and code pairs in github-jupyter-code-to-text dataset and added the delimiters "Explanation:" and "End of explanation" inside the doctrings.
|
||||
|
||||
To fine-tune a model on this dataset we use the same [script](https://github.com/huggingface/transformers/blob/main/examples/research_projects/codeparrot/scripts/codeparrot_training.py) as the pretraining of codeparrot:
|
||||
|
||||
```python
|
||||
accelerate launch scripts/codeparrot_training.py \
|
||||
--model_ckpt codeparrot/codeparrot-small \
|
||||
--dataset_name_train codeparrot/github-jupyter-code-to-text \
|
||||
--dataset_name_valid codeparrot/github-jupyter-code-to-text \
|
||||
--train_batch_size 12 \
|
||||
--valid_batch_size 12 \
|
||||
--learning_rate 5e-4 \
|
||||
--num_warmup_steps 100 \
|
||||
--gradient_accumulation 1 \
|
||||
--gradient_checkpointing False \
|
||||
--max_train_steps 3000 \
|
||||
--save_checkpoint_steps 200 \
|
||||
--save_dir jupyter-python-to-text
|
||||
```
|
||||
@@ -1,5 +0,0 @@
|
||||
datasets==2.3.2
|
||||
transformers==4.48.0
|
||||
wandb==0.13.1
|
||||
evaluate==0.2.2
|
||||
scikit-learn==1.5.0
|
||||
@@ -1,132 +0,0 @@
|
||||
import argparse
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
from datasets import ClassLabel, DatasetDict, load_dataset
|
||||
from evaluate import load
|
||||
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
DataCollatorWithPadding,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_ckpt", type=str, default="microsoft/unixcoder-base-nine")
|
||||
parser.add_argument("--num_epochs", type=int, default=5)
|
||||
parser.add_argument("--batch_size", type=int, default=6)
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--freeze", type=bool, default=True)
|
||||
parser.add_argument("--learning_rate", type=float, default=5e-4)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
|
||||
parser.add_argument("--num_warmup_steps", type=int, default=10)
|
||||
parser.add_argument("--weight_decay", type=float, default=0.01)
|
||||
parser.add_argument("--output_dir", type=str, default="./results")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
metric = load("accuracy")
|
||||
|
||||
|
||||
def compute_metrics(eval_pred):
|
||||
predictions, labels = eval_pred
|
||||
predictions = np.argmax(predictions, axis=1)
|
||||
return metric.compute(predictions=predictions, references=labels)
|
||||
|
||||
|
||||
class CustomCallback(TrainerCallback):
|
||||
def __init__(self, trainer) -> None:
|
||||
super().__init__()
|
||||
self._trainer = trainer
|
||||
|
||||
def on_epoch_end(self, args, state, control, **kwargs):
|
||||
if control.should_evaluate:
|
||||
control_copy = deepcopy(control)
|
||||
self._trainer.evaluate(eval_dataset=self._trainer.train_dataset, metric_key_prefix="train")
|
||||
return control_copy
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
set_seed(args.seed)
|
||||
|
||||
dataset = load_dataset("codeparrot/codecomplex", split="train")
|
||||
train_test = dataset.train_test_split(test_size=0.2)
|
||||
test_validation = train_test["test"].train_test_split(test_size=0.5)
|
||||
train_test_validation = DatasetDict(
|
||||
{
|
||||
"train": train_test["train"],
|
||||
"test": test_validation["train"],
|
||||
"valid": test_validation["test"],
|
||||
}
|
||||
)
|
||||
|
||||
print("Loading tokenizer and model")
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_ckpt)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model = AutoModelForSequenceClassification.from_pretrained(args.model_ckpt, num_labels=7)
|
||||
model.config.pad_token_id = model.config.eos_token_id
|
||||
|
||||
if args.freeze:
|
||||
for param in model.roberta.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
labels = ClassLabel(num_classes=7, names=list(set(train_test_validation["train"]["complexity"])))
|
||||
|
||||
def tokenize(example):
|
||||
inputs = tokenizer(example["src"], truncation=True, max_length=1024)
|
||||
label = labels.str2int(example["complexity"])
|
||||
return {
|
||||
"input_ids": inputs["input_ids"],
|
||||
"attention_mask": inputs["attention_mask"],
|
||||
"label": label,
|
||||
}
|
||||
|
||||
tokenized_datasets = train_test_validation.map(
|
||||
tokenize,
|
||||
batched=True,
|
||||
remove_columns=train_test_validation["train"].column_names,
|
||||
)
|
||||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir=args.output_dir,
|
||||
learning_rate=args.learning_rate,
|
||||
lr_scheduler_type=args.lr_scheduler_type,
|
||||
eval_strategy="epoch",
|
||||
save_strategy="epoch",
|
||||
logging_strategy="epoch",
|
||||
per_device_train_batch_size=args.batch_size,
|
||||
per_device_eval_batch_size=args.batch_size,
|
||||
num_train_epochs=args.num_epochs,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
weight_decay=0.01,
|
||||
metric_for_best_model="accuracy",
|
||||
run_name="complexity-java",
|
||||
report_to="wandb",
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized_datasets["train"],
|
||||
eval_dataset=tokenized_datasets["valid"],
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
|
||||
print("Training...")
|
||||
trainer.add_callback(CustomCallback(trainer))
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,9 +0,0 @@
|
||||
transformers==4.38.0
|
||||
datasets==1.16.0
|
||||
wandb==0.12.0
|
||||
tensorboard==2.6.0
|
||||
torch==2.2.0
|
||||
huggingface-hub==0.1.0
|
||||
git+https://github.com/huggingface/accelerate.git@3c45b6f760ad8745be9ebc9bbb26f5b04dea4abe
|
||||
datasketch==1.5.7
|
||||
dpu_utils
|
||||
@@ -1,220 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments:
|
||||
"""
|
||||
Configuration for training model.
|
||||
"""
|
||||
|
||||
model_ckpt: Optional[str] = field(
|
||||
default="codeparrot/codeparrot", metadata={"help": "Model name or path of model to be trained."}
|
||||
)
|
||||
save_dir: Optional[str] = field(
|
||||
default="./", metadata={"help": "Save dir where model repo is cloned and models updates are saved to."}
|
||||
)
|
||||
dataset_name_train: Optional[str] = field(
|
||||
default="codeparrot/codeparrot-clean-train", metadata={"help": "Name or path of training dataset."}
|
||||
)
|
||||
dataset_name_valid: Optional[str] = field(
|
||||
default="codeparrot/codeparrot-clean-valid", metadata={"help": "Name or path of validation dataset."}
|
||||
)
|
||||
train_batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size for training."})
|
||||
valid_batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size for evaluation."})
|
||||
weight_decay: Optional[float] = field(default=0.1, metadata={"help": "Value of weight decay."})
|
||||
shuffle_buffer: Optional[int] = field(
|
||||
default=10000, metadata={"help": "Size of buffer used to shuffle streaming dataset."}
|
||||
)
|
||||
learning_rate: Optional[float] = field(default=2e-4, metadata={"help": "Learning rate fo training."})
|
||||
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "Learning rate."})
|
||||
num_warmup_steps: Optional[int] = field(
|
||||
default=750, metadata={"help": "Number of warmup steps in the learning rate schedule."}
|
||||
)
|
||||
gradient_accumulation_steps: Optional[int] = field(
|
||||
default=16, metadata={"help": "Number of gradient accumulation steps."}
|
||||
)
|
||||
gradient_checkpointing: Optional[bool] = field(
|
||||
default=True, metadata={"help": "Use gradient checkpointing to reduce memory footprint."}
|
||||
)
|
||||
max_train_steps: Optional[int] = field(default=50000, metadata={"help": "Maximum number of training steps."})
|
||||
max_eval_steps: Optional[int] = field(
|
||||
default=-1, metadata={"help": "Maximum number of evaluation steps. If -1 the full dataset is evaluated."}
|
||||
)
|
||||
seq_length: Optional[int] = field(default=1024, metadata={"help": "Sequence lengths used for training."})
|
||||
seed: Optional[int] = field(default=1, metadata={"help": "Training seed."})
|
||||
save_checkpoint_steps: Optional[int] = field(
|
||||
default=1024,
|
||||
metadata={"help": "Interval to save checkpoints. Measured as number of forward passes not training steps."},
|
||||
)
|
||||
resume_from_checkpoint: Optional[str] = field(
|
||||
default=None, metadata={"help": "States path if the training should continue from a checkpoint folder."}
|
||||
)
|
||||
tokenized: Optional[bool] = field(default=False, metadata={"help": "If True the data is pretokenized."})
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluationArguments:
|
||||
"""
|
||||
Configuration for evaluating model.
|
||||
"""
|
||||
|
||||
model_ckpt: Optional[str] = field(
|
||||
default="codeparrot/codeparrot", metadata={"help": "Model name or path of model to be evaluated."}
|
||||
)
|
||||
dataset_name: Optional[str] = field(
|
||||
default="codeparrot/codeparrot-clean-valid", metadata={"help": "Name or path of validation dataset."}
|
||||
)
|
||||
batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size used for evaluation."})
|
||||
max_eval_steps: Optional[int] = field(
|
||||
default=-1, metadata={"help": "Maximum number of evaluation steps. If -1 the full dataset is evaluated."}
|
||||
)
|
||||
seq_length: Optional[int] = field(default=1024, metadata={"help": "Length of sequences to be evaluated."})
|
||||
seed: Optional[int] = field(default=1, metadata={"help": "Random seed used for evaluation."})
|
||||
|
||||
|
||||
@dataclass
|
||||
class HumanEvalArguments:
|
||||
"""
|
||||
Configuration for running evaluation on HumanEval dataset.
|
||||
"""
|
||||
|
||||
model_ckpt: Optional[str] = field(
|
||||
default="codeparrot/codeparrot", metadata={"help": "Model name or path of model to be evaluated."}
|
||||
)
|
||||
num_workers: Optional[int] = field(default=None, metadata={"help": "Number of workers used for code evaluation."})
|
||||
num_tasks: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of human-eval tasks to run. If not included all tasks are evaluated."},
|
||||
)
|
||||
do_sample: Optional[bool] = field(
|
||||
default=True, metadata={"help": "Sample from the language model's output distribution."}
|
||||
)
|
||||
temperature: Optional[float] = field(default=0.2, metadata={"help": "Sampling temperature used for generation."})
|
||||
max_new_tokens: Optional[int] = field(default=256, metadata={"help": "Maximum number of newly generated tokens."})
|
||||
top_k: Optional[int] = field(default=0, metadata={"help": "Top-k parameter used for generation."})
|
||||
top_p: Optional[float] = field(default=0.95, metadata={"help": "Top-p parameter used for nucleus sampling."})
|
||||
batch_size: Optional[int] = field(default=10, metadata={"help": "Number of generations to run in parallel."})
|
||||
n_samples: Optional[int] = field(
|
||||
default=200, metadata={"help": "Number of completions to generate for each sample."}
|
||||
)
|
||||
seed: Optional[int] = field(default=1, metadata={"help": "Random seed used for evaluation."})
|
||||
output_file: Optional[str] = field(
|
||||
default="eval_results.json", metadata={"help": "Random seed used for evaluation."}
|
||||
)
|
||||
HF_ALLOW_CODE_EVAL: Optional[str] = field(
|
||||
default="0", metadata={"help": "Allow `code_eval` to execute Python code on machine"}
|
||||
)
|
||||
device_int: Optional[int] = field(
|
||||
default=-1,
|
||||
metadata={
|
||||
"help": (
|
||||
"Determine which device to run the `text-generation` Pipeline on. -1 is CPU and any zero or positive"
|
||||
" number corresponds to which GPU device id to run on."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreprocessingArguments:
|
||||
"""
|
||||
Configuration for preprocessing data.
|
||||
"""
|
||||
|
||||
num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The number of CPU cores to use for parallel preprocessing. Default uses the maximum available."
|
||||
},
|
||||
)
|
||||
dataset_name: Optional[str] = field(
|
||||
default="transformersbook/codeparrot", metadata={"help": "Folder or name of dataset to process."}
|
||||
)
|
||||
output_dir: Optional[str] = field(
|
||||
default="codeparrot-clean", metadata={"help": "Folder to save processed dataset."}
|
||||
)
|
||||
samples_per_file: Optional[int] = field(
|
||||
default=100_000, metadata={"help": "Number of files to save per JSON output file."}
|
||||
)
|
||||
text_column: Optional[str] = field(default="content", metadata={"help": "Column containing text data to process."})
|
||||
line_max: Optional[float] = field(
|
||||
default=1000, metadata={"help": "Maximum line length in file, otherwise file is filtered."}
|
||||
)
|
||||
line_mean: Optional[float] = field(
|
||||
default=100, metadata={"help": "Maximum mean line length in file, otherwise file is filtered."}
|
||||
)
|
||||
alpha_frac: Optional[float] = field(
|
||||
default=0.25, metadata={"help": "Maximum fraction of non-alphanumeric characters, otherwise file is filtered."}
|
||||
)
|
||||
min_token_ratio: Optional[float] = field(
|
||||
default=1.5, metadata={"help": "Minimum character token ratio for the file, otherwise file is filtered."}
|
||||
)
|
||||
filter_proba: Optional[float] = field(
|
||||
default=0.7, metadata={"help": "Probability for filtering config, test and uncommon files."}
|
||||
)
|
||||
tokenizer: Optional[str] = field(
|
||||
default="codeparrot/codeparrot",
|
||||
metadata={"help": "Name or path to the tokenizer."},
|
||||
)
|
||||
near_deduplication: Optional[bool] = field(
|
||||
default=False, metadata={"help": "If True, near-duplicate samples are removed."}
|
||||
)
|
||||
jaccard_threshold: Optional[float] = field(
|
||||
default=0.85, metadata={"help": "Jaccard threshold for near-duplicate samples."}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenizerTrainingArguments:
|
||||
"""
|
||||
Configuration for tokenizer training.
|
||||
"""
|
||||
|
||||
base_tokenizer: Optional[str] = field(
|
||||
default="openai-community/gpt2", metadata={"help": "Base tokenizer to build new tokenizer from."}
|
||||
)
|
||||
dataset_name: Optional[str] = field(
|
||||
default="transformersbook/codeparrot-train", metadata={"help": "Dataset to train tokenizer on."}
|
||||
)
|
||||
text_column: Optional[str] = field(default="content", metadata={"help": "Column containing text data to process."})
|
||||
vocab_size: Optional[int] = field(default=200_000, metadata={"help": "Number of examples to train tokenizer on."})
|
||||
n_examples: Optional[int] = field(
|
||||
default=32768, metadata={"help": "Number of examples to train the tokenizer on."}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(default="codeparrot", metadata={"help": "Name of new tokenizer."})
|
||||
push_to_hub: Optional[bool] = field(default=True, metadata={"help": "Push saved tokenizer to the hub."})
|
||||
|
||||
|
||||
@dataclass
|
||||
class PretokenizationArguments:
|
||||
"""
|
||||
Configuration for data pretokenization.
|
||||
"""
|
||||
|
||||
tokenizer_dir: Optional[str] = field(
|
||||
default="codeparrot/codeparrot", metadata={"help": "Name or path to the tokenizer."}
|
||||
)
|
||||
dataset_name: Optional[str] = field(
|
||||
default="codeparrot/codeparrot-clean-train", metadata={"help": "Name or path to the dataset to pretokenize."}
|
||||
)
|
||||
tokenized_data_repo: Optional[str] = field(
|
||||
default="tokenized-codeparrot-train", metadata={"help": "Repo name of the pretokenized data."}
|
||||
)
|
||||
num_workers: Optional[int] = field(default=None, metadata={"help": "Number of workers used for code evaluation."})
|
||||
|
||||
|
||||
@dataclass
|
||||
class InitializationArguments:
|
||||
"""
|
||||
Configuration for initializing new model.
|
||||
"""
|
||||
|
||||
config_name: Optional[str] = field(
|
||||
default="openai-community/gpt2-large", metadata={"help": "Configuration to use for model initialization."}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default="codeparrot/codeparrot", metadata={"help": "Tokenizer attached to model."}
|
||||
)
|
||||
model_name: Optional[str] = field(default="codeparrot", metadata={"help": "Name of the created model."})
|
||||
push_to_hub: Optional[bool] = field(default=True, metadata={"help": "Push saved tokenizer to the hub."})
|
||||
@@ -1,32 +0,0 @@
|
||||
from arguments import TokenizerTrainingArguments
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import AutoTokenizer, HfArgumentParser
|
||||
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
|
||||
|
||||
|
||||
# Iterator for Training
|
||||
def batch_iterator(batch_size=10):
|
||||
for _ in tqdm(range(0, args.n_examples, batch_size)):
|
||||
yield [next(iter_dataset)[args.text_column] for _ in range(batch_size)]
|
||||
|
||||
|
||||
# Configuration
|
||||
parser = HfArgumentParser(TokenizerTrainingArguments)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Base tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.base_tokenizer)
|
||||
base_vocab = list(bytes_to_unicode().values())
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset(args.dataset_name, split="train", streaming=True)
|
||||
iter_dataset = iter(dataset)
|
||||
|
||||
|
||||
# Training and saving
|
||||
new_tokenizer = tokenizer.train_new_from_iterator(
|
||||
batch_iterator(), vocab_size=args.vocab_size, initial_alphabet=base_vocab
|
||||
)
|
||||
new_tokenizer.save_pretrained(args.tokenizer_name, push_to_hub=args.push_to_hub)
|
||||
@@ -1,328 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from accelerate.utils import ProjectConfiguration
|
||||
from arguments import TrainingArguments
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import Repository
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import IterableDataset
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe
|
||||
|
||||
import transformers
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed
|
||||
|
||||
|
||||
class ConstantLengthDataset(IterableDataset):
|
||||
"""
|
||||
Iterable dataset that returns constant length chunks of tokens from stream of text files.
|
||||
Args:
|
||||
tokenizer (Tokenizer): The processor used for processing the data.
|
||||
dataset (dataset.Dataset): Dataset with text files.
|
||||
infinite (bool): If True the iterator is reset after dataset reaches end else stops.
|
||||
seq_length (int): Length of token sequences to return.
|
||||
num_of_sequences (int): Number of token sequences to keep in buffer.
|
||||
chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
|
||||
tokenized (bool): If true we use a pretokenized dataset.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
dataset,
|
||||
infinite=False,
|
||||
seq_length=1024,
|
||||
num_of_sequences=1024,
|
||||
chars_per_token=3.6,
|
||||
tokenized=False,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.concat_token_id = tokenizer.bos_token_id
|
||||
self.dataset = dataset
|
||||
self.seq_length = seq_length
|
||||
self.epoch = 0
|
||||
self.infinite = infinite
|
||||
self.current_size = 0
|
||||
self.tokenized = tokenized
|
||||
|
||||
if self.tokenized:
|
||||
self.max_buffer_size = seq_length * num_of_sequences
|
||||
self.content_field = "input_ids"
|
||||
else:
|
||||
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
|
||||
self.content_field = "content"
|
||||
|
||||
def __iter__(self):
|
||||
iterator = iter(self.dataset)
|
||||
more_examples = True
|
||||
while more_examples:
|
||||
buffer, buffer_len = [], 0
|
||||
while True:
|
||||
if buffer_len >= self.max_buffer_size:
|
||||
break
|
||||
try:
|
||||
buffer.append(next(iterator)[self.content_field])
|
||||
buffer_len += len(buffer[-1])
|
||||
except StopIteration:
|
||||
if self.infinite:
|
||||
iterator = iter(self.dataset)
|
||||
self.epoch += 1
|
||||
logger.info(f"Dataset epoch: {self.epoch}")
|
||||
else:
|
||||
more_examples = False
|
||||
break
|
||||
if self.tokenized:
|
||||
tokenized_inputs = buffer
|
||||
else:
|
||||
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
|
||||
all_token_ids = []
|
||||
for tokenized_input in tokenized_inputs:
|
||||
all_token_ids.extend(tokenized_input + [self.concat_token_id])
|
||||
for i in range(0, len(all_token_ids), self.seq_length):
|
||||
input_ids = all_token_ids[i : i + self.seq_length]
|
||||
if len(input_ids) == self.seq_length:
|
||||
self.current_size += 1
|
||||
yield torch.tensor(input_ids)
|
||||
|
||||
def shuffle(self, buffer_size=1000):
|
||||
return ShufflerIterDataPipe(self, buffer_size=buffer_size)
|
||||
|
||||
|
||||
def setup_logging(args):
|
||||
project_name = args.model_ckpt.split("/")[-1]
|
||||
logger = logging.getLogger(__name__)
|
||||
log_dir = Path(args.save_dir) / "log/"
|
||||
log_dir.mkdir(exist_ok=True)
|
||||
filename = f"debug_{accelerator.process_index}.log"
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
handlers=[logging.FileHandler(log_dir / filename), logging.StreamHandler()],
|
||||
)
|
||||
if accelerator.is_main_process: # we only want to setup logging once
|
||||
accelerator.init_trackers(project_name, vars(args))
|
||||
run_name = accelerator.trackers[0].run.name
|
||||
logger.setLevel(logging.INFO)
|
||||
datasets.utils.logging.set_verbosity_info()
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
run_name = ""
|
||||
logger.setLevel(logging.ERROR)
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
return logger, run_name
|
||||
|
||||
|
||||
def create_dataloaders(args):
|
||||
ds_kwargs = {"streaming": True}
|
||||
train_data = load_dataset(args.dataset_name_train, split="train", **ds_kwargs)
|
||||
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed)
|
||||
valid_data = load_dataset(args.dataset_name_valid, split="train", **ds_kwargs)
|
||||
train_dataset = ConstantLengthDataset(
|
||||
tokenizer, train_data, infinite=True, seq_length=args.seq_length, tokenized=args.tokenized
|
||||
)
|
||||
valid_dataset = ConstantLengthDataset(
|
||||
tokenizer, valid_data, infinite=False, seq_length=args.seq_length, tokenized=args.tokenized
|
||||
)
|
||||
train_dataset = train_dataset.shuffle(buffer_size=args.shuffle_buffer)
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
|
||||
eval_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size)
|
||||
return train_dataloader, eval_dataloader
|
||||
|
||||
|
||||
def get_grouped_params(model, args, no_decay=["bias", "ln_1.weight", "ln_2.weight", "ln_f.weight"]):
|
||||
params_with_wd, params_without_wd = [], []
|
||||
for n, p in model.named_parameters():
|
||||
if any(nd in n for nd in no_decay):
|
||||
params_without_wd.append(p)
|
||||
else:
|
||||
params_with_wd.append(p)
|
||||
return [
|
||||
{"params": params_with_wd, "weight_decay": args.weight_decay},
|
||||
{"params": params_without_wd, "weight_decay": 0.0},
|
||||
]
|
||||
|
||||
|
||||
def log_metrics(step, metrics):
|
||||
logger.info(f"Step {step}: {metrics}")
|
||||
if accelerator.is_main_process:
|
||||
accelerator.log(metrics, step)
|
||||
|
||||
|
||||
def compute_tflops(elapsed_time, accelerator, args):
|
||||
# TFLOPs formula (from Equation 3 in Section 5.1 of https://arxiv.org/pdf/2104.04473.pdf).
|
||||
config_model = accelerator.unwrap_model(model).config
|
||||
checkpoint_factor = 4 if args.gradient_checkpointing else 3
|
||||
batch_size = args.train_batch_size * accelerator.state.num_processes * args.gradient_accumulation_steps
|
||||
factor = 24 * checkpoint_factor * batch_size * args.seq_length * config_model.n_layer * (config_model.n_embd**2)
|
||||
flops_per_iteration = factor * (
|
||||
1.0
|
||||
+ (args.seq_length / (6.0 * config_model.n_embd))
|
||||
+ (tokenizer.vocab_size / (16.0 * config_model.n_layer * config_model.n_embd))
|
||||
)
|
||||
tflops = flops_per_iteration / (elapsed_time * accelerator.state.num_processes * (10**12))
|
||||
return tflops
|
||||
|
||||
|
||||
def evaluate(args):
|
||||
model.eval()
|
||||
losses = []
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
with torch.no_grad():
|
||||
outputs = model(batch, labels=batch)
|
||||
loss = outputs.loss.repeat(args.valid_batch_size)
|
||||
losses.append(accelerator.gather(loss))
|
||||
if args.max_eval_steps > 0 and step >= args.max_eval_steps:
|
||||
break
|
||||
losses = torch.cat(losses)
|
||||
loss = losses[: eval_dataloader.dataset.current_size].mean()
|
||||
try:
|
||||
perplexity = torch.exp(loss)
|
||||
except OverflowError:
|
||||
perplexity = float("inf")
|
||||
return loss.item(), perplexity.item()
|
||||
|
||||
|
||||
# Settings
|
||||
parser = HfArgumentParser(TrainingArguments)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Accelerator
|
||||
config = ProjectConfiguration(project_dir=args.save_dir, logging_dir="log")
|
||||
accelerator = Accelerator(log_with=["wandb", "tensorboard"], project_config=config)
|
||||
acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()}
|
||||
|
||||
args = Namespace(**vars(args), **acc_state)
|
||||
samples_per_step = accelerator.state.num_processes * args.train_batch_size
|
||||
set_seed(args.seed)
|
||||
|
||||
# Clone model repository
|
||||
if accelerator.is_main_process:
|
||||
hf_repo = Repository(args.save_dir, clone_from=args.model_ckpt)
|
||||
|
||||
# Logging
|
||||
logger, run_name = setup_logging(args)
|
||||
logger.info(accelerator.state)
|
||||
|
||||
# Checkout new branch on repo
|
||||
if accelerator.is_main_process:
|
||||
hf_repo.git_checkout(run_name, create_branch_ok=True)
|
||||
|
||||
# Load model and tokenizer
|
||||
model = AutoModelForCausalLM.from_pretrained(args.save_dir)
|
||||
if args.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.save_dir)
|
||||
|
||||
# Load dataset and dataloader
|
||||
train_dataloader, eval_dataloader = create_dataloaders(args)
|
||||
|
||||
# Prepare the optimizer and learning rate scheduler
|
||||
optimizer = AdamW(get_grouped_params(model, args), lr=args.learning_rate)
|
||||
lr_scheduler = get_scheduler(
|
||||
name=args.lr_scheduler_type,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.num_warmup_steps,
|
||||
num_training_steps=args.max_train_steps,
|
||||
)
|
||||
accelerator.register_for_checkpointing(lr_scheduler)
|
||||
|
||||
|
||||
def get_lr():
|
||||
return optimizer.param_groups[0]["lr"]
|
||||
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, eval_dataloader
|
||||
)
|
||||
|
||||
# load in the weights and states from a previous save
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
|
||||
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
|
||||
accelerator.load_state(args.resume_from_checkpoint)
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
else:
|
||||
# Get the most recent checkpoint
|
||||
dirs = [f.name for f in os.scandir(args.save_dir) if f.is_dir() and "step" in str(f)]
|
||||
dirs.sort(key=os.path.getctime)
|
||||
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
|
||||
# Extract the step of the checkpoint to continue from there
|
||||
training_difference = os.path.splitext(path)[0]
|
||||
resume_step = int(training_difference.replace("step_", ""))
|
||||
|
||||
# Train model
|
||||
model.train()
|
||||
completed_steps = 0
|
||||
t_start = time.time()
|
||||
loss_tracking = 0
|
||||
for step, batch in enumerate(train_dataloader, start=1):
|
||||
if args.resume_from_checkpoint and step < resume_step:
|
||||
continue # we need to skip steps until we reach the resumed step
|
||||
loss = model(batch, labels=batch, use_cache=False).loss
|
||||
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
||||
loss_tracking += avg_loss.item() / args.gradient_accumulation_steps
|
||||
log_metrics(step, {"samples": step * samples_per_step, "loss_per_step/train": loss.item()})
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
if step % args.gradient_accumulation_steps != 0:
|
||||
# Prevent backward from doing gradient all_reduce in every step
|
||||
if accelerator.distributed_type == DistributedType.MULTI_GPU:
|
||||
with model.no_sync():
|
||||
accelerator.backward(loss)
|
||||
else:
|
||||
accelerator.backward(loss)
|
||||
else:
|
||||
lr = get_lr()
|
||||
accelerator.backward(loss)
|
||||
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
elapsed_time = time.time() - t_start
|
||||
tflops = compute_tflops(elapsed_time, accelerator, args)
|
||||
log_metrics(
|
||||
step,
|
||||
{
|
||||
"steps": completed_steps,
|
||||
"loss/train": loss_tracking,
|
||||
"lr": lr,
|
||||
"tflops": tflops,
|
||||
"time_per_iteration": elapsed_time,
|
||||
},
|
||||
)
|
||||
t_start = time.time()
|
||||
loss_tracking = 0
|
||||
completed_steps += 1
|
||||
if step % args.save_checkpoint_steps == 0:
|
||||
logger.info("Evaluating and saving model checkpoint")
|
||||
eval_loss, perplexity = evaluate(args)
|
||||
log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
|
||||
accelerator.wait_for_everyone()
|
||||
save_dir = os.path.join(args.save_dir, f"step_{step}")
|
||||
accelerator.save_state(save_dir)
|
||||
if accelerator.is_main_process:
|
||||
hf_repo.push_to_hub(commit_message=f"step {step}")
|
||||
model.train()
|
||||
if completed_steps >= args.max_train_steps:
|
||||
break
|
||||
|
||||
# Evaluate and save the last checkpoint
|
||||
logger.info("Evaluating and saving model after training")
|
||||
eval_loss, perplexity = evaluate(args)
|
||||
log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save)
|
||||
save_dir = os.path.join(args.save_dir, f"step_{step}")
|
||||
accelerator.save_state(save_dir)
|
||||
if accelerator.is_main_process:
|
||||
hf_repo.push_to_hub(commit_message="final model")
|
||||
@@ -1,228 +0,0 @@
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import set_seed
|
||||
from arguments import HumanEvalArguments
|
||||
from datasets import load_dataset, load_metric
|
||||
from torch.utils.data import IterableDataset
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
import transformers
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, StoppingCriteria, StoppingCriteriaList
|
||||
|
||||
|
||||
EOF_STRINGS = ["\nclass", "\ndef", "\n#", "\n@", "\nprint", "\nif"]
|
||||
|
||||
|
||||
class TokenizedDataset(IterableDataset):
|
||||
"""Tokenize and preprocess the dataset
|
||||
Multiple copies of the same prompt are sent sequentially.
|
||||
See compute_code for more details.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer, dataset, n_tasks=None, n_copies=1):
|
||||
self.tokenizer = tokenizer
|
||||
self.dataset = dataset
|
||||
self.n_tasks = len(dataset) if n_tasks is None else n_tasks
|
||||
self.n_copies = n_copies
|
||||
|
||||
def __iter__(self):
|
||||
prompts = []
|
||||
for task in range(self.n_tasks):
|
||||
# without strip, the model generate commented codes ...
|
||||
prompts.append(self.tokenizer.eos_token + self.dataset[task]["prompt"].strip())
|
||||
outputs = self.tokenizer(prompts, padding=True, return_tensors="pt")
|
||||
for task in range(self.n_tasks):
|
||||
for _ in range(self.n_copies):
|
||||
yield {
|
||||
"ids": outputs.input_ids[task],
|
||||
"task_id": task,
|
||||
"input_len": outputs.attention_mask[task].sum(),
|
||||
}
|
||||
|
||||
|
||||
class EndOfFunctionCriteria(StoppingCriteria):
|
||||
"""Custom `StoppingCriteria` which checks if all generated functions in the batch are completed."""
|
||||
|
||||
def __init__(self, start_length, eof_strings, tokenizer):
|
||||
self.start_length = start_length
|
||||
self.eof_strings = eof_strings
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def __call__(self, input_ids, scores, **kwargs):
|
||||
"""Returns true if all generated sequences contain any of the end-of-function strings."""
|
||||
decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :])
|
||||
done = []
|
||||
for decoded_generation in decoded_generations:
|
||||
done.append(any(stop_string in decoded_generation for stop_string in self.eof_strings))
|
||||
return all(done)
|
||||
|
||||
|
||||
def remove_last_block(string):
|
||||
"""Remove the last block of the code containing EOF_STRINGS"""
|
||||
string_list = re.split("(%s)" % "|".join(EOF_STRINGS), string)
|
||||
# last string should be ""
|
||||
return "".join(string_list[:-2])
|
||||
|
||||
|
||||
def complete_code(accelerator, model, tokenizer, dataloader, n_tasks, batch_size=20, **gen_kwargs):
|
||||
"""Generate multiple codes for each task in the dataset. This function leverage accelerator to distribute
|
||||
the processing to multiple GPUs.
|
||||
dataloader, a wrapper around a TokenizeDataset objectm is supposed to send all the prompts from
|
||||
the evalution dataset to the modelm as the following:
|
||||
[p_0_0, p_0_1, ..., p_0_nc-1, p_1_0, ..., p_nt-1_nc-1]
|
||||
where nc is the number of copies of the prompt, and nt is the number of tasks.
|
||||
nc is such that num_sample = nc * batch_size
|
||||
|
||||
Parameters
|
||||
----------
|
||||
accelerator: Accelerator
|
||||
|
||||
model: transformers.PreTrainedModel
|
||||
Code generation model. AutoTokenizer.from_pretrained(model_ckpt), ex model_ckpt = "lvwerra/codeparrot"
|
||||
|
||||
tokenizer: transformers.AutoTokenizer
|
||||
The tokenizer used to train model
|
||||
|
||||
dataloader: DataLoader
|
||||
The dataloader is a wrapper around a TokenizeDataset object. It is designed to be used with multiple GPUs.
|
||||
|
||||
n_tasks: int
|
||||
The number of tasks in the dataset. It is used to determine the length of the output.
|
||||
Should be aligned with the number of tasks in the TokenizeDataset.
|
||||
|
||||
batch_size: int
|
||||
num_return_sequences per copy of the prompt such that num_sample = batch_size * n_copies
|
||||
|
||||
gen_kwargs: dict
|
||||
Keyword arguments for the generation function of the model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
code_gens: list of list of str, of length n_tasks
|
||||
List of generated codes for each task.
|
||||
Each element is a list of generated codes for each task, with length num_samples
|
||||
"""
|
||||
gen_token_dict = defaultdict(list) # dict of list of generated tokens
|
||||
for step, batch in tqdm(enumerate(dataloader)):
|
||||
with torch.no_grad():
|
||||
gen_kwargs["stopping_criteria"][0].start_length = batch["ids"].shape[-1]
|
||||
generated_tokens = accelerator.unwrap_model(model).generate(
|
||||
input_ids=batch["ids"][:, : batch["input_len"]], num_return_sequences=batch_size, **gen_kwargs
|
||||
)
|
||||
# each task is generated batch_size times
|
||||
generated_tasks = batch["task_id"].repeat(batch_size)
|
||||
generated_tokens = accelerator.pad_across_processes(
|
||||
generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
generated_tokens, generated_tasks = accelerator.gather((generated_tokens, generated_tasks))
|
||||
generated_tokens = generated_tokens.cpu().numpy()
|
||||
generated_tasks = generated_tasks.cpu().numpy()
|
||||
|
||||
for task, generated_tokens in zip(generated_tasks, generated_tokens):
|
||||
gen_token_dict[task].append(generated_tokens)
|
||||
|
||||
code_gens = [[] for _ in range(n_tasks)]
|
||||
for task, generated_tokens in gen_token_dict.items():
|
||||
for s in generated_tokens:
|
||||
gen_code = tokenizer.decode(s, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
code_gens[task].append(remove_last_block(gen_code))
|
||||
return code_gens
|
||||
|
||||
|
||||
def main():
|
||||
# Setup configuration
|
||||
parser = HfArgumentParser(HumanEvalArguments)
|
||||
args = parser.parse_args()
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
# enables code execution in code_eval metric
|
||||
os.environ["HF_ALLOW_CODE_EVAL"] = args.HF_ALLOW_CODE_EVAL
|
||||
# make sure tokenizer plays nice with multiprocessing
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
if args.num_workers is None:
|
||||
args.num_workers = multiprocessing.cpu_count()
|
||||
|
||||
# Use dataset load to feed to accelerate
|
||||
accelerator = Accelerator()
|
||||
set_seed(args.seed, device_specific=True)
|
||||
|
||||
# Load model and tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_ckpt)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model_ckpt)
|
||||
|
||||
# Generation settings
|
||||
gen_kwargs = {
|
||||
"do_sample": args.do_sample,
|
||||
"temperature": args.temperature,
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"top_p": args.top_p,
|
||||
"top_k": args.top_k,
|
||||
"stopping_criteria": StoppingCriteriaList([EndOfFunctionCriteria(0, EOF_STRINGS, tokenizer)]),
|
||||
}
|
||||
|
||||
# Load evaluation dataset and metric
|
||||
human_eval = load_dataset("openai_humaneval")
|
||||
code_eval_metric = load_metric("code_eval")
|
||||
|
||||
n_tasks = args.num_tasks if args.num_tasks is not None else len(human_eval["test"])
|
||||
n_copies = args.n_samples // args.batch_size
|
||||
|
||||
human_eval_tokenized = TokenizedDataset(tokenizer, human_eval["test"], n_copies=n_copies, n_tasks=n_tasks)
|
||||
# do not confuse args.batch_size, which is actually the num_return_sequences
|
||||
human_eval_loader = DataLoader(human_eval_tokenized, batch_size=1)
|
||||
|
||||
# Run a quick test to see if code evaluation is enabled
|
||||
try:
|
||||
_ = code_eval_metric.compute(references=[""], predictions=[[""]])
|
||||
except ValueError as exception:
|
||||
print(
|
||||
'Code evaluation not enabled. Read the warning below carefully and then use `--HF_ALLOW_CODE_EVAL="1"`'
|
||||
" flag to enable code evaluation."
|
||||
)
|
||||
raise exception
|
||||
|
||||
model, human_eval_loader = accelerator.prepare(model, human_eval_loader)
|
||||
|
||||
generations = complete_code(
|
||||
accelerator,
|
||||
model,
|
||||
tokenizer,
|
||||
human_eval_loader,
|
||||
n_tasks=n_tasks,
|
||||
batch_size=args.batch_size,
|
||||
**gen_kwargs,
|
||||
)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
references = []
|
||||
|
||||
for task in tqdm(range(n_tasks)):
|
||||
test_func = human_eval["test"][task]["test"]
|
||||
entry_point = f"check({human_eval['test'][task]['entry_point']})"
|
||||
references.append("\n" + test_func + "\n" + entry_point)
|
||||
|
||||
# Evaluate completions with "code_eval" metric
|
||||
pass_at_k, _ = code_eval_metric.compute(
|
||||
references=references, predictions=generations, num_workers=args.num_workers
|
||||
)
|
||||
print(f"Results: {pass_at_k}")
|
||||
|
||||
# Save results to json file
|
||||
with open(args.output_file, "w") as fp:
|
||||
json.dump(pass_at_k, fp)
|
||||
|
||||
|
||||
# For some reason the folliwng seems to be necessary sometimes for code_eval to work nice with multiprocessing
|
||||
# https://stackoverflow.com/questions/60804599/python-multiprocessing-keeps-spawning-the-whole-script
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,27 +0,0 @@
|
||||
from arguments import InitializationArguments
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
|
||||
|
||||
|
||||
# Configuration
|
||||
parser = HfArgumentParser(InitializationArguments)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load codeparrot tokenizer trained for Python code tokenization
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
|
||||
|
||||
# Config: "scale_attn_by_layer_idx" and "reorder_and_upcast_attn" are Mistral stability tweaks
|
||||
config_kwargs = {
|
||||
"vocab_size": len(tokenizer),
|
||||
"scale_attn_by_inverse_layer_idx": True,
|
||||
"reorder_and_upcast_attn": True,
|
||||
}
|
||||
|
||||
# Load model config (GPT-2 large in this case)
|
||||
config = AutoConfig.from_pretrained(args.config_name, **config_kwargs)
|
||||
|
||||
# Initialize new model with config
|
||||
model = AutoModelForCausalLM.from_config(config)
|
||||
|
||||
# Save model to the hub
|
||||
model.save_pretrained(args.model_name, push_to_hub=args.push_to_hub)
|
||||
@@ -1,268 +0,0 @@
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Set, Tuple, Type
|
||||
|
||||
from datasets import Dataset
|
||||
from datasketch import MinHash, MinHashLSH
|
||||
from dpu_utils.utils.iterators import ThreadedIterator
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
NON_ALPHA = re.compile("[^A-Za-z_0-9]")
|
||||
# parameters used in DuplicationIndex
|
||||
MIN_NUM_TOKENS = 10
|
||||
NUM_PERM = 256
|
||||
|
||||
|
||||
def get_min_hash(tokens: List[str]) -> Optional[MinHash]:
|
||||
"""Compute the MinHash of a code snippet."""
|
||||
if len(tokens) < MIN_NUM_TOKENS:
|
||||
return None
|
||||
min_hash = MinHash(num_perm=NUM_PERM)
|
||||
for token in set(tokens):
|
||||
min_hash.update(token.encode())
|
||||
return min_hash
|
||||
|
||||
|
||||
def get_tokens(code: str) -> Set[str]:
|
||||
"""Tokenize a code snippet."""
|
||||
return {t for t in NON_ALPHA.split(code) if len(t.strip()) > 0}
|
||||
|
||||
|
||||
class DuplicationIndex:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
duplication_jaccard_threshold: float = 0.85,
|
||||
):
|
||||
self._duplication_jaccard_threshold = duplication_jaccard_threshold
|
||||
self._num_perm = NUM_PERM
|
||||
self._index = MinHashLSH(threshold=self._duplication_jaccard_threshold, num_perm=self._num_perm)
|
||||
|
||||
self._duplicate_clusters = defaultdict(set)
|
||||
|
||||
def add(self, code_key: Tuple, min_hash: MinHash) -> None:
|
||||
"""Add a key to _index (MinHashLSH)
|
||||
the min_hash is used to query closest matches based on the jaccard_threshold.
|
||||
The new key is either added to a existing cluster of one close match,
|
||||
or a new cluster is created. The clusters created in this way, depend on the order of add.
|
||||
|
||||
Args:
|
||||
code_key (Tuple of (index, repo_name, path)):
|
||||
Theoritically any hasbale key. Here we use a tuple to retrieve the information later.
|
||||
min_hash: MinHash of the code_key.
|
||||
"""
|
||||
close_duplicates = self._index.query(min_hash)
|
||||
if code_key in self._index.keys:
|
||||
print(f"Duplicate key {code_key}")
|
||||
return
|
||||
|
||||
self._index.insert(code_key, min_hash)
|
||||
if len(close_duplicates) > 0:
|
||||
for base_duplicate in close_duplicates:
|
||||
if base_duplicate in self._duplicate_clusters:
|
||||
self._duplicate_clusters[base_duplicate].add(code_key)
|
||||
break
|
||||
else:
|
||||
self._duplicate_clusters[close_duplicates[0]].add(code_key)
|
||||
|
||||
def get_duplicate_clusters(self) -> List[List[Dict]]:
|
||||
"""Export the duplicate clusters.
|
||||
For each cluster, the first element is the base element of the cluster.
|
||||
The base element has an estimation jaccard similarity higher than the threshold with all the other elements.
|
||||
|
||||
Returns:
|
||||
duplicate_clusters (List[List[Dict]]):
|
||||
List of duplicate clusters.
|
||||
"""
|
||||
duplicate_clusters = []
|
||||
for base, duplicates in self._duplicate_clusters.items():
|
||||
cluster = [base] + list(duplicates)
|
||||
# reformat the cluster to be a list of dict
|
||||
cluster = [{"base_index": el[0], "repo_name": el[1], "path": el[2]} for el in cluster]
|
||||
duplicate_clusters.append(cluster)
|
||||
return duplicate_clusters
|
||||
|
||||
def save(self, filepath) -> None:
|
||||
duplicate_clusters = self.get_duplicate_clusters()
|
||||
with open(filepath, "w") as f:
|
||||
json.dump(duplicate_clusters, f)
|
||||
|
||||
|
||||
def _compute_min_hash(element):
|
||||
index, data = element
|
||||
min_hash = get_min_hash([t for t in NON_ALPHA.split(data["content"]) if len(t.strip()) > 0])
|
||||
if min_hash is not None:
|
||||
return (index, data["repo_name"], data["path"]), min_hash
|
||||
|
||||
|
||||
def minhash_iter(dataset_iterator: Type[Dataset]):
|
||||
with mp.Pool() as pool:
|
||||
for data in pool.imap_unordered(
|
||||
_compute_min_hash,
|
||||
ThreadedIterator(dataset_iterator, max_queue_size=10000),
|
||||
chunksize=100,
|
||||
):
|
||||
if data is not None:
|
||||
yield data
|
||||
|
||||
|
||||
def make_duplicate_clusters(dataset_iterator: Type[Dataset], jaccard_threshold: float):
|
||||
"""Find duplicate clusters in the dataset in two steps:
|
||||
1. Compute MinHash for each code snippet. MinHash is a tool for fast jaccard similarity estimation.
|
||||
This step is computed using an asynchronous multiprocessing pool, minhash_iter
|
||||
2. Find duplicate clusters. The computed MinHash is added sequentially to the DuplicationIndex.
|
||||
This step cannot be parallelized. So using asynchronous thread in the previous step helps to speed up the process.
|
||||
"""
|
||||
di = DuplicationIndex(duplication_jaccard_threshold=jaccard_threshold)
|
||||
|
||||
for filename, min_hash in tqdm(ThreadedIterator(minhash_iter(enumerate(dataset_iterator)), max_queue_size=100)):
|
||||
di.add(filename, min_hash)
|
||||
|
||||
# Returns a List[Cluster] where Cluster is List[str] with the filenames.
|
||||
return di.get_duplicate_clusters()
|
||||
|
||||
|
||||
def jaccard_similarity(code1: str, code2: str) -> float:
|
||||
"""Compute the Jaccard similarity of two code snippets."""
|
||||
tokens1 = get_tokens(code1)
|
||||
tokens2 = get_tokens(code2)
|
||||
return len(tokens1 & tokens2) / len(tokens1 | tokens2)
|
||||
|
||||
|
||||
_shared_dataset = None
|
||||
|
||||
|
||||
def _find_cluster_extremes_shared(cluster, jaccard_threshold):
|
||||
"""Find a reduced cluster such that each code in the origin cluster is similar to at least one code in the reduced cluster.
|
||||
Two codes are similar if their Jaccard similarity is above the threshold.
|
||||
|
||||
Args:
|
||||
cluster (List[dict]):
|
||||
cluster is a list of dict, each dict contains the following keys:
|
||||
- base_index
|
||||
- repo_name
|
||||
- path
|
||||
This is a typical output of DuplicationIndex.get_duplicate_clusters()
|
||||
jaccard_threshold (float):
|
||||
threshold for Jaccard similarity.
|
||||
Two codes are similar if their Jaccard similarity is above the threshold.
|
||||
|
||||
Returns:
|
||||
extremes (List[dict]):
|
||||
A reduced representation of the cluster. The field copies is added to each dict.
|
||||
The copies field indicates the number of similar codes in the cluster for a extreme.
|
||||
"""
|
||||
extremes = []
|
||||
for element1 in cluster:
|
||||
code1 = _shared_dataset[element1["base_index"]]["content"]
|
||||
for element2 in extremes:
|
||||
code2 = _shared_dataset[element2["base_index"]]["content"]
|
||||
if jaccard_similarity(code1, code2) >= jaccard_threshold:
|
||||
element2["copies"] += 1
|
||||
break
|
||||
else:
|
||||
element1["copies"] = 1
|
||||
extremes.append(element1)
|
||||
return extremes
|
||||
|
||||
|
||||
def find_extremes(cluster_list, dataset, jaccard_threshold):
|
||||
"""Call the _find_cluster_extremes_shared function in a parallel fashion.
|
||||
|
||||
Args:
|
||||
cluster_list (List[List[Dict]]):
|
||||
each cluster is a list of dicts with the key base_index,
|
||||
referring to the index of the base code in the dataset.
|
||||
dataset (Type[Dataset]):
|
||||
dataset is used to access the content of the code snippets,
|
||||
using the base_index from the cluster_list.
|
||||
dataset is shared between all the processes using a glabal variable (any other way to share the dataset?),
|
||||
otherwise the multi processing is not speeded up.
|
||||
jaccard_threshold (float):
|
||||
the threshold for the jaccard similarity. The default value is 0.85
|
||||
|
||||
Returns:
|
||||
extremes_list (List[Dict]):
|
||||
Each cluster is reduced to extremes.
|
||||
See _find_cluster_extremes_shared for the definition of extremes.
|
||||
"""
|
||||
global _shared_dataset
|
||||
_shared_dataset = dataset
|
||||
extremes_list = []
|
||||
f = partial(_find_cluster_extremes_shared, jaccard_threshold=jaccard_threshold)
|
||||
with mp.Pool() as pool:
|
||||
for extremes in tqdm(
|
||||
pool.imap_unordered(
|
||||
f,
|
||||
cluster_list,
|
||||
),
|
||||
total=len(cluster_list),
|
||||
):
|
||||
extremes_list.append(extremes)
|
||||
return extremes_list
|
||||
|
||||
|
||||
def deduplicate_dataset(
|
||||
dataset: Type[Dataset], jaccard_threshold: float = 0.85
|
||||
) -> Tuple[Type[Dataset], List[List[Dict]]]:
|
||||
"""Deduplicate the dataset using minhash and jaccard similarity.
|
||||
This function first generate duplicate clusters, then each cluster
|
||||
is reduced to the extremes that are similar to the other elements in the cluster.
|
||||
Codes are called similar if their Jaccard similarity is greater than jaccard_threshold (0.85 default).
|
||||
|
||||
Args:
|
||||
dataset (Type[Dataset]):
|
||||
The dataset to deduplicate.
|
||||
jaccard_threshold (float, default=0.85):
|
||||
jaccard threshold to determine if two codes are similar
|
||||
|
||||
Returns:
|
||||
ds_dedup (Type[Dataset]):
|
||||
The deduplicated dataset.
|
||||
duplicate_clusters (List[List[Dict]]):
|
||||
The list of duplicate clusters.
|
||||
Each cluster is a list of dicts with the following keys:
|
||||
- base_index : int
|
||||
The index of the code in the original dataset.
|
||||
- repo_name : str
|
||||
- path : str
|
||||
- copies : int
|
||||
The number of copies of the code in the cluster. (find_cluster_extremes)
|
||||
- is_extreme : bool
|
||||
Whether the code is an extreme in the cluster.
|
||||
All the codes in the cluster are removed from the dataset except the extremes.
|
||||
|
||||
Example:
|
||||
>>> from datasets import load_dataset
|
||||
>>> from minhash_deduplication import deduplicate_dataset
|
||||
>>> ds = load_dataset("lvwerra/codeparrot-clean", split="train")
|
||||
>>> ds_dedup, duplicate_clusters = deduplicate_dataset(ds, jaccard_threshold=0.85)
|
||||
"""
|
||||
duplicate_clusters = make_duplicate_clusters(dataset, jaccard_threshold)
|
||||
duplicate_indices = {x["base_index"] for cluster in duplicate_clusters for x in cluster}
|
||||
extreme_dict = {}
|
||||
extremes_clusters = find_extremes(duplicate_clusters, dataset, jaccard_threshold)
|
||||
for extremes in extremes_clusters:
|
||||
for element in extremes:
|
||||
extreme_dict[element["base_index"]] = element
|
||||
remove_indices = duplicate_indices - set(extreme_dict.keys())
|
||||
ds_filter = dataset.filter(lambda x, idx: idx not in remove_indices, with_indices=True)
|
||||
|
||||
# update duplicate_clusters
|
||||
for cluster in duplicate_clusters:
|
||||
for element in cluster:
|
||||
element["is_extreme"] = element["base_index"] in extreme_dict
|
||||
if element["is_extreme"]:
|
||||
element["copies"] = extreme_dict[element["base_index"]]["copies"]
|
||||
|
||||
print(f"Original dataset size: {len(dataset)}")
|
||||
print(f"Number of duplicate clusters: {len(duplicate_clusters)}")
|
||||
print(f"Files in duplicate cluster: {len(duplicate_indices)}")
|
||||
print(f"Unique files in duplicate cluster: {len(extreme_dict)}")
|
||||
print(f"Filtered dataset size: {len(ds_filter)}")
|
||||
|
||||
return ds_filter, duplicate_clusters
|
||||
@@ -1,215 +0,0 @@
|
||||
import gzip
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from arguments import PreprocessingArguments
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub.utils import insecure_hashlib
|
||||
from minhash_deduplication import deduplicate_dataset
|
||||
|
||||
from transformers import AutoTokenizer, HfArgumentParser
|
||||
|
||||
|
||||
PATTERN = re.compile(r"\s+")
|
||||
|
||||
|
||||
def get_hash(example):
|
||||
"""Get hash of content field."""
|
||||
return {"hash": insecure_hashlib.md5(re.sub(PATTERN, "", example["content"]).encode("utf-8")).hexdigest()}
|
||||
|
||||
|
||||
def line_stats(example):
|
||||
"""Calculates mean and max line length of file."""
|
||||
line_lengths = [len(line) for line in example["content"].splitlines()]
|
||||
return {"line_mean": np.mean(line_lengths), "line_max": max(line_lengths)}
|
||||
|
||||
|
||||
def alpha_stats(example):
|
||||
"""Calculates mean and max line length of file."""
|
||||
alpha_frac = np.mean([c.isalnum() for c in example["content"]])
|
||||
return {"alpha_frac": alpha_frac}
|
||||
|
||||
|
||||
def check_uniques(example, uniques):
|
||||
"""Check if current hash is still in set of unique hashes and remove if true."""
|
||||
if example["hash"] in uniques:
|
||||
uniques.remove(example["hash"])
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_autogenerated(example, scan_width=5):
|
||||
"""Check if file is autogenerated by looking for keywords in the first few lines of the file."""
|
||||
keywords = ["auto-generated", "autogenerated", "automatically generated"]
|
||||
lines = example["content"].splitlines()
|
||||
for _, line in zip(range(scan_width), lines):
|
||||
for keyword in keywords:
|
||||
if keyword in line.lower():
|
||||
return {"autogenerated": True}
|
||||
else:
|
||||
return {"autogenerated": False}
|
||||
|
||||
|
||||
def is_config_or_test(example, scan_width=5, coeff=0.05):
|
||||
"""Check if file is a configuration file or a unit test by :
|
||||
1- looking for keywords in the first few lines of the file.
|
||||
2- counting number of occurrence of the words 'config' and 'test' with respect to number of lines.
|
||||
"""
|
||||
|
||||
keywords = ["unit tests", "test file", "configuration file"]
|
||||
lines = example["content"].splitlines()
|
||||
count_config = 0
|
||||
count_test = 0
|
||||
# first test
|
||||
for _, line in zip(range(scan_width), lines):
|
||||
for keyword in keywords:
|
||||
if keyword in line.lower():
|
||||
return {"config_or_test": True}
|
||||
# second test
|
||||
nlines = example["content"].count("\n")
|
||||
threshold = int(coeff * nlines)
|
||||
for line in lines:
|
||||
count_config += line.lower().count("config")
|
||||
count_test += line.lower().count("test")
|
||||
if count_config > threshold or count_test > threshold:
|
||||
return {"config_or_test": True}
|
||||
return {"config_or_test": False}
|
||||
|
||||
|
||||
def has_no_keywords(example):
|
||||
"""Check if a python file has none of the keywords for: function, class, for loop, while loop."""
|
||||
keywords = ["def ", "class ", "for ", "while "]
|
||||
lines = example["content"].splitlines()
|
||||
for line in lines:
|
||||
for keyword in keywords:
|
||||
if keyword in line.lower():
|
||||
return {"has_no_keywords": False}
|
||||
return {"has_no_keywords": True}
|
||||
|
||||
|
||||
def has_few_assignments(example, minimum=4):
|
||||
"""Check if file uses symbol '=' less than `minimum` times."""
|
||||
lines = example["content"].splitlines()
|
||||
counter = 0
|
||||
for line in lines:
|
||||
counter += line.lower().count("=")
|
||||
if counter > minimum:
|
||||
return {"has_few_assignments": False}
|
||||
return {"has_few_assignments": True}
|
||||
|
||||
|
||||
def char_token_ratio(example):
|
||||
"""Compute character/token ratio of the file with tokenizer."""
|
||||
input_ids = tokenizer(example["content"], truncation=False)["input_ids"]
|
||||
ratio = len(example["content"]) / len(input_ids)
|
||||
return {"ratio": ratio}
|
||||
|
||||
|
||||
def preprocess(example):
|
||||
"""Chain all preprocessing steps into one function to not fill cache."""
|
||||
results = {}
|
||||
results.update(get_hash(example))
|
||||
results.update(line_stats(example))
|
||||
results.update(alpha_stats(example))
|
||||
results.update(char_token_ratio(example))
|
||||
results.update(is_autogenerated(example))
|
||||
results.update(is_config_or_test(example))
|
||||
results.update(has_no_keywords(example))
|
||||
results.update(has_few_assignments(example))
|
||||
return results
|
||||
|
||||
|
||||
def filter(example, uniques, args):
|
||||
"""Filter dataset with heuristics. Config, test and has_no_keywords files are removed with a given probability."""
|
||||
if not check_uniques(example, uniques):
|
||||
return False
|
||||
elif example["autogenerated"]:
|
||||
return False
|
||||
elif example["line_max"] > args.line_max:
|
||||
return False
|
||||
elif example["line_mean"] > args.line_mean:
|
||||
return False
|
||||
elif example["alpha_frac"] < args.alpha_frac:
|
||||
return False
|
||||
elif example["ratio"] < args.min_token_ratio:
|
||||
return False
|
||||
elif example["config_or_test"] and np.random.rand() <= args.filter_proba:
|
||||
return False
|
||||
elif example["has_no_keywords"] and np.random.rand() <= args.filter_proba:
|
||||
return False
|
||||
elif example["has_few_assignments"]:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def compress_file(file_path):
|
||||
"""Compress a file with g-zip."""
|
||||
with open(file_path, "rb") as f_in:
|
||||
with gzip.open(str(file_path) + ".gz", "wb", compresslevel=6) as f_out:
|
||||
shutil.copyfileobj(f_in, f_out)
|
||||
os.unlink(file_path)
|
||||
|
||||
|
||||
# Settings
|
||||
parser = HfArgumentParser(PreprocessingArguments)
|
||||
args = parser.parse_args()
|
||||
if args.num_workers is None:
|
||||
args.num_workers = multiprocessing.cpu_count()
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
|
||||
|
||||
# Load dataset
|
||||
t_start = time.time()
|
||||
ds = load_dataset(args.dataset_name, split="train")
|
||||
print(f"Time to load dataset: {time.time()-t_start:.2f}")
|
||||
|
||||
# Run preprocessing
|
||||
t_start = time.time()
|
||||
ds = ds.map(preprocess, num_proc=args.num_workers)
|
||||
print(f"Time to preprocess dataset: {time.time()-t_start:.2f}")
|
||||
|
||||
# Deduplicate hashes
|
||||
uniques = set(ds.unique("hash"))
|
||||
frac = len(uniques) / len(ds)
|
||||
print(f"Fraction of duplicates: {1-frac:.2%}")
|
||||
|
||||
# Deduplicate data and apply heuristics
|
||||
t_start = time.time()
|
||||
ds_filter = ds.filter(filter, fn_kwargs={"uniques": uniques, "args": args})
|
||||
print(f"Time to filter dataset: {time.time()-t_start:.2f}")
|
||||
print(f"Size of filtered dataset: {len(ds_filter)}")
|
||||
|
||||
# Deduplicate with minhash and jaccard similarity
|
||||
if args.near_deduplication:
|
||||
t_start = time.time()
|
||||
ds_filter, duplicate_clusters = deduplicate_dataset(ds_filter, args.jaccard_threshold)
|
||||
print(f"Time to deduplicate dataset: {time.time()-t_start:.2f}")
|
||||
print(f"Size of deduplicate dataset: {len(ds_filter)}")
|
||||
|
||||
# Save data in batches of samples_per_file
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
# save duplicate_clusters in the output_dir as artifacts
|
||||
# not sure it is the right place the save it
|
||||
if args.near_deduplication:
|
||||
with open(output_dir / "duplicate_clusters.json", "w") as f:
|
||||
json.dump(duplicate_clusters, f)
|
||||
|
||||
data_dir = output_dir / "data"
|
||||
data_dir.mkdir(exist_ok=True)
|
||||
|
||||
t_start = time.time()
|
||||
for file_number, index in enumerate(range(0, len(ds_filter), args.samples_per_file)):
|
||||
file_path = str(data_dir / f"file-{file_number+1:012}.json")
|
||||
end_index = min(len(ds_filter), index + args.samples_per_file)
|
||||
ds_filter.select(list(range(index, end_index))).to_json(file_path)
|
||||
compress_file(file_path)
|
||||
print(f"Time to save dataset: {time.time()-t_start:.2f}")
|
||||
@@ -1,49 +0,0 @@
|
||||
import multiprocessing
|
||||
import time
|
||||
|
||||
from arguments import PretokenizationArguments
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import AutoTokenizer, HfArgumentParser
|
||||
|
||||
|
||||
def tokenize(example):
|
||||
output = {}
|
||||
output["input_ids"] = tokenizer(example["content"], truncation=False)["input_ids"]
|
||||
output["ratio_char_token"] = len(example["content"]) / len(output["input_ids"])
|
||||
return output
|
||||
|
||||
|
||||
parser = HfArgumentParser(PretokenizationArguments)
|
||||
args = parser.parse_args()
|
||||
if args.num_workers is None:
|
||||
args.num_workers = multiprocessing.cpu_count()
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
|
||||
|
||||
t_start = time.time()
|
||||
ds = load_dataset(args.dataset_name, split="train")
|
||||
print(f"Dataset loaded in {time.time()-t_start:.2f}s")
|
||||
|
||||
t_start = time.time()
|
||||
ds = ds.map(
|
||||
tokenize,
|
||||
num_proc=args.num_workers,
|
||||
remove_columns=[
|
||||
"repo_name",
|
||||
"path",
|
||||
"copies",
|
||||
"size",
|
||||
"content",
|
||||
"license",
|
||||
"hash",
|
||||
"line_mean",
|
||||
"line_max",
|
||||
"alpha_frac",
|
||||
"autogenerated",
|
||||
],
|
||||
)
|
||||
print(f"Dataset tokenized in {time.time()-t_start:.2f}s")
|
||||
|
||||
t_start = time.time()
|
||||
ds.push_to_hub(args.tokenized_data_repo)
|
||||
print(f"Data pushed to the hub in {time.time()-t_start:.2f}s")
|
||||
@@ -1,29 +0,0 @@
|
||||
from unittest import TestCase
|
||||
|
||||
from datasets import Dataset
|
||||
from minhash_deduplication import deduplicate_dataset, make_duplicate_clusters
|
||||
|
||||
|
||||
def get_dataset():
|
||||
data_dict = {
|
||||
"repo_name": ["test_repo1", "test_repo2", "test_repo3"],
|
||||
"path": ["test_1.py", "test_2.py", "unit_test.py"],
|
||||
"content": ["a " * 20, "a " * 30, "b " * 7],
|
||||
}
|
||||
dataset = Dataset.from_dict(data_dict)
|
||||
return dataset
|
||||
|
||||
|
||||
class MakeDuplicateClustersTest(TestCase):
|
||||
def test_make_duplicate_clusters(self):
|
||||
ds = get_dataset()
|
||||
duplicate_clusters = make_duplicate_clusters(ds, 0.85)
|
||||
self.assertEqual(len(duplicate_clusters[0]), 2)
|
||||
|
||||
def test_deduplicate_dataset(self):
|
||||
ds = get_dataset()
|
||||
ds_filter, duplicate_clusters = deduplicate_dataset(ds)
|
||||
self.assertEqual(len(ds_filter), 2)
|
||||
print(duplicate_clusters)
|
||||
self.assertEqual(duplicate_clusters[0][0]["copies"], 2)
|
||||
self.assertEqual(duplicate_clusters[0][0]["is_extreme"], True)
|
||||
@@ -1,99 +0,0 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from arguments import EvaluationArguments
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import IterableDataset
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, set_seed
|
||||
|
||||
|
||||
class ConstantLengthDataset(IterableDataset):
|
||||
def __init__(self, tokenizer, dataset, seq_length=1024, num_of_sequences=1024, chars_per_token=3.6):
|
||||
self.tokenizer = tokenizer
|
||||
self.concat_token_id = tokenizer.bos_token_id
|
||||
self.dataset = dataset
|
||||
self.seq_length = seq_length
|
||||
self.input_characters = seq_length * chars_per_token * num_of_sequences
|
||||
|
||||
def __iter__(self):
|
||||
iterator = iter(self.dataset)
|
||||
more_examples = True
|
||||
while more_examples:
|
||||
buffer, buffer_len = [], 0
|
||||
while True:
|
||||
if buffer_len >= self.input_characters:
|
||||
break
|
||||
try:
|
||||
buffer.append(next(iterator)["content"])
|
||||
buffer_len += len(buffer[-1])
|
||||
except StopIteration:
|
||||
more_examples = False
|
||||
break
|
||||
tokenized_inputs = tokenizer(buffer, truncation=False)["input_ids"]
|
||||
all_token_ids = []
|
||||
for tokenized_input in tokenized_inputs:
|
||||
all_token_ids.extend(tokenized_input + [self.concat_token_id])
|
||||
for i in range(0, len(all_token_ids), self.seq_length):
|
||||
input_ids = all_token_ids[i : i + self.seq_length]
|
||||
if len(input_ids) == self.seq_length:
|
||||
yield torch.tensor(input_ids)
|
||||
|
||||
|
||||
def create_dataloader(args):
|
||||
ds_kwargs = {"streaming": True}
|
||||
valid_data = load_dataset(args.dataset_name, split="train", **ds_kwargs)
|
||||
valid_dataset = ConstantLengthDataset(tokenizer, valid_data, seq_length=args.seq_length)
|
||||
eval_dataloader = DataLoader(valid_dataset, batch_size=args.batch_size)
|
||||
return eval_dataloader
|
||||
|
||||
|
||||
def evaluate(args):
|
||||
model.eval()
|
||||
losses = []
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
with torch.no_grad():
|
||||
outputs = model(batch, labels=batch)
|
||||
loss = outputs.loss.repeat(args.batch_size)
|
||||
losses.append(accelerator.gather(loss))
|
||||
|
||||
if args.max_eval_steps > 0 and step >= args.max_eval_steps:
|
||||
break
|
||||
loss = torch.mean(torch.cat(losses))
|
||||
try:
|
||||
perplexity = torch.exp(loss)
|
||||
except OverflowError:
|
||||
perplexity = float("inf")
|
||||
return loss.item(), perplexity.item()
|
||||
|
||||
|
||||
# Setup Accelerator
|
||||
accelerator = Accelerator()
|
||||
|
||||
# Parse configuration
|
||||
parser = HfArgumentParser(EvaluationArguments)
|
||||
args = parser.parse_args()
|
||||
set_seed(args.seed)
|
||||
|
||||
# Logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
|
||||
)
|
||||
|
||||
# Load model and tokenizer
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model_ckpt)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_ckpt)
|
||||
|
||||
# Load dataset and dataloader
|
||||
eval_dataloader = create_dataloader(args)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
model, eval_dataloader = accelerator.prepare(model, eval_dataloader)
|
||||
|
||||
# Evaluate and save the last checkpoint
|
||||
logger.info("Evaluating and saving model after training")
|
||||
eval_loss, perplexity = evaluate(args)
|
||||
logger.info(f"loss/eval: {eval_loss}, perplexity: {perplexity}")
|
||||
@@ -1,240 +0,0 @@
|
||||
absl-py==1.0.0
|
||||
aiohttp==3.10.11
|
||||
aiosignal==1.2.0
|
||||
alembic==1.7.7
|
||||
appdirs==1.4.4
|
||||
APScheduler==3.9.1
|
||||
arrow==1.2.2
|
||||
asttokens==2.0.5
|
||||
astunparse==1.6.3
|
||||
async-timeout==4.0.2
|
||||
attrs==21.4.0
|
||||
audioread==2.1.9
|
||||
autopage==0.5.0
|
||||
backcall==0.2.0
|
||||
backoff==1.11.1
|
||||
backports.zoneinfo==0.2.1
|
||||
binaryornot==0.4.4
|
||||
black==24.3.0
|
||||
boto3==1.16.34
|
||||
botocore==1.19.63
|
||||
Brotli==1.0.9
|
||||
cachetools==5.0.0
|
||||
certifi==2024.7.4
|
||||
cffi==1.15.0
|
||||
chardet==4.0.0
|
||||
charset-normalizer==2.0.12
|
||||
chex==0.1.1
|
||||
click==8.0.4
|
||||
cliff==3.10.1
|
||||
clldutils==3.11.1
|
||||
cloudpickle==2.0.0
|
||||
cmaes==0.8.2
|
||||
cmd2==2.4.0
|
||||
codecarbon==1.2.0
|
||||
colorlog==6.6.0
|
||||
cookiecutter==2.1.1
|
||||
cryptography==44.0.1
|
||||
csvw==2.0.0
|
||||
cycler==0.11.0
|
||||
Cython==0.29.28
|
||||
dash==2.15.0
|
||||
dash-bootstrap-components==1.0.3
|
||||
dash-core-components==2.0.0
|
||||
dash-html-components==2.0.0
|
||||
dash-table==5.0.0
|
||||
datasets==2.0.0
|
||||
decorator==5.1.1
|
||||
Deprecated==1.2.13
|
||||
dill==0.3.4
|
||||
dlinfo==1.2.1
|
||||
dm-tree==0.1.6
|
||||
docker==4.4.4
|
||||
execnet==1.9.0
|
||||
executing==0.8.3
|
||||
faiss-cpu==1.7.2
|
||||
fasteners==0.17.3
|
||||
filelock==3.6.0
|
||||
fire==0.4.0
|
||||
flake8==4.0.1
|
||||
Flask==2.3.2
|
||||
Flask-Compress==1.11
|
||||
flatbuffers==2.0
|
||||
flax==0.4.0
|
||||
fonttools==4.43.0
|
||||
frozenlist==1.3.0
|
||||
fsspec==2022.2.0
|
||||
fugashi==1.1.2
|
||||
gast==0.5.3
|
||||
gitdb==4.0.9
|
||||
GitPython==3.1.41
|
||||
glfw==2.5.1
|
||||
google-auth==2.6.2
|
||||
google-auth-oauthlib==0.4.6
|
||||
google-pasta==0.2.0
|
||||
greenlet==1.1.2
|
||||
grpcio==1.53.2
|
||||
gym==0.23.1
|
||||
gym-notices==0.0.6
|
||||
h5py==3.6.0
|
||||
huggingface-hub==0.4.0
|
||||
hypothesis==6.39.4
|
||||
idna==3.7
|
||||
imageio==2.16.1
|
||||
importlib-metadata==4.11.3
|
||||
importlib-resources==5.4.0
|
||||
iniconfig==1.1.1
|
||||
ipadic==1.0.0
|
||||
ipython==8.10.0
|
||||
isodate==0.6.1
|
||||
isort==5.10.1
|
||||
itsdangerous==2.1.1
|
||||
jax==0.3.4
|
||||
jaxlib==0.3.2
|
||||
jedi==0.18.1
|
||||
Jinja2==3.1.6
|
||||
jinja2-time==0.2.0
|
||||
jmespath==0.10.0
|
||||
joblib==1.2.0
|
||||
jsonschema==4.4.0
|
||||
keras==2.13.1
|
||||
Keras-Preprocessing==1.1.2
|
||||
kiwisolver==1.4.0
|
||||
kubernetes==12.0.1
|
||||
libclang==13.0.0
|
||||
librosa==0.9.1
|
||||
llvmlite==0.38.0
|
||||
Mako==1.2.2
|
||||
Markdown==3.3.6
|
||||
MarkupSafe==1.1.1
|
||||
matplotlib==3.5.1
|
||||
matplotlib-inline==0.1.3
|
||||
mccabe==0.6.1
|
||||
msgpack==1.0.3
|
||||
mujoco-py==2.1.2.14
|
||||
multidict==6.0.2
|
||||
multiprocess==0.70.12.2
|
||||
mypy-extensions==0.4.3
|
||||
nltk==3.9
|
||||
numba==0.55.1
|
||||
numpy==1.22.3
|
||||
oauthlib==3.2.2
|
||||
onnx>=1.15.0
|
||||
onnxconverter-common==1.9.0
|
||||
opt-einsum==3.3.0
|
||||
optax==0.1.1
|
||||
optuna==2.10.0
|
||||
packaging==21.3
|
||||
pandas==1.4.1
|
||||
parameterized==0.8.1
|
||||
parso==0.8.3
|
||||
pathspec==0.9.0
|
||||
pbr==5.8.1
|
||||
pexpect==4.8.0
|
||||
phonemizer==3.0.1
|
||||
pickleshare==0.7.5
|
||||
Pillow==10.3.0
|
||||
Pint==0.16.1
|
||||
plac==1.3.4
|
||||
platformdirs==2.5.1
|
||||
plotly==5.6.0
|
||||
pluggy==1.0.0
|
||||
pooch==1.6.0
|
||||
portalocker==2.0.0
|
||||
poyo==0.5.0
|
||||
prettytable==3.2.0
|
||||
prompt-toolkit==3.0.28
|
||||
protobuf==3.19.5
|
||||
psutil==5.9.0
|
||||
ptyprocess==0.7.0
|
||||
pure-eval==0.2.2
|
||||
py==1.11.0
|
||||
py-cpuinfo==8.0.0
|
||||
pyarrow==15.0.0
|
||||
pyasn1==0.4.8
|
||||
pyasn1-modules==0.2.8
|
||||
pycodestyle==2.8.0
|
||||
pycparser==2.21
|
||||
pyctcdecode==0.3.0
|
||||
pyflakes==2.4.0
|
||||
Pygments==2.15.0
|
||||
pygtrie==2.4.2
|
||||
pynvml==11.4.1
|
||||
pyOpenSSL==22.0.0
|
||||
pyparsing==3.0.7
|
||||
pyperclip==1.8.2
|
||||
pypng==0.0.21
|
||||
pyrsistent==0.18.1
|
||||
pytest==7.1.1
|
||||
pytest-forked==1.4.0
|
||||
pytest-timeout==2.1.0
|
||||
pytest-xdist==2.5.0
|
||||
python-dateutil==2.8.2
|
||||
python-slugify==6.1.1
|
||||
pytz==2022.1
|
||||
pytz-deprecation-shim==0.1.0.post0
|
||||
PyYAML==6.0
|
||||
ray>2.6.3
|
||||
redis==4.5.4
|
||||
regex==2022.3.15
|
||||
requests==2.32.0
|
||||
requests-oauthlib==1.3.1
|
||||
resampy==0.2.2
|
||||
responses==0.18.0
|
||||
rfc3986==1.5.0
|
||||
rouge-score==0.0.4
|
||||
rsa==4.8
|
||||
s3transfer==0.3.7
|
||||
sacrebleu==1.5.1
|
||||
sacremoses==0.0.49
|
||||
scikit-learn==1.5.0
|
||||
scipy==1.8.0
|
||||
segments==2.2.0
|
||||
sentencepiece==0.1.96
|
||||
sigopt==8.2.0
|
||||
six==1.16.0
|
||||
smmap==5.0.0
|
||||
sortedcontainers==2.4.0
|
||||
SoundFile==0.10.3.post1
|
||||
SQLAlchemy==1.4.32
|
||||
stack-data==0.2.0
|
||||
stevedore==3.5.0
|
||||
tabulate==0.8.9
|
||||
tenacity==8.0.1
|
||||
tensorboard==2.8.0
|
||||
tensorboard-data-server==0.6.1
|
||||
tensorboard-plugin-wit==1.8.1
|
||||
tensorboardX==2.5
|
||||
tensorflow==2.12.1
|
||||
tensorflow-io-gcs-filesystem==0.24.0
|
||||
termcolor==1.1.0
|
||||
text-unidecode==1.3
|
||||
tf-estimator-nightly==2.8.0.dev2021122109
|
||||
tf2onnx==1.9.3
|
||||
threadpoolctl==3.1.0
|
||||
timeout-decorator==0.5.0
|
||||
timm==0.5.4
|
||||
tokenizers==0.11.6
|
||||
tomli==2.0.1
|
||||
toolz==0.11.2
|
||||
torch==2.2.0
|
||||
torchaudio==0.11.0
|
||||
torchvision==0.12.0
|
||||
tqdm==4.66.3
|
||||
traitlets==5.1.1
|
||||
-e git+git@github.com:edbeeching/transformers.git@77b90113ca0a0e4058b046796c874bdc98f1da61#egg=transformers
|
||||
typing-extensions==4.1.1
|
||||
tzdata==2022.1
|
||||
tzlocal==4.1
|
||||
unidic==1.1.0
|
||||
unidic-lite==1.0.8
|
||||
uritemplate==4.1.1
|
||||
urllib3==1.26.19
|
||||
wasabi==0.9.0
|
||||
wcwidth==0.2.5
|
||||
websocket-client==1.3.1
|
||||
Werkzeug==3.0.6
|
||||
wrapt==1.14.0
|
||||
xxhash==3.0.0
|
||||
yarl==1.7.2
|
||||
zipp==3.19.1
|
||||
@@ -1,173 +0,0 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from mujoco_py import GlfwContext
|
||||
|
||||
from transformers import DecisionTransformerModel
|
||||
|
||||
|
||||
GlfwContext(offscreen=True) # Create a window to init GLFW.
|
||||
|
||||
|
||||
def get_action(model, states, actions, rewards, returns_to_go, timesteps):
|
||||
# we don't care about the past rewards in this model
|
||||
|
||||
states = states.reshape(1, -1, model.config.state_dim)
|
||||
actions = actions.reshape(1, -1, model.config.act_dim)
|
||||
returns_to_go = returns_to_go.reshape(1, -1, 1)
|
||||
timesteps = timesteps.reshape(1, -1)
|
||||
|
||||
if model.config.max_length is not None:
|
||||
states = states[:, -model.config.max_length :]
|
||||
actions = actions[:, -model.config.max_length :]
|
||||
returns_to_go = returns_to_go[:, -model.config.max_length :]
|
||||
timesteps = timesteps[:, -model.config.max_length :]
|
||||
|
||||
# pad all tokens to sequence length
|
||||
attention_mask = torch.cat(
|
||||
[torch.zeros(model.config.max_length - states.shape[1]), torch.ones(states.shape[1])]
|
||||
)
|
||||
attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1)
|
||||
states = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
(states.shape[0], model.config.max_length - states.shape[1], model.config.state_dim),
|
||||
device=states.device,
|
||||
),
|
||||
states,
|
||||
],
|
||||
dim=1,
|
||||
).to(dtype=torch.float32)
|
||||
actions = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
(actions.shape[0], model.config.max_length - actions.shape[1], model.config.act_dim),
|
||||
device=actions.device,
|
||||
),
|
||||
actions,
|
||||
],
|
||||
dim=1,
|
||||
).to(dtype=torch.float32)
|
||||
returns_to_go = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
(returns_to_go.shape[0], model.config.max_length - returns_to_go.shape[1], 1),
|
||||
device=returns_to_go.device,
|
||||
),
|
||||
returns_to_go,
|
||||
],
|
||||
dim=1,
|
||||
).to(dtype=torch.float32)
|
||||
timesteps = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
(timesteps.shape[0], model.config.max_length - timesteps.shape[1]), device=timesteps.device
|
||||
),
|
||||
timesteps,
|
||||
],
|
||||
dim=1,
|
||||
).to(dtype=torch.long)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
_, action_preds, _ = model(
|
||||
states=states,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
returns_to_go=returns_to_go,
|
||||
timesteps=timesteps,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
return action_preds[0, -1]
|
||||
|
||||
|
||||
# build the environment
|
||||
|
||||
env = gym.make("Hopper-v3")
|
||||
state_dim = env.observation_space.shape[0]
|
||||
act_dim = env.action_space.shape[0]
|
||||
max_ep_len = 1000
|
||||
device = "cuda"
|
||||
scale = 1000.0 # normalization for rewards/returns
|
||||
TARGET_RETURN = 3600 / scale # evaluation conditioning targets, 3600 is reasonable from the paper LINK
|
||||
state_mean = np.array(
|
||||
[
|
||||
1.311279,
|
||||
-0.08469521,
|
||||
-0.5382719,
|
||||
-0.07201576,
|
||||
0.04932366,
|
||||
2.1066856,
|
||||
-0.15017354,
|
||||
0.00878345,
|
||||
-0.2848186,
|
||||
-0.18540096,
|
||||
-0.28461286,
|
||||
]
|
||||
)
|
||||
state_std = np.array(
|
||||
[
|
||||
0.17790751,
|
||||
0.05444621,
|
||||
0.21297139,
|
||||
0.14530419,
|
||||
0.6124444,
|
||||
0.85174465,
|
||||
1.4515252,
|
||||
0.6751696,
|
||||
1.536239,
|
||||
1.6160746,
|
||||
5.6072536,
|
||||
]
|
||||
)
|
||||
state_mean = torch.from_numpy(state_mean).to(device=device)
|
||||
state_std = torch.from_numpy(state_std).to(device=device)
|
||||
|
||||
# Create the decision transformer model
|
||||
model = DecisionTransformerModel.from_pretrained("edbeeching/decision-transformer-gym-hopper-medium")
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
for ep in range(10):
|
||||
episode_return, episode_length = 0, 0
|
||||
state = env.reset()
|
||||
target_return = torch.tensor(TARGET_RETURN, device=device, dtype=torch.float32).reshape(1, 1)
|
||||
states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32)
|
||||
actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32)
|
||||
rewards = torch.zeros(0, device=device, dtype=torch.float32)
|
||||
|
||||
timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)
|
||||
for t in range(max_ep_len):
|
||||
env.render()
|
||||
# add padding
|
||||
actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0)
|
||||
rewards = torch.cat([rewards, torch.zeros(1, device=device)])
|
||||
|
||||
action = get_action(
|
||||
model,
|
||||
(states.to(dtype=torch.float32) - state_mean) / state_std,
|
||||
actions.to(dtype=torch.float32),
|
||||
rewards.to(dtype=torch.float32),
|
||||
target_return.to(dtype=torch.float32),
|
||||
timesteps.to(dtype=torch.long),
|
||||
)
|
||||
actions[-1] = action
|
||||
action = action.detach().cpu().numpy()
|
||||
|
||||
state, reward, done, _ = env.step(action)
|
||||
|
||||
cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim)
|
||||
states = torch.cat([states, cur_state], dim=0)
|
||||
rewards[-1] = reward
|
||||
|
||||
pred_return = target_return[0, -1] - (reward / scale)
|
||||
target_return = torch.cat([target_return, pred_return.reshape(1, 1)], dim=1)
|
||||
timesteps = torch.cat([timesteps, torch.ones((1, 1), device=device, dtype=torch.long) * (t + 1)], dim=1)
|
||||
|
||||
episode_return += reward
|
||||
episode_length += 1
|
||||
|
||||
if done:
|
||||
break
|
||||
@@ -1,54 +0,0 @@
|
||||
# DeeBERT: Early Exiting for *BERT
|
||||
|
||||
This is the code base for the paper [DeeBERT: Dynamic Early Exiting for Accelerating BERT Inference](https://www.aclweb.org/anthology/2020.acl-main.204/), modified from its [original code base](https://github.com/castorini/deebert).
|
||||
|
||||
The original code base also has information for downloading sample models that we have trained in advance.
|
||||
|
||||
## Usage
|
||||
|
||||
There are three scripts in the folder which can be run directly.
|
||||
|
||||
In each script, there are several things to modify before running:
|
||||
|
||||
* `PATH_TO_DATA`: path to the GLUE dataset.
|
||||
* `--output_dir`: path for saving fine-tuned models. Default: `./saved_models`.
|
||||
* `--plot_data_dir`: path for saving evaluation results. Default: `./results`. Results are printed to stdout and also saved to `npy` files in this directory to facilitate plotting figures and further analyses.
|
||||
* `MODEL_TYPE`: bert or roberta
|
||||
* `MODEL_SIZE`: base or large
|
||||
* `DATASET`: SST-2, MRPC, RTE, QNLI, QQP, or MNLI
|
||||
|
||||
#### train_deebert.sh
|
||||
|
||||
This is for fine-tuning DeeBERT models.
|
||||
|
||||
#### eval_deebert.sh
|
||||
|
||||
This is for evaluating each exit layer for fine-tuned DeeBERT models.
|
||||
|
||||
#### entropy_eval.sh
|
||||
|
||||
This is for evaluating fine-tuned DeeBERT models, given a number of different early exit entropy thresholds.
|
||||
|
||||
|
||||
|
||||
## Citation
|
||||
|
||||
Please cite our paper if you find the resource useful:
|
||||
```bibtex
|
||||
@inproceedings{xin-etal-2020-deebert,
|
||||
title = "{D}ee{BERT}: Dynamic Early Exiting for Accelerating {BERT} Inference",
|
||||
author = "Xin, Ji and
|
||||
Tang, Raphael and
|
||||
Lee, Jaejun and
|
||||
Yu, Yaoliang and
|
||||
Lin, Jimmy",
|
||||
booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics",
|
||||
month = jul,
|
||||
year = "2020",
|
||||
address = "Online",
|
||||
publisher = "Association for Computational Linguistics",
|
||||
url = "https://www.aclweb.org/anthology/2020.acl-main.204",
|
||||
pages = "2246--2251",
|
||||
}
|
||||
```
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
#!/bin/bash
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
PATH_TO_DATA=/h/xinji/projects/GLUE
|
||||
|
||||
MODEL_TYPE=bert # bert or roberta
|
||||
MODEL_SIZE=base # base or large
|
||||
DATASET=MRPC # SST-2, MRPC, RTE, QNLI, QQP, or MNLI
|
||||
|
||||
MODEL_NAME=${MODEL_TYPE}-${MODEL_SIZE}
|
||||
if [ $MODEL_TYPE = 'bert' ]
|
||||
then
|
||||
MODEL_NAME=${MODEL_NAME}-uncased
|
||||
fi
|
||||
|
||||
ENTROPIES="0 0.1 0.2 0.3 0.4 0.5 0.6 0.7"
|
||||
|
||||
for ENTROPY in $ENTROPIES; do
|
||||
python -u run_glue_deebert.py \
|
||||
--model_type $MODEL_TYPE \
|
||||
--model_name_or_path ./saved_models/${MODEL_TYPE}-${MODEL_SIZE}/$DATASET/two_stage \
|
||||
--task_name $DATASET \
|
||||
--do_eval \
|
||||
--do_lower_case \
|
||||
--data_dir $PATH_TO_DATA/$DATASET \
|
||||
--output_dir ./saved_models/${MODEL_TYPE}-${MODEL_SIZE}/$DATASET/two_stage \
|
||||
--plot_data_dir ./results/ \
|
||||
--max_seq_length 128 \
|
||||
--early_exit_entropy $ENTROPY \
|
||||
--eval_highway \
|
||||
--overwrite_cache \
|
||||
--per_gpu_eval_batch_size=1
|
||||
done
|
||||
@@ -1,30 +0,0 @@
|
||||
#!/bin/bash
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
PATH_TO_DATA=/h/xinji/projects/GLUE
|
||||
|
||||
MODEL_TYPE=bert # bert or roberta
|
||||
MODEL_SIZE=base # base or large
|
||||
DATASET=MRPC # SST-2, MRPC, RTE, QNLI, QQP, or MNLI
|
||||
|
||||
MODEL_NAME=${MODEL_TYPE}-${MODEL_SIZE}
|
||||
if [ $MODEL_TYPE = 'bert' ]
|
||||
then
|
||||
MODEL_NAME=${MODEL_NAME}-uncased
|
||||
fi
|
||||
|
||||
|
||||
python -u run_glue_deebert.py \
|
||||
--model_type $MODEL_TYPE \
|
||||
--model_name_or_path ./saved_models/${MODEL_TYPE}-${MODEL_SIZE}/$DATASET/two_stage \
|
||||
--task_name $DATASET \
|
||||
--do_eval \
|
||||
--do_lower_case \
|
||||
--data_dir $PATH_TO_DATA/$DATASET \
|
||||
--output_dir ./saved_models/${MODEL_TYPE}-${MODEL_SIZE}/$DATASET/two_stage \
|
||||
--plot_data_dir ./results/ \
|
||||
--max_seq_length 128 \
|
||||
--eval_each_highway \
|
||||
--eval_highway \
|
||||
--overwrite_cache \
|
||||
--per_gpu_eval_batch_size=1
|
||||
@@ -1 +0,0 @@
|
||||
transformers == 4.38.0
|
||||
@@ -1,735 +0,0 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
import transformers
|
||||
from src.modeling_highway_bert import DeeBertForSequenceClassification
|
||||
from src.modeling_highway_roberta import DeeRobertaForSequenceClassification
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
BertConfig,
|
||||
BertTokenizer,
|
||||
RobertaConfig,
|
||||
RobertaTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from transformers import glue_compute_metrics as compute_metrics
|
||||
from transformers import glue_convert_examples_to_features as convert_examples_to_features
|
||||
from transformers import glue_output_modes as output_modes
|
||||
from transformers import glue_processors as processors
|
||||
from transformers.trainer_utils import is_main_process
|
||||
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except ImportError:
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, DeeBertForSequenceClassification, BertTokenizer),
|
||||
"roberta": (RobertaConfig, DeeRobertaForSequenceClassification, RobertaTokenizer),
|
||||
}
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
def get_wanted_result(result):
|
||||
if "spearmanr" in result:
|
||||
print_result = result["spearmanr"]
|
||||
elif "f1" in result:
|
||||
print_result = result["f1"]
|
||||
elif "mcc" in result:
|
||||
print_result = result["mcc"]
|
||||
elif "acc" in result:
|
||||
print_result = result["acc"]
|
||||
else:
|
||||
raise ValueError("Primary metric unclear in the results")
|
||||
return print_result
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer, train_highway=False):
|
||||
"""Train the model"""
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter()
|
||||
|
||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
if args.max_steps > 0:
|
||||
t_total = args.max_steps
|
||||
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
||||
else:
|
||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
if train_highway:
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if ("highway" in n) and (not any(nd in n for nd in no_decay))
|
||||
],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p for n, p in model.named_parameters() if ("highway" in n) and (any(nd in n for nd in no_decay))
|
||||
],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
else:
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if ("highway" not in n) and (not any(nd in n for nd in no_decay))
|
||||
],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if ("highway" not in n) and (any(nd in n for nd in no_decay))
|
||||
],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||
|
||||
# multi-gpu training (should be after apex fp16 initialization)
|
||||
if args.n_gpu > 1:
|
||||
model = nn.DataParallel(model)
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if args.local_rank != -1:
|
||||
model = nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||
)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size
|
||||
* args.gradient_accumulation_steps
|
||||
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
global_step = 0
|
||||
tr_loss, logging_loss = 0.0, 0.0
|
||||
model.zero_grad()
|
||||
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
||||
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
model.train()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = (
|
||||
batch[2] if args.model_type in ["bert", "xlnet"] else None
|
||||
) # XLM, DistilBERT and RoBERTa don't use segment_ids
|
||||
inputs["train_highway"] = train_highway
|
||||
outputs = model(**inputs)
|
||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||
|
||||
if args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
if args.fp16:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
tr_loss += loss.item()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
if args.fp16:
|
||||
nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||
else:
|
||||
nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step() # Update learning rate schedule
|
||||
model.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
# Log metrics
|
||||
if (
|
||||
args.local_rank == -1 and args.evaluate_during_training
|
||||
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||
results = evaluate(args, model, tokenizer)
|
||||
for key, value in results.items():
|
||||
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
||||
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
|
||||
logging_loss = tr_loss
|
||||
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
epoch_iterator.close()
|
||||
break
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
train_iterator.close()
|
||||
break
|
||||
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer.close()
|
||||
|
||||
return global_step, tr_loss / global_step
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, prefix="", output_layer=-1, eval_highway=False):
|
||||
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
||||
eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
|
||||
eval_outputs_dirs = (args.output_dir, args.output_dir + "-MM") if args.task_name == "mnli" else (args.output_dir,)
|
||||
|
||||
results = {}
|
||||
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
|
||||
eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
|
||||
|
||||
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(eval_output_dir)
|
||||
|
||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
||||
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
# multi-gpu eval
|
||||
if args.n_gpu > 1:
|
||||
model = nn.DataParallel(model)
|
||||
|
||||
# Eval!
|
||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||
logger.info(" Num examples = %d", len(eval_dataset))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
eval_loss = 0.0
|
||||
nb_eval_steps = 0
|
||||
preds = None
|
||||
out_label_ids = None
|
||||
exit_layer_counter = {(i + 1): 0 for i in range(model.num_layers)}
|
||||
st = time.time()
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
model.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = (
|
||||
batch[2] if args.model_type in ["bert", "xlnet"] else None
|
||||
) # XLM, DistilBERT and RoBERTa don't use segment_ids
|
||||
if output_layer >= 0:
|
||||
inputs["output_layer"] = output_layer
|
||||
outputs = model(**inputs)
|
||||
if eval_highway:
|
||||
exit_layer_counter[outputs[-1]] += 1
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
|
||||
eval_loss += tmp_eval_loss.mean().item()
|
||||
nb_eval_steps += 1
|
||||
if preds is None:
|
||||
preds = logits.detach().cpu().numpy()
|
||||
out_label_ids = inputs["labels"].detach().cpu().numpy()
|
||||
else:
|
||||
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
||||
out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
|
||||
eval_time = time.time() - st
|
||||
logger.info("Eval time: {}".format(eval_time))
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
if args.output_mode == "classification":
|
||||
preds = np.argmax(preds, axis=1)
|
||||
elif args.output_mode == "regression":
|
||||
preds = np.squeeze(preds)
|
||||
result = compute_metrics(eval_task, preds, out_label_ids)
|
||||
results.update(result)
|
||||
|
||||
if eval_highway:
|
||||
logger.info("Exit layer counter: {}".format(exit_layer_counter))
|
||||
actual_cost = sum([l * c for l, c in exit_layer_counter.items()])
|
||||
full_cost = len(eval_dataloader) * model.num_layers
|
||||
logger.info("Expected saving: {}".format(actual_cost / full_cost))
|
||||
if args.early_exit_entropy >= 0:
|
||||
save_fname = (
|
||||
args.plot_data_dir
|
||||
+ "/"
|
||||
+ args.model_name_or_path[2:]
|
||||
+ "/entropy_{}.npy".format(args.early_exit_entropy)
|
||||
)
|
||||
if not os.path.exists(os.path.dirname(save_fname)):
|
||||
os.makedirs(os.path.dirname(save_fname))
|
||||
print_result = get_wanted_result(result)
|
||||
np.save(save_fname, np.array([exit_layer_counter, eval_time, actual_cost / full_cost, print_result]))
|
||||
logger.info("Entropy={}\tResult={:.2f}".format(args.early_exit_entropy, 100 * print_result))
|
||||
|
||||
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results {} *****".format(prefix))
|
||||
for key in sorted(result.keys()):
|
||||
logger.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
||||
if args.local_rank not in [-1, 0] and not evaluate:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
processor = processors[task]()
|
||||
output_mode = output_modes[task]
|
||||
# Load data features from cache or dataset file
|
||||
cached_features_file = os.path.join(
|
||||
args.data_dir,
|
||||
"cached_{}_{}_{}_{}".format(
|
||||
"dev" if evaluate else "train",
|
||||
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||||
str(args.max_seq_length),
|
||||
str(task),
|
||||
),
|
||||
)
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features = torch.load(cached_features_file)
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s", args.data_dir)
|
||||
label_list = processor.get_labels()
|
||||
if task in ["mnli", "mnli-mm"] and args.model_type in ["roberta"]:
|
||||
# HACK(label indices are swapped in RoBERTa pretrained model)
|
||||
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||
examples = (
|
||||
processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
||||
)
|
||||
features = convert_examples_to_features(
|
||||
examples,
|
||||
tokenizer,
|
||||
label_list=label_list,
|
||||
max_length=args.max_seq_length,
|
||||
output_mode=output_mode,
|
||||
)
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(features, cached_features_file)
|
||||
|
||||
if args.local_rank == 0 and not evaluate:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
# Convert to Tensors and build dataset
|
||||
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
||||
|
||||
if features[0].token_type_ids is None:
|
||||
# For RoBERTa (a potential bug!)
|
||||
all_token_type_ids = torch.tensor([[0] * args.max_seq_length for f in features], dtype=torch.long)
|
||||
else:
|
||||
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
|
||||
if output_mode == "classification":
|
||||
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
||||
elif output_mode == "regression":
|
||||
all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
|
||||
|
||||
dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_name",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plot_data_dir",
|
||||
default="./plotting/",
|
||||
type=str,
|
||||
required=False,
|
||||
help="The directory to store data for plotting figures.",
|
||||
)
|
||||
|
||||
# Other parameters
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from huggingface.co",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help=(
|
||||
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||
parser.add_argument(
|
||||
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||
)
|
||||
parser.add_argument("--eval_each_highway", action="store_true", help="Set this flag to evaluate each highway.")
|
||||
parser.add_argument(
|
||||
"--eval_after_first_stage",
|
||||
action="store_true",
|
||||
help="Set this flag to evaluate after training only bert (not highway).",
|
||||
)
|
||||
parser.add_argument("--eval_highway", action="store_true", help="Set this flag if it's evaluating highway models")
|
||||
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument(
|
||||
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||
)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument("--early_exit_entropy", default=-1, type=float, help="Entropy threshold for early exit.")
|
||||
|
||||
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument(
|
||||
"--eval_all_checkpoints",
|
||||
action="store_true",
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||
)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||
parser.add_argument(
|
||||
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
help=(
|
||||
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. "
|
||||
"See details at https://nvidia.github.io/apex/amp.html"
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if (
|
||||
os.path.exists(args.output_dir)
|
||||
and os.listdir(args.output_dir)
|
||||
and args.do_train
|
||||
and not args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||
args.output_dir
|
||||
)
|
||||
)
|
||||
|
||||
# Setup distant debugging if needed
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
|
||||
# Setup CUDA, GPU & distributed training
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
args.n_gpu = torch.cuda.device_count()
|
||||
else: # Initializes the distributed backend which will take care of synchronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank,
|
||||
device,
|
||||
args.n_gpu,
|
||||
bool(args.local_rank != -1),
|
||||
args.fp16,
|
||||
)
|
||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||
if is_main_process(args.local_rank):
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
|
||||
# Prepare GLUE task
|
||||
args.task_name = args.task_name.lower()
|
||||
if args.task_name not in processors:
|
||||
raise ValueError("Task not found: %s" % (args.task_name))
|
||||
processor = processors[args.task_name]()
|
||||
args.output_mode = output_modes[args.task_name]
|
||||
label_list = processor.get_labels()
|
||||
num_labels = len(label_list)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=args.task_name,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
|
||||
if args.model_type == "bert":
|
||||
model.bert.encoder.set_early_exit_entropy(args.early_exit_entropy)
|
||||
model.bert.init_highway_pooler()
|
||||
elif args.model_type == "roberta":
|
||||
model.roberta.encoder.set_early_exit_entropy(args.early_exit_entropy)
|
||||
model.roberta.init_highway_pooler()
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
model.to(args.device)
|
||||
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Training
|
||||
if args.do_train:
|
||||
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
if args.eval_after_first_stage:
|
||||
result = evaluate(args, model, tokenizer, prefix="")
|
||||
print_result = get_wanted_result(result)
|
||||
|
||||
train(args, train_dataset, model, tokenizer, train_highway=True)
|
||||
|
||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
# Create output directory if needed
|
||||
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = [
|
||||
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||
]
|
||||
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
for checkpoint in checkpoints:
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
if args.model_type == "bert":
|
||||
model.bert.encoder.set_early_exit_entropy(args.early_exit_entropy)
|
||||
elif args.model_type == "roberta":
|
||||
model.roberta.encoder.set_early_exit_entropy(args.early_exit_entropy)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
model.to(args.device)
|
||||
result = evaluate(args, model, tokenizer, prefix=prefix, eval_highway=args.eval_highway)
|
||||
print_result = get_wanted_result(result)
|
||||
logger.info("Result: {}".format(print_result))
|
||||
if args.eval_each_highway:
|
||||
last_layer_results = print_result
|
||||
each_layer_results = []
|
||||
for i in range(model.num_layers):
|
||||
logger.info("\n")
|
||||
_result = evaluate(
|
||||
args, model, tokenizer, prefix=prefix, output_layer=i, eval_highway=args.eval_highway
|
||||
)
|
||||
if i + 1 < model.num_layers:
|
||||
each_layer_results.append(get_wanted_result(_result))
|
||||
each_layer_results.append(last_layer_results)
|
||||
save_fname = args.plot_data_dir + "/" + args.model_name_or_path[2:] + "/each_layer.npy"
|
||||
if not os.path.exists(os.path.dirname(save_fname)):
|
||||
os.makedirs(os.path.dirname(save_fname))
|
||||
np.save(save_fname, np.array(each_layer_results))
|
||||
info_str = "Score of each layer:"
|
||||
for i in range(model.num_layers):
|
||||
info_str += " {:.2f}".format(100 * each_layer_results[i])
|
||||
logger.info(info_str)
|
||||
result = {k + "_{}".format(global_step): v for k, v in result.items()}
|
||||
results.update(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,397 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from transformers.models.bert.modeling_bert import (
|
||||
BERT_INPUTS_DOCSTRING,
|
||||
BERT_START_DOCSTRING,
|
||||
BertEmbeddings,
|
||||
BertLayer,
|
||||
BertPooler,
|
||||
BertPreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
def entropy(x):
|
||||
"""Calculate entropy of a pre-softmax logit Tensor"""
|
||||
exp_x = torch.exp(x)
|
||||
A = torch.sum(exp_x, dim=1) # sum of exp(x_i)
|
||||
B = torch.sum(x * exp_x, dim=1) # sum of x_i * exp(x_i)
|
||||
return torch.log(A) - B / A
|
||||
|
||||
|
||||
class DeeBertEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.highway = nn.ModuleList([BertHighway(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
self.early_exit_entropy = [-1 for _ in range(config.num_hidden_layers)]
|
||||
|
||||
def set_early_exit_entropy(self, x):
|
||||
if isinstance(x, (float, int)):
|
||||
for i in range(len(self.early_exit_entropy)):
|
||||
self.early_exit_entropy[i] = x
|
||||
else:
|
||||
self.early_exit_entropy = x
|
||||
|
||||
def init_highway_pooler(self, pooler):
|
||||
loaded_model = pooler.state_dict()
|
||||
for highway in self.highway:
|
||||
for name, param in highway.pooler.state_dict().items():
|
||||
param.copy_(loaded_model[name])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
):
|
||||
all_hidden_states = ()
|
||||
all_attentions = ()
|
||||
all_highway_exits = ()
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if self.output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
current_outputs = (hidden_states,)
|
||||
if self.output_hidden_states:
|
||||
current_outputs = current_outputs + (all_hidden_states,)
|
||||
if self.output_attentions:
|
||||
current_outputs = current_outputs + (all_attentions,)
|
||||
|
||||
highway_exit = self.highway[i](current_outputs)
|
||||
# logits, pooled_output
|
||||
|
||||
if not self.training:
|
||||
highway_logits = highway_exit[0]
|
||||
highway_entropy = entropy(highway_logits)
|
||||
highway_exit = highway_exit + (highway_entropy,) # logits, hidden_states(?), entropy
|
||||
all_highway_exits = all_highway_exits + (highway_exit,)
|
||||
|
||||
if highway_entropy < self.early_exit_entropy[i]:
|
||||
new_output = (highway_logits,) + current_outputs[1:] + (all_highway_exits,)
|
||||
raise HighwayException(new_output, i + 1)
|
||||
else:
|
||||
all_highway_exits = all_highway_exits + (highway_exit,)
|
||||
|
||||
# Add last layer
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if self.output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if self.output_attentions:
|
||||
outputs = outputs + (all_attentions,)
|
||||
|
||||
outputs = outputs + (all_highway_exits,)
|
||||
return outputs # last-layer hidden state, (all hidden states), (all attentions), all highway exits
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The Bert Model transformer with early exiting (DeeBERT). ",
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class DeeBertModel(BertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.embeddings = BertEmbeddings(config)
|
||||
self.encoder = DeeBertEncoder(config)
|
||||
self.pooler = BertPooler(config)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_highway_pooler(self):
|
||||
self.encoder.init_highway_pooler(self.pooler)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
"""Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
See base class PreTrainedModel
|
||||
"""
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
):
|
||||
r"""
|
||||
Return:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
|
||||
Last layer hidden-state of the first token of the sequence (classification token)
|
||||
further processed by a Linear layer and a Tanh activation function. The Linear
|
||||
layer weights are trained from the next sentence prediction (classification)
|
||||
objective during pre-training.
|
||||
|
||||
This output is usually *not* a good summary
|
||||
of the semantic content of the input, you're often better with averaging or pooling
|
||||
the sequence of hidden-states for the whole input sequence.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
highway_exits (:obj:`tuple(tuple(torch.Tensor))`:
|
||||
Tuple of each early exit's results (total length: number of layers)
|
||||
Each tuple is again, a tuple of length 2 - the first entry is logits and the second entry is hidden states.
|
||||
"""
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(input_shape, device=device)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(input_shape, device=device)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
|
||||
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if encoder_attention_mask.dim() == 3:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
||||
if encoder_attention_mask.dim() == 2:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
||||
|
||||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(
|
||||
dtype=next(self.parameters()).dtype
|
||||
) # fp16 compatibility
|
||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
outputs = (
|
||||
sequence_output,
|
||||
pooled_output,
|
||||
) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
||||
return outputs # sequence_output, pooled_output, (hidden_states), (attentions), highway exits
|
||||
|
||||
|
||||
class HighwayException(Exception):
|
||||
def __init__(self, message, exit_layer):
|
||||
self.message = message
|
||||
self.exit_layer = exit_layer # start from 1!
|
||||
|
||||
|
||||
class BertHighway(nn.Module):
|
||||
"""A module to provide a shortcut
|
||||
from (the output of one non-final BertLayer in BertEncoder) to (cross-entropy computation in BertForSequenceClassification)
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.pooler = BertPooler(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
def forward(self, encoder_outputs):
|
||||
# Pooler
|
||||
pooler_input = encoder_outputs[0]
|
||||
pooler_output = self.pooler(pooler_input)
|
||||
# "return" pooler_output
|
||||
|
||||
# BertModel
|
||||
bmodel_output = (pooler_input, pooler_output) + encoder_outputs[1:]
|
||||
# "return" bmodel_output
|
||||
|
||||
# Dropout and classification
|
||||
pooled_output = bmodel_output[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
return logits, pooled_output
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""Bert Model (with early exiting - DeeBERT) with a classifier on top,
|
||||
also takes care of multi-layer training. """,
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class DeeBertForSequenceClassification(BertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.bert = DeeBertModel(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_layer=-1,
|
||||
train_highway=False,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the sequence classification/regression loss.
|
||||
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
|
||||
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Returns:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
highway_exits (:obj:`tuple(tuple(torch.Tensor))`:
|
||||
Tuple of each early exit's results (total length: number of layers)
|
||||
Each tuple is again, a tuple of length 2 - the first entry is logits and the second entry is hidden states.
|
||||
"""
|
||||
|
||||
exit_layer = self.num_layers
|
||||
try:
|
||||
outputs = self.bert(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
# sequence_output, pooled_output, (hidden_states), (attentions), highway exits
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
except HighwayException as e:
|
||||
outputs = e.message
|
||||
exit_layer = e.exit_layer
|
||||
logits = outputs[0]
|
||||
|
||||
if not self.training:
|
||||
original_entropy = entropy(logits)
|
||||
highway_entropy = []
|
||||
highway_logits_all = []
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
# work with highway exits
|
||||
highway_losses = []
|
||||
for highway_exit in outputs[-1]:
|
||||
highway_logits = highway_exit[0]
|
||||
if not self.training:
|
||||
highway_logits_all.append(highway_logits)
|
||||
highway_entropy.append(highway_exit[2])
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
highway_loss = loss_fct(highway_logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
highway_loss = loss_fct(highway_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
highway_losses.append(highway_loss)
|
||||
|
||||
if train_highway:
|
||||
outputs = (sum(highway_losses[:-1]),) + outputs
|
||||
# exclude the final highway, of course
|
||||
else:
|
||||
outputs = (loss,) + outputs
|
||||
if not self.training:
|
||||
outputs = outputs + ((original_entropy, highway_entropy), exit_layer)
|
||||
if output_layer >= 0:
|
||||
outputs = (
|
||||
(outputs[0],) + (highway_logits_all[output_layer],) + outputs[2:]
|
||||
) # use the highway of the last layer
|
||||
|
||||
return outputs # (loss), logits, (hidden_states), (attentions), (highway_exits)
|
||||
@@ -1,154 +0,0 @@
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from transformers import RobertaConfig
|
||||
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from transformers.models.roberta.modeling_roberta import (
|
||||
ROBERTA_INPUTS_DOCSTRING,
|
||||
ROBERTA_START_DOCSTRING,
|
||||
RobertaEmbeddings,
|
||||
)
|
||||
|
||||
from .modeling_highway_bert import BertPreTrainedModel, DeeBertModel, HighwayException, entropy
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The RoBERTa Model transformer with early exiting (DeeRoBERTa). ",
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class DeeRobertaModel(DeeBertModel):
|
||||
config_class = RobertaConfig
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.embeddings = RobertaEmbeddings(config)
|
||||
self.init_weights()
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""RoBERTa Model (with early exiting - DeeRoBERTa) with a classifier on top,
|
||||
also takes care of multi-layer training. """,
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class DeeRobertaForSequenceClassification(BertPreTrainedModel):
|
||||
config_class = RobertaConfig
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.roberta = DeeRobertaModel(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
|
||||
|
||||
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_layer=-1,
|
||||
train_highway=False,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the sequence classification/regression loss.
|
||||
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
|
||||
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Returns:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
highway_exits (:obj:`tuple(tuple(torch.Tensor))`:
|
||||
Tuple of each early exit's results (total length: number of layers)
|
||||
Each tuple is again, a tuple of length 2 - the first entry is logits and the second entry is hidden states.
|
||||
"""
|
||||
|
||||
exit_layer = self.num_layers
|
||||
try:
|
||||
outputs = self.roberta(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
except HighwayException as e:
|
||||
outputs = e.message
|
||||
exit_layer = e.exit_layer
|
||||
logits = outputs[0]
|
||||
|
||||
if not self.training:
|
||||
original_entropy = entropy(logits)
|
||||
highway_entropy = []
|
||||
highway_logits_all = []
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
# work with highway exits
|
||||
highway_losses = []
|
||||
for highway_exit in outputs[-1]:
|
||||
highway_logits = highway_exit[0]
|
||||
if not self.training:
|
||||
highway_logits_all.append(highway_logits)
|
||||
highway_entropy.append(highway_exit[2])
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
highway_loss = loss_fct(highway_logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
highway_loss = loss_fct(highway_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
highway_losses.append(highway_loss)
|
||||
|
||||
if train_highway:
|
||||
outputs = (sum(highway_losses[:-1]),) + outputs
|
||||
# exclude the final highway, of course
|
||||
else:
|
||||
outputs = (loss,) + outputs
|
||||
if not self.training:
|
||||
outputs = outputs + ((original_entropy, highway_entropy), exit_layer)
|
||||
if output_layer >= 0:
|
||||
outputs = (
|
||||
(outputs[0],) + (highway_logits_all[output_layer],) + outputs[2:]
|
||||
) # use the highway of the last layer
|
||||
|
||||
return outputs # (loss), logits, (hidden_states), (attentions), entropy
|
||||
@@ -1,104 +0,0 @@
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
import run_glue_deebert
|
||||
|
||||
from transformers.testing_utils import TestCasePlus, get_gpu_count, require_torch_non_multi_gpu, slow
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def get_setup_file():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-f")
|
||||
args = parser.parse_args()
|
||||
return args.f
|
||||
|
||||
|
||||
class DeeBertTests(TestCasePlus):
|
||||
def setup(self) -> None:
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
def run_and_check(self, args):
|
||||
n_gpu = get_gpu_count()
|
||||
|
||||
if n_gpu > 1:
|
||||
pass
|
||||
# XXX: doesn't quite work with n_gpu > 1 https://github.com/huggingface/transformers/issues/10560
|
||||
# script = f"{self.examples_dir_str}/research_projects/deebert/run_glue_deebert.py"
|
||||
# distributed_args = f"-m torch.distributed.launch --nproc_per_node={n_gpu} {script}".split()
|
||||
# cmd = [sys.executable] + distributed_args + args
|
||||
# execute_subprocess_async(cmd, env=self.get_env())
|
||||
# XXX: test the results - need to save them first into .json file
|
||||
else:
|
||||
args.insert(0, "run_glue_deebert.py")
|
||||
with patch.object(sys, "argv", args):
|
||||
result = run_glue_deebert.main()
|
||||
for value in result.values():
|
||||
self.assertGreaterEqual(value, 0.666)
|
||||
|
||||
@slow
|
||||
@require_torch_non_multi_gpu
|
||||
def test_glue_deebert_train(self):
|
||||
train_args = """
|
||||
--model_type roberta
|
||||
--model_name_or_path FacebookAI/roberta-base
|
||||
--task_name MRPC
|
||||
--do_train
|
||||
--do_eval
|
||||
--do_lower_case
|
||||
--data_dir ./tests/fixtures/tests_samples/MRPC/
|
||||
--max_seq_length 128
|
||||
--per_gpu_eval_batch_size=1
|
||||
--per_gpu_train_batch_size=8
|
||||
--learning_rate 2e-4
|
||||
--num_train_epochs 3
|
||||
--overwrite_output_dir
|
||||
--seed 42
|
||||
--output_dir ./examples/deebert/saved_models/FacebookAI/roberta-base/MRPC/two_stage
|
||||
--plot_data_dir ./examples/deebert/results/
|
||||
--save_steps 0
|
||||
--overwrite_cache
|
||||
--eval_after_first_stage
|
||||
""".split()
|
||||
self.run_and_check(train_args)
|
||||
|
||||
eval_args = """
|
||||
--model_type roberta
|
||||
--model_name_or_path ./examples/deebert/saved_models/FacebookAI/roberta-base/MRPC/two_stage
|
||||
--task_name MRPC
|
||||
--do_eval
|
||||
--do_lower_case
|
||||
--data_dir ./tests/fixtures/tests_samples/MRPC/
|
||||
--output_dir ./examples/deebert/saved_models/FacebookAI/roberta-base/MRPC/two_stage
|
||||
--plot_data_dir ./examples/deebert/results/
|
||||
--max_seq_length 128
|
||||
--eval_each_highway
|
||||
--eval_highway
|
||||
--overwrite_cache
|
||||
--per_gpu_eval_batch_size=1
|
||||
""".split()
|
||||
self.run_and_check(eval_args)
|
||||
|
||||
entropy_eval_args = """
|
||||
--model_type roberta
|
||||
--model_name_or_path ./examples/deebert/saved_models/FacebookAI/roberta-base/MRPC/two_stage
|
||||
--task_name MRPC
|
||||
--do_eval
|
||||
--do_lower_case
|
||||
--data_dir ./tests/fixtures/tests_samples/MRPC/
|
||||
--output_dir ./examples/deebert/saved_models/FacebookAI/roberta-base/MRPC/two_stage
|
||||
--plot_data_dir ./examples/deebert/results/
|
||||
--max_seq_length 128
|
||||
--early_exit_entropy 0.1
|
||||
--eval_highway
|
||||
--overwrite_cache
|
||||
--per_gpu_eval_batch_size=1
|
||||
""".split()
|
||||
self.run_and_check(entropy_eval_args)
|
||||
@@ -1,38 +0,0 @@
|
||||
#!/bin/bash
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
PATH_TO_DATA=/h/xinji/projects/GLUE
|
||||
|
||||
MODEL_TYPE=bert # bert or roberta
|
||||
MODEL_SIZE=base # base or large
|
||||
DATASET=MRPC # SST-2, MRPC, RTE, QNLI, QQP, or MNLI
|
||||
|
||||
MODEL_NAME=${MODEL_TYPE}-${MODEL_SIZE}
|
||||
EPOCHS=10
|
||||
if [ $MODEL_TYPE = 'bert' ]
|
||||
then
|
||||
EPOCHS=3
|
||||
MODEL_NAME=${MODEL_NAME}-uncased
|
||||
fi
|
||||
|
||||
|
||||
python -u run_glue_deebert.py \
|
||||
--model_type $MODEL_TYPE \
|
||||
--model_name_or_path $MODEL_NAME \
|
||||
--task_name $DATASET \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--do_lower_case \
|
||||
--data_dir $PATH_TO_DATA/$DATASET \
|
||||
--max_seq_length 128 \
|
||||
--per_gpu_eval_batch_size=1 \
|
||||
--per_gpu_train_batch_size=8 \
|
||||
--learning_rate 2e-5 \
|
||||
--num_train_epochs $EPOCHS \
|
||||
--overwrite_output_dir \
|
||||
--seed 42 \
|
||||
--output_dir ./saved_models/${MODEL_TYPE}-${MODEL_SIZE}/$DATASET/two_stage \
|
||||
--plot_data_dir ./results/ \
|
||||
--save_steps 0 \
|
||||
--overwrite_cache \
|
||||
--eval_after_first_stage
|
||||
@@ -1,193 +0,0 @@
|
||||
# Distil*
|
||||
|
||||
Author: @VictorSanh
|
||||
|
||||
This folder contains the original code used to train Distil* as well as examples showcasing how to use DistilBERT, DistilRoBERTa and DistilGPT2.
|
||||
|
||||
**January 20, 2020 - Bug fixing** We have recently discovered and fixed [a bug](https://github.com/huggingface/transformers/commit/48cbf267c988b56c71a2380f748a3e6092ccaed3) in the evaluation of our `run_*.py` scripts that caused the reported metrics to be over-estimated on average. We have updated all the metrics with the latest runs.
|
||||
|
||||
**December 6, 2019 - Update** We release **DistilmBERT**: 92% of `bert-base-multilingual-cased` on XNLI. The model supports 104 different languages listed [here](https://github.com/google-research/bert/blob/master/multilingual.md#list-of-languages).
|
||||
|
||||
**November 19, 2019 - Update** We release German **DistilBERT**: 98.8% of `bert-base-german-dbmdz-cased` on NER tasks.
|
||||
|
||||
**October 23, 2019 - Update** We release **DistilRoBERTa**: 95% of `RoBERTa-base`'s performance on GLUE, twice as fast as RoBERTa while being 35% smaller.
|
||||
|
||||
**October 3, 2019 - Update** We release our [NeurIPS workshop paper](https://arxiv.org/abs/1910.01108) explaining our approach on **DistilBERT**. It includes updated results and further experiments. We applied the same method to GPT2 and release the weights of **DistilGPT2**. DistilGPT2 is two times faster and 33% smaller than GPT2. **The paper supersedes our [previous blogpost](https://medium.com/huggingface/distilbert-8cf3380435b5) with a different distillation loss and better performances. Please use the paper as a reference when comparing/reporting results on DistilBERT.**
|
||||
|
||||
**September 19, 2019 - Update:** We fixed bugs in the code and released an updated version of the weights trained with a modification of the distillation loss. DistilBERT now reaches 99% of `BERT-base`'s performance on GLUE, and 86.9 F1 score on SQuAD v1.1 dev set (compared to 88.5 for `BERT-base`). We will publish a formal write-up of our approach in the near future!
|
||||
|
||||
|
||||
## What is Distil*
|
||||
|
||||
Distil* is a class of compressed models that started with DistilBERT. DistilBERT stands for Distilled-BERT. DistilBERT is a small, fast, cheap and light Transformer model based on Bert architecture. It has 40% less parameters than `bert-base-uncased`, runs 60% faster while preserving 97% of BERT's performances as measured on the GLUE language understanding benchmark. DistilBERT is trained using knowledge distillation, a technique to compress a large model called the teacher into a smaller model called the student. By distillating Bert, we obtain a smaller Transformer model that bears a lot of similarities with the original BERT model while being lighter, smaller and faster to run. DistilBERT is thus an interesting option to put large-scaled trained Transformer model into production.
|
||||
|
||||
We have applied the same method to other Transformer architectures and released the weights:
|
||||
- GPT2: on the [WikiText-103](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/) benchmark, GPT2 reaches a perplexity on the test set of 16.3 compared to 21.1 for **DistilGPT2** (after fine-tuning on the train set).
|
||||
- RoBERTa: **DistilRoBERTa** reaches 95% of `RoBERTa-base`'s performance on GLUE while being twice faster and 35% smaller.
|
||||
- German BERT: **German DistilBERT** reaches 99% of `bert-base-german-dbmdz-cased`'s performance on German NER (CoNLL-2003).
|
||||
- Multilingual BERT: **DistilmBERT** reaches 92% of Multilingual BERT's performance on XNLI while being twice faster and 25% smaller. The model supports 104 languages listed [here](https://github.com/google-research/bert/blob/master/multilingual.md#list-of-languages).
|
||||
|
||||
For more information on DistilBERT, please refer to our [NeurIPS workshop paper](https://arxiv.org/abs/1910.01108).
|
||||
|
||||
Here are the results on the dev sets of GLUE:
|
||||
|
||||
| Model | Macro-score | CoLA | MNLI | MRPC | QNLI | QQP | RTE | SST-2| STS-B| WNLI |
|
||||
| :---: | :---: | :---:| :---:| :---:| :---:| :---:| :---:| :---:| :---:| :---: |
|
||||
| BERT-base-uncased | **79.5** | 56.3 | 84.7 | 88.6 | 91.8 | 89.6 | 69.3 | 92.7 | 89.0 | 53.5 |
|
||||
| DistilBERT-base-uncased | **77.0** | 51.3 | 82.1 | 87.5 | 89.2 | 88.5 | 59.9 | 91.3 | 86.9 | 56.3 |
|
||||
| BERT-base-cased | **78.2** | 58.2 | 83.9 | 87.8 | 91.0 | 89.2 | 66.1 | 91.7 | 89.2 | 46.5 |
|
||||
| DistilBERT-base-cased | **75.9** | 47.2 | 81.5 | 85.6 | 88.2 | 87.8 | 60.6 | 90.4 | 85.5 | 56.3 |
|
||||
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
|
||||
| RoBERTa-base (reported) | **83.2**/**86.4**<sup>2</sup> | 63.6 | 87.6 | 90.2 | 92.8 | 91.9 | 78.7 | 94.8 | 91.2 | 57.7<sup>3</sup> |
|
||||
| DistilRoBERTa<sup>1</sup> | **79.0**/**82.3**<sup>2</sup> | 59.3 | 84.0 | 86.6 | 90.8 | 89.4 | 67.9 | 92.5 | 88.3 | 52.1 |
|
||||
|
||||
<sup>1</sup> We did not use the MNLI checkpoint for fine-tuning but directly perform transfer learning on the pre-trained DistilRoBERTa.
|
||||
|
||||
<sup>2</sup> Macro-score computed without WNLI.
|
||||
|
||||
<sup>3</sup> We compute this score ourselves for completeness.
|
||||
|
||||
Here are the results on the *test* sets for 6 of the languages available in XNLI. The results are computed in the zero shot setting (trained on the English portion and evaluated on the target language portion):
|
||||
|
||||
| Model | English | Spanish | Chinese | German | Arabic | Urdu |
|
||||
| :---: | :---: | :---: | :---: | :---: | :---: | :---:|
|
||||
| mBERT base cased (computed) | 82.1 | 74.6 | 69.1 | 72.3 | 66.4 | 58.5 |
|
||||
| mBERT base uncased (reported)| 81.4 | 74.3 | 63.8 | 70.5 | 62.1 | 58.3 |
|
||||
| DistilmBERT | 78.2 | 69.1 | 64.0 | 66.3 | 59.1 | 54.7 |
|
||||
|
||||
## Setup
|
||||
|
||||
This part of the library has only be tested with Python3.6+. There are few specific dependencies to install before launching a distillation, you can install them with the command `pip install -r requirements.txt`.
|
||||
|
||||
**Important note:** The training scripts have been updated to support PyTorch v1.2.0 (there are breaking changes compared to v1.1.0).
|
||||
|
||||
|
||||
## How to use DistilBERT
|
||||
|
||||
Transformers includes five pre-trained Distil* models, currently only provided for English and German (we are investigating the possibility to train and release a multilingual version of DistilBERT):
|
||||
|
||||
- `distilbert-base-uncased`: DistilBERT English language model pretrained on the same data used to pretrain Bert (concatenation of the Toronto Book Corpus and full English Wikipedia) using distillation with the supervision of the `bert-base-uncased` version of Bert. The model has 6 layers, 768 dimension and 12 heads, totalizing 66M parameters.
|
||||
- `distilbert-base-uncased-distilled-squad`: A finetuned version of `distilbert-base-uncased` finetuned using (a second step of) knowledge distillation on SQuAD 1.0. This model reaches a F1 score of 86.9 on the dev set (for comparison, Bert `bert-base-uncased` version reaches a 88.5 F1 score).
|
||||
- `distilbert-base-cased`: DistilBERT English language model pretrained on the same data used to pretrain Bert (concatenation of the Toronto Book Corpus and full English Wikipedia) using distillation with the supervision of the `bert-base-cased` version of Bert. The model has 6 layers, 768 dimension and 12 heads, totalizing 65M parameters.
|
||||
- `distilbert-base-cased-distilled-squad`: A finetuned version of `distilbert-base-cased` finetuned using (a second step of) knowledge distillation on SQuAD 1.0. This model reaches a F1 score of 87.1 on the dev set (for comparison, Bert `bert-base-cased` version reaches a 88.7 F1 score).
|
||||
- `distilbert-base-german-cased`: DistilBERT German language model pretrained on 1/2 of the data used to pretrain Bert using distillation with the supervision of the `bert-base-german-dbmdz-cased` version of German DBMDZ Bert. For NER tasks the model reaches a F1 score of 83.49 on the CoNLL-2003 test set (for comparison, `bert-base-german-dbmdz-cased` reaches a 84.52 F1 score), and a F1 score of 85.23 on the GermEval 2014 test set (`bert-base-german-dbmdz-cased` reaches a 86.89 F1 score).
|
||||
- `distilgpt2`: DistilGPT2 English language model pretrained with the supervision of `gpt2` (the smallest version of GPT2) on [OpenWebTextCorpus](https://skylion007.github.io/OpenWebTextCorpus/), a reproduction of OpenAI's WebText dataset. The model has 6 layers, 768 dimension and 12 heads, totalizing 82M parameters (compared to 124M parameters for GPT2). On average, DistilGPT2 is two times faster than GPT2.
|
||||
- `distilroberta-base`: DistilRoBERTa English language model pretrained with the supervision of `roberta-base` solely on [OpenWebTextCorpus](https://skylion007.github.io/OpenWebTextCorpus/), a reproduction of OpenAI's WebText dataset (it is ~4 times less training data than the teacher RoBERTa). The model has 6 layers, 768 dimension and 12 heads, totalizing 82M parameters (compared to 125M parameters for RoBERTa-base). On average DistilRoBERTa is twice as fast as Roberta-base.
|
||||
- `distilbert-base-multilingual-cased`: DistilmBERT multilingual model pretrained with the supervision of `bert-base-multilingual-cased` on the concatenation of Wikipedia in 104 different languages. The model supports the 104 languages listed [here](https://github.com/google-research/bert/blob/master/multilingual.md#list-of-languages). The model has 6 layers, 768 dimension and 12 heads, totalizing 134M parameters (compared to 177M parameters for mBERT-base). On average DistilmBERT is twice as fast as mBERT-base.
|
||||
|
||||
Using DistilBERT is very similar to using BERT. DistilBERT share the same tokenizer as BERT's `bert-base-uncased` even though we provide a link to this tokenizer under the `DistilBertTokenizer` name to have a consistent naming between the library models.
|
||||
|
||||
```python
|
||||
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
|
||||
model = DistilBertModel.from_pretrained('distilbert-base-cased')
|
||||
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)
|
||||
outputs = model(input_ids)
|
||||
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||
```
|
||||
|
||||
Similarly, using the other Distil* models simply consists in calling the base classes with a different pretrained checkpoint:
|
||||
- DistilBERT uncased: `model = DistilBertModel.from_pretrained('distilbert-base-uncased')`
|
||||
- DistilGPT2: `model = GPT2Model.from_pretrained('distilgpt2')`
|
||||
- DistilRoBERTa: `model = RobertaModel.from_pretrained('distilroberta-base')`
|
||||
- DistilmBERT: `model = DistilBertModel.from_pretrained('distilbert-base-multilingual-cased')`
|
||||
|
||||
|
||||
## How to train Distil*
|
||||
|
||||
In the following, we will explain how you can train DistilBERT.
|
||||
|
||||
### A. Preparing the data
|
||||
|
||||
The weights we release are trained using a concatenation of Toronto Book Corpus and English Wikipedia (same training data as the English version of BERT).
|
||||
|
||||
To avoid processing the data several time, we do it once and for all before the training. From now on, will suppose that you have a text file `dump.txt` which contains one sequence per line (a sequence being composed of one of several coherent sentences).
|
||||
|
||||
First, we will binarize the data, i.e. tokenize the data and convert each token in an index in our model's vocabulary.
|
||||
|
||||
```bash
|
||||
python scripts/binarized_data.py \
|
||||
--file_path data/dump.txt \
|
||||
--tokenizer_type bert \
|
||||
--tokenizer_name bert-base-uncased \
|
||||
--dump_file data/binarized_text
|
||||
```
|
||||
|
||||
Our implementation of masked language modeling loss follows [XLM](https://github.com/facebookresearch/XLM)'s one and smooths the probability of masking with a factor that put more emphasis on rare words. Thus we count the occurrences of each tokens in the data:
|
||||
|
||||
```bash
|
||||
python scripts/token_counts.py \
|
||||
--data_file data/binarized_text.bert-base-uncased.pickle \
|
||||
--token_counts_dump data/token_counts.bert-base-uncased.pickle \
|
||||
--vocab_size 30522
|
||||
```
|
||||
|
||||
### B. Training
|
||||
|
||||
Training with distillation is really simple once you have pre-processed the data:
|
||||
|
||||
```bash
|
||||
python train.py \
|
||||
--student_type distilbert \
|
||||
--student_config training_configs/distilbert-base-uncased.json \
|
||||
--teacher_type bert \
|
||||
--teacher_name bert-base-uncased \
|
||||
--alpha_ce 5.0 --alpha_mlm 2.0 --alpha_cos 1.0 --alpha_clm 0.0 --mlm \
|
||||
--freeze_pos_embs \
|
||||
--dump_path serialization_dir/my_first_training \
|
||||
--data_file data/binarized_text.bert-base-uncased.pickle \
|
||||
--token_counts data/token_counts.bert-base-uncased.pickle \
|
||||
--force # overwrites the `dump_path` if it already exists.
|
||||
```
|
||||
|
||||
By default, this will launch a training on a single GPU (even if more are available on the cluster). Other parameters are available in the command line, please look in `train.py` or run `python train.py --help` to list them.
|
||||
|
||||
We highly encourage you to use distributed training for training DistilBERT as the training corpus is quite large. Here's an example that runs a distributed training on a single node having 4 GPUs:
|
||||
|
||||
```bash
|
||||
export NODE_RANK=0
|
||||
export N_NODES=1
|
||||
|
||||
export N_GPU_NODE=4
|
||||
export WORLD_SIZE=4
|
||||
export MASTER_PORT=<AN_OPEN_PORT>
|
||||
export MASTER_ADDR=<I.P.>
|
||||
|
||||
pkill -f 'python -u train.py'
|
||||
|
||||
python -m torch.distributed.launch \
|
||||
--nproc_per_node=$N_GPU_NODE \
|
||||
--nnodes=$N_NODES \
|
||||
--node_rank $NODE_RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT \
|
||||
train.py \
|
||||
--force \
|
||||
--n_gpu $WORLD_SIZE \
|
||||
--student_type distilbert \
|
||||
--student_config training_configs/distilbert-base-uncased.json \
|
||||
--teacher_type bert \
|
||||
--teacher_name bert-base-uncased \
|
||||
--alpha_ce 0.33 --alpha_mlm 0.33 --alpha_cos 0.33 --alpha_clm 0.0 --mlm \
|
||||
--freeze_pos_embs \
|
||||
--dump_path serialization_dir/my_first_training \
|
||||
--data_file data/binarized_text.bert-base-uncased.pickle \
|
||||
--token_counts data/token_counts.bert-base-uncased.pickle
|
||||
```
|
||||
|
||||
**Tips:** Starting distilled training with good initialization of the model weights is crucial to reach decent performance. In our experiments, we initialized our model from a few layers of the teacher (Bert) itself! Please refer to `scripts/extract.py` and `scripts/extract_distilbert.py` to create a valid initialization checkpoint and use `--student_pretrained_weights` argument to use this initialization for the distilled training!
|
||||
|
||||
Happy distillation!
|
||||
|
||||
## Citation
|
||||
|
||||
If you find the resource useful, you should cite the following paper:
|
||||
|
||||
```bibtex
|
||||
@inproceedings{sanh2019distilbert,
|
||||
title={DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter},
|
||||
author={Sanh, Victor and Debut, Lysandre and Chaumond, Julien and Wolf, Thomas},
|
||||
booktitle={NeurIPS EMC^2 Workshop},
|
||||
year={2019}
|
||||
}
|
||||
```
|
||||
@@ -1,601 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""The distiller to distil the student.
|
||||
Adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
|
||||
"""
|
||||
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
|
||||
from lm_seqs_dataset import LmSeqsDataset
|
||||
from torch import nn
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
from utils import logger
|
||||
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except ImportError:
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
|
||||
class Distiller:
|
||||
def __init__(
|
||||
self, params: dict, dataset: LmSeqsDataset, token_probs: torch.tensor, student: nn.Module, teacher: nn.Module
|
||||
):
|
||||
logger.info("Initializing Distiller")
|
||||
self.params = params
|
||||
self.dump_path = params.dump_path
|
||||
self.multi_gpu = params.multi_gpu
|
||||
self.fp16 = params.fp16
|
||||
|
||||
self.student = student
|
||||
self.teacher = teacher
|
||||
|
||||
self.student_config = student.config
|
||||
self.vocab_size = student.config.vocab_size
|
||||
|
||||
if params.n_gpu <= 1:
|
||||
sampler = RandomSampler(dataset)
|
||||
else:
|
||||
sampler = DistributedSampler(dataset)
|
||||
|
||||
if params.group_by_size:
|
||||
groups = create_lengths_groups(lengths=dataset.lengths, k=params.max_model_input_size)
|
||||
sampler = GroupedBatchSampler(sampler=sampler, group_ids=groups, batch_size=params.batch_size)
|
||||
else:
|
||||
sampler = BatchSampler(sampler=sampler, batch_size=params.batch_size, drop_last=False)
|
||||
|
||||
self.dataloader = DataLoader(dataset=dataset, batch_sampler=sampler, collate_fn=dataset.batch_sequences)
|
||||
|
||||
self.temperature = params.temperature
|
||||
assert self.temperature > 0.0
|
||||
|
||||
self.alpha_ce = params.alpha_ce
|
||||
self.alpha_mlm = params.alpha_mlm
|
||||
self.alpha_clm = params.alpha_clm
|
||||
self.alpha_mse = params.alpha_mse
|
||||
self.alpha_cos = params.alpha_cos
|
||||
|
||||
self.mlm = params.mlm
|
||||
if self.mlm:
|
||||
logger.info("Using MLM loss for LM step.")
|
||||
self.mlm_mask_prop = params.mlm_mask_prop
|
||||
assert 0.0 <= self.mlm_mask_prop <= 1.0
|
||||
assert params.word_mask + params.word_keep + params.word_rand == 1.0
|
||||
self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand])
|
||||
self.pred_probs = self.pred_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else self.pred_probs
|
||||
self.token_probs = token_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else token_probs
|
||||
if self.fp16:
|
||||
self.pred_probs = self.pred_probs.half()
|
||||
self.token_probs = self.token_probs.half()
|
||||
else:
|
||||
logger.info("Using CLM loss for LM step.")
|
||||
|
||||
self.epoch = 0
|
||||
self.n_iter = 0
|
||||
self.n_total_iter = 0
|
||||
self.n_sequences_epoch = 0
|
||||
self.total_loss_epoch = 0
|
||||
self.last_loss = 0
|
||||
self.last_loss_ce = 0
|
||||
self.last_loss_mlm = 0
|
||||
self.last_loss_clm = 0
|
||||
if self.alpha_mse > 0.0:
|
||||
self.last_loss_mse = 0
|
||||
if self.alpha_cos > 0.0:
|
||||
self.last_loss_cos = 0
|
||||
self.last_log = 0
|
||||
|
||||
self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
|
||||
self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
||||
if self.alpha_mse > 0.0:
|
||||
self.mse_loss_fct = nn.MSELoss(reduction="sum")
|
||||
if self.alpha_cos > 0.0:
|
||||
self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")
|
||||
|
||||
logger.info("--- Initializing model optimizer")
|
||||
assert params.gradient_accumulation_steps >= 1
|
||||
self.num_steps_epoch = len(self.dataloader)
|
||||
num_train_optimization_steps = (
|
||||
int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1
|
||||
)
|
||||
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [
|
||||
p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad
|
||||
],
|
||||
"weight_decay": params.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad
|
||||
],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
logger.info(
|
||||
"------ Number of trainable parameters (student): %i"
|
||||
% sum([p.numel() for p in self.student.parameters() if p.requires_grad])
|
||||
)
|
||||
logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()]))
|
||||
self.optimizer = AdamW(
|
||||
optimizer_grouped_parameters, lr=params.learning_rate, eps=params.adam_epsilon, betas=(0.9, 0.98)
|
||||
)
|
||||
|
||||
warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
|
||||
self.scheduler = get_linear_schedule_with_warmup(
|
||||
self.optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps
|
||||
)
|
||||
|
||||
if self.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
logger.info(f"Using fp16 training: {self.params.fp16_opt_level} level")
|
||||
self.student, self.optimizer = amp.initialize(
|
||||
self.student, self.optimizer, opt_level=self.params.fp16_opt_level
|
||||
)
|
||||
self.teacher = self.teacher.half()
|
||||
|
||||
if self.multi_gpu:
|
||||
if self.fp16:
|
||||
from apex.parallel import DistributedDataParallel
|
||||
|
||||
logger.info("Using apex.parallel.DistributedDataParallel for distributed training.")
|
||||
self.student = DistributedDataParallel(self.student)
|
||||
else:
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
logger.info("Using nn.parallel.DistributedDataParallel for distributed training.")
|
||||
self.student = DistributedDataParallel(
|
||||
self.student,
|
||||
device_ids=[params.local_rank],
|
||||
output_device=params.local_rank,
|
||||
find_unused_parameters=True,
|
||||
)
|
||||
|
||||
self.is_master = params.is_master
|
||||
if self.is_master:
|
||||
logger.info("--- Initializing Tensorboard")
|
||||
self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, "log", "train"))
|
||||
self.tensorboard.add_text(tag="config/training", text_string=str(self.params), global_step=0)
|
||||
self.tensorboard.add_text(tag="config/student", text_string=str(self.student_config), global_step=0)
|
||||
|
||||
def prepare_batch_mlm(self, batch):
|
||||
"""
|
||||
Prepare the batch: from the token_ids and the lengths, compute the attention mask and the masked label for MLM.
|
||||
|
||||
Input:
|
||||
------
|
||||
batch: `Tuple`
|
||||
token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
|
||||
lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.
|
||||
|
||||
Output:
|
||||
-------
|
||||
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
|
||||
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
|
||||
mlm_labels: `torch.tensor(bs, seq_length)` - The masked language modeling labels. There is a -100 where there is nothing to predict.
|
||||
"""
|
||||
token_ids, lengths = batch
|
||||
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
|
||||
assert token_ids.size(0) == lengths.size(0)
|
||||
|
||||
attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]
|
||||
|
||||
bs, max_seq_len = token_ids.size()
|
||||
mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
|
||||
|
||||
x_prob = self.token_probs[token_ids.flatten()]
|
||||
n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
|
||||
tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False)
|
||||
pred_mask = torch.zeros(
|
||||
bs * max_seq_len, dtype=torch.bool, device=token_ids.device
|
||||
) # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
|
||||
pred_mask[tgt_ids] = 1
|
||||
pred_mask = pred_mask.view(bs, max_seq_len)
|
||||
|
||||
pred_mask[token_ids == self.params.special_tok_ids["pad_token"]] = 0
|
||||
|
||||
# mask a number of words == 0 [8] (faster with fp16)
|
||||
if self.fp16:
|
||||
n1 = pred_mask.sum().item()
|
||||
if n1 > 8:
|
||||
pred_mask = pred_mask.view(-1)
|
||||
n2 = max(n1 % 8, 8 * (n1 // 8))
|
||||
if n2 != n1:
|
||||
pred_mask[torch.nonzero(pred_mask).view(-1)[: n1 - n2]] = 0
|
||||
pred_mask = pred_mask.view(bs, max_seq_len)
|
||||
assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item()
|
||||
|
||||
_token_ids_real = token_ids[pred_mask]
|
||||
_token_ids_rand = _token_ids_real.clone().random_(self.vocab_size)
|
||||
_token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids["mask_token"])
|
||||
probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True)
|
||||
_token_ids = (
|
||||
_token_ids_mask * (probs == 0).long()
|
||||
+ _token_ids_real * (probs == 1).long()
|
||||
+ _token_ids_rand * (probs == 2).long()
|
||||
)
|
||||
token_ids = token_ids.masked_scatter(pred_mask, _token_ids)
|
||||
|
||||
mlm_labels[~pred_mask] = -100 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
|
||||
|
||||
# sanity checks
|
||||
assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size
|
||||
|
||||
return token_ids, attn_mask, mlm_labels
|
||||
|
||||
def prepare_batch_clm(self, batch):
|
||||
"""
|
||||
Prepare the batch: from the token_ids and the lengths, compute the attention mask and the labels for CLM.
|
||||
|
||||
Input:
|
||||
------
|
||||
batch: `Tuple`
|
||||
token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
|
||||
lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.
|
||||
|
||||
Output:
|
||||
-------
|
||||
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
|
||||
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
|
||||
clm_labels: `torch.tensor(bs, seq_length)` - The causal language modeling labels. There is a -100 where there is nothing to predict.
|
||||
"""
|
||||
token_ids, lengths = batch
|
||||
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
|
||||
assert token_ids.size(0) == lengths.size(0)
|
||||
|
||||
attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]
|
||||
clm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
|
||||
clm_labels[~attn_mask] = -100 # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility
|
||||
|
||||
# sanity checks
|
||||
assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size
|
||||
|
||||
return token_ids, attn_mask, clm_labels
|
||||
|
||||
def round_batch(self, x: torch.tensor, lengths: torch.tensor):
|
||||
"""
|
||||
For float16 only.
|
||||
Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.
|
||||
|
||||
Input:
|
||||
------
|
||||
x: `torch.tensor(bs, seq_length)` - The token ids.
|
||||
lengths: `torch.tensor(bs, seq_length)` - The lengths of each of the sequence in the batch.
|
||||
|
||||
Output:
|
||||
-------
|
||||
x: `torch.tensor(new_bs, new_seq_length)` - The updated token ids.
|
||||
lengths: `torch.tensor(new_bs, new_seq_length)` - The updated lengths.
|
||||
"""
|
||||
if not self.fp16 or len(lengths) < 8:
|
||||
return x, lengths
|
||||
|
||||
# number of sentences == 0 [8]
|
||||
bs1 = len(lengths)
|
||||
bs2 = 8 * (bs1 // 8)
|
||||
assert bs2 > 0 and bs2 % 8 == 0
|
||||
if bs1 != bs2:
|
||||
idx = torch.randperm(bs1)[:bs2]
|
||||
lengths = lengths[idx]
|
||||
slen = lengths.max().item()
|
||||
x = x[idx, :slen]
|
||||
else:
|
||||
idx = None
|
||||
|
||||
# sequence length == 0 [8]
|
||||
ml1 = x.size(1)
|
||||
if ml1 % 8 != 0:
|
||||
pad = 8 - (ml1 % 8)
|
||||
ml2 = ml1 + pad
|
||||
if self.mlm:
|
||||
pad_id = self.params.special_tok_ids["pad_token"]
|
||||
else:
|
||||
pad_id = self.params.special_tok_ids["unk_token"]
|
||||
padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id)
|
||||
x = torch.cat([x, padding_tensor], 1)
|
||||
assert x.size() == (bs2, ml2)
|
||||
|
||||
assert x.size(0) % 8 == 0
|
||||
assert x.size(1) % 8 == 0
|
||||
return x, lengths
|
||||
|
||||
def train(self):
|
||||
"""
|
||||
The real training loop.
|
||||
"""
|
||||
if self.is_master:
|
||||
logger.info("Starting training")
|
||||
self.last_log = time.time()
|
||||
self.student.train()
|
||||
self.teacher.eval()
|
||||
|
||||
for _ in range(self.params.n_epoch):
|
||||
if self.is_master:
|
||||
logger.info(f"--- Starting epoch {self.epoch}/{self.params.n_epoch-1}")
|
||||
if self.multi_gpu:
|
||||
torch.distributed.barrier()
|
||||
|
||||
iter_bar = tqdm(self.dataloader, desc="-Iter", disable=self.params.local_rank not in [-1, 0])
|
||||
for batch in iter_bar:
|
||||
if self.params.n_gpu > 0:
|
||||
batch = tuple(t.to(f"cuda:{self.params.local_rank}") for t in batch)
|
||||
|
||||
if self.mlm:
|
||||
token_ids, attn_mask, lm_labels = self.prepare_batch_mlm(batch=batch)
|
||||
else:
|
||||
token_ids, attn_mask, lm_labels = self.prepare_batch_clm(batch=batch)
|
||||
self.step(input_ids=token_ids, attention_mask=attn_mask, lm_labels=lm_labels)
|
||||
|
||||
iter_bar.update()
|
||||
iter_bar.set_postfix(
|
||||
{"Last_loss": f"{self.last_loss:.2f}", "Avg_cum_loss": f"{self.total_loss_epoch/self.n_iter:.2f}"}
|
||||
)
|
||||
iter_bar.close()
|
||||
|
||||
if self.is_master:
|
||||
logger.info(f"--- Ending epoch {self.epoch}/{self.params.n_epoch-1}")
|
||||
self.end_epoch()
|
||||
|
||||
if self.is_master:
|
||||
logger.info("Save very last checkpoint as `pytorch_model.bin`.")
|
||||
self.save_checkpoint(checkpoint_name="pytorch_model.bin")
|
||||
logger.info("Training is finished")
|
||||
|
||||
def step(self, input_ids: torch.tensor, attention_mask: torch.tensor, lm_labels: torch.tensor):
|
||||
"""
|
||||
One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
|
||||
and possibly a parameter update (depending on the gradient accumulation).
|
||||
|
||||
Input:
|
||||
------
|
||||
input_ids: `torch.tensor(bs, seq_length)` - The token ids.
|
||||
attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
|
||||
lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
|
||||
"""
|
||||
if self.mlm:
|
||||
student_outputs = self.student(
|
||||
input_ids=input_ids, attention_mask=attention_mask
|
||||
) # (bs, seq_length, voc_size)
|
||||
with torch.no_grad():
|
||||
teacher_outputs = self.teacher(
|
||||
input_ids=input_ids, attention_mask=attention_mask
|
||||
) # (bs, seq_length, voc_size)
|
||||
else:
|
||||
student_outputs = self.student(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size)
|
||||
with torch.no_grad():
|
||||
teacher_outputs = self.teacher(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size)
|
||||
s_logits, s_hidden_states = student_outputs["logits"], student_outputs["hidden_states"]
|
||||
t_logits, t_hidden_states = teacher_outputs["logits"], teacher_outputs["hidden_states"]
|
||||
assert s_logits.size() == t_logits.size()
|
||||
|
||||
# https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
|
||||
# https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
|
||||
if self.params.restrict_ce_to_mask:
|
||||
mask = (lm_labels > -1).unsqueeze(-1).expand_as(s_logits) # (bs, seq_length, voc_size)
|
||||
else:
|
||||
mask = attention_mask.unsqueeze(-1).expand_as(s_logits) # (bs, seq_length, voc_size)
|
||||
s_logits_slct = torch.masked_select(s_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||
s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
|
||||
t_logits_slct = torch.masked_select(t_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||
t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
|
||||
assert t_logits_slct.size() == s_logits_slct.size()
|
||||
|
||||
loss_ce = (
|
||||
self.ce_loss_fct(
|
||||
nn.functional.log_softmax(s_logits_slct / self.temperature, dim=-1),
|
||||
nn.functional.softmax(t_logits_slct / self.temperature, dim=-1),
|
||||
)
|
||||
* (self.temperature) ** 2
|
||||
)
|
||||
loss = self.alpha_ce * loss_ce
|
||||
|
||||
if self.alpha_mlm > 0.0:
|
||||
loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)), lm_labels.view(-1))
|
||||
loss += self.alpha_mlm * loss_mlm
|
||||
if self.alpha_clm > 0.0:
|
||||
shift_logits = s_logits[..., :-1, :].contiguous()
|
||||
shift_labels = lm_labels[..., 1:].contiguous()
|
||||
loss_clm = self.lm_loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
loss += self.alpha_clm * loss_clm
|
||||
|
||||
if self.alpha_mse > 0.0:
|
||||
loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct) / s_logits_slct.size(
|
||||
0
|
||||
) # Reproducing batchmean reduction
|
||||
loss += self.alpha_mse * loss_mse
|
||||
if self.alpha_cos > 0.0:
|
||||
s_hidden_states = s_hidden_states[-1] # (bs, seq_length, dim)
|
||||
t_hidden_states = t_hidden_states[-1] # (bs, seq_length, dim)
|
||||
mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states) # (bs, seq_length, dim)
|
||||
assert s_hidden_states.size() == t_hidden_states.size()
|
||||
dim = s_hidden_states.size(-1)
|
||||
|
||||
s_hidden_states_slct = torch.masked_select(s_hidden_states, mask) # (bs * seq_length * dim)
|
||||
s_hidden_states_slct = s_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim)
|
||||
t_hidden_states_slct = torch.masked_select(t_hidden_states, mask) # (bs * seq_length * dim)
|
||||
t_hidden_states_slct = t_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim)
|
||||
|
||||
target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1) # (bs * seq_length,)
|
||||
loss_cos = self.cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target)
|
||||
loss += self.alpha_cos * loss_cos
|
||||
|
||||
self.total_loss_epoch += loss.item()
|
||||
self.last_loss = loss.item()
|
||||
self.last_loss_ce = loss_ce.item()
|
||||
if self.alpha_mlm > 0.0:
|
||||
self.last_loss_mlm = loss_mlm.item()
|
||||
if self.alpha_clm > 0.0:
|
||||
self.last_loss_clm = loss_clm.item()
|
||||
if self.alpha_mse > 0.0:
|
||||
self.last_loss_mse = loss_mse.item()
|
||||
if self.alpha_cos > 0.0:
|
||||
self.last_loss_cos = loss_cos.item()
|
||||
|
||||
self.optimize(loss)
|
||||
|
||||
self.n_sequences_epoch += input_ids.size(0)
|
||||
|
||||
def optimize(self, loss):
|
||||
"""
|
||||
Normalization on the loss (gradient accumulation or distributed training), followed by
|
||||
backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
|
||||
Also update the metrics for tensorboard.
|
||||
"""
|
||||
# Check for NaN
|
||||
if (loss != loss).data.any():
|
||||
logger.error("NaN detected")
|
||||
exit()
|
||||
|
||||
if self.multi_gpu:
|
||||
loss = loss.mean()
|
||||
if self.params.gradient_accumulation_steps > 1:
|
||||
loss = loss / self.params.gradient_accumulation_steps
|
||||
|
||||
if self.fp16:
|
||||
from apex import amp
|
||||
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
self.iter()
|
||||
if self.n_iter % self.params.gradient_accumulation_steps == 0:
|
||||
if self.fp16:
|
||||
nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm)
|
||||
else:
|
||||
nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm)
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.scheduler.step()
|
||||
|
||||
def iter(self):
|
||||
"""
|
||||
Update global counts, write to tensorboard and save checkpoint.
|
||||
"""
|
||||
self.n_iter += 1
|
||||
self.n_total_iter += 1
|
||||
|
||||
if self.n_total_iter % self.params.log_interval == 0:
|
||||
self.log_tensorboard()
|
||||
self.last_log = time.time()
|
||||
if self.n_total_iter % self.params.checkpoint_interval == 0:
|
||||
self.save_checkpoint()
|
||||
|
||||
def log_tensorboard(self):
|
||||
"""
|
||||
Log into tensorboard. Only by the master process.
|
||||
"""
|
||||
if not self.is_master:
|
||||
return
|
||||
|
||||
for param_name, param in self.student.named_parameters():
|
||||
self.tensorboard.add_scalar(
|
||||
tag="parameter_mean/" + param_name, scalar_value=param.data.mean(), global_step=self.n_total_iter
|
||||
)
|
||||
self.tensorboard.add_scalar(
|
||||
tag="parameter_std/" + param_name, scalar_value=param.data.std(), global_step=self.n_total_iter
|
||||
)
|
||||
if param.grad is None:
|
||||
continue
|
||||
self.tensorboard.add_scalar(
|
||||
tag="grad_mean/" + param_name, scalar_value=param.grad.data.mean(), global_step=self.n_total_iter
|
||||
)
|
||||
self.tensorboard.add_scalar(
|
||||
tag="grad_std/" + param_name, scalar_value=param.grad.data.std(), global_step=self.n_total_iter
|
||||
)
|
||||
|
||||
self.tensorboard.add_scalar(
|
||||
tag="losses/cum_avg_loss_epoch",
|
||||
scalar_value=self.total_loss_epoch / self.n_iter,
|
||||
global_step=self.n_total_iter,
|
||||
)
|
||||
self.tensorboard.add_scalar(tag="losses/loss", scalar_value=self.last_loss, global_step=self.n_total_iter)
|
||||
self.tensorboard.add_scalar(
|
||||
tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter
|
||||
)
|
||||
if self.alpha_mlm > 0.0:
|
||||
self.tensorboard.add_scalar(
|
||||
tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter
|
||||
)
|
||||
if self.alpha_clm > 0.0:
|
||||
self.tensorboard.add_scalar(
|
||||
tag="losses/loss_clm", scalar_value=self.last_loss_clm, global_step=self.n_total_iter
|
||||
)
|
||||
if self.alpha_mse > 0.0:
|
||||
self.tensorboard.add_scalar(
|
||||
tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter
|
||||
)
|
||||
if self.alpha_cos > 0.0:
|
||||
self.tensorboard.add_scalar(
|
||||
tag="losses/loss_cos", scalar_value=self.last_loss_cos, global_step=self.n_total_iter
|
||||
)
|
||||
self.tensorboard.add_scalar(
|
||||
tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter
|
||||
)
|
||||
|
||||
self.tensorboard.add_scalar(
|
||||
tag="global/memory_usage",
|
||||
scalar_value=psutil.virtual_memory()._asdict()["used"] / 1_000_000,
|
||||
global_step=self.n_total_iter,
|
||||
)
|
||||
self.tensorboard.add_scalar(
|
||||
tag="global/speed", scalar_value=time.time() - self.last_log, global_step=self.n_total_iter
|
||||
)
|
||||
|
||||
def end_epoch(self):
|
||||
"""
|
||||
Finally arrived at the end of epoch (full pass on dataset).
|
||||
Do some tensorboard logging and checkpoint saving.
|
||||
"""
|
||||
logger.info(f"{self.n_sequences_epoch} sequences have been trained during this epoch.")
|
||||
|
||||
if self.is_master:
|
||||
self.save_checkpoint(checkpoint_name=f"model_epoch_{self.epoch}.pth")
|
||||
self.tensorboard.add_scalar(
|
||||
tag="epoch/loss", scalar_value=self.total_loss_epoch / self.n_iter, global_step=self.epoch
|
||||
)
|
||||
|
||||
self.epoch += 1
|
||||
self.n_sequences_epoch = 0
|
||||
self.n_iter = 0
|
||||
self.total_loss_epoch = 0
|
||||
|
||||
def save_checkpoint(self, checkpoint_name: str = "checkpoint.pth"):
|
||||
"""
|
||||
Save the current state. Only by the master process.
|
||||
"""
|
||||
if not self.is_master:
|
||||
return
|
||||
mdl_to_save = self.student.module if hasattr(self.student, "module") else self.student
|
||||
mdl_to_save.config.save_pretrained(self.dump_path)
|
||||
state_dict = mdl_to_save.state_dict()
|
||||
torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))
|
||||
@@ -1,108 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Adapted from PyTorch Vision (https://github.com/pytorch/vision/blob/master/references/detection/group_by_aspect_ratio.py)"""
|
||||
|
||||
import bisect
|
||||
import copy
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data import BatchSampler, Sampler
|
||||
|
||||
from utils import logger
|
||||
|
||||
|
||||
def _quantize(x, bins):
|
||||
bins = copy.deepcopy(bins)
|
||||
bins = sorted(bins)
|
||||
quantized = [bisect.bisect_right(bins, y) for y in x]
|
||||
return quantized
|
||||
|
||||
|
||||
def create_lengths_groups(lengths, k=0):
|
||||
bins = np.arange(start=3, stop=k, step=4).tolist() if k > 0 else [10]
|
||||
groups = _quantize(lengths, bins)
|
||||
# count number of elements per group
|
||||
counts = np.unique(groups, return_counts=True)[1]
|
||||
fbins = [0] + bins + [np.inf]
|
||||
logger.info("Using {} as bins for aspect lengths quantization".format(fbins))
|
||||
logger.info("Count of instances per bin: {}".format(counts))
|
||||
return groups
|
||||
|
||||
|
||||
class GroupedBatchSampler(BatchSampler):
|
||||
"""
|
||||
Wraps another sampler to yield a mini-batch of indices.
|
||||
It enforces that the batch only contain elements from the same group.
|
||||
It also tries to provide mini-batches which follows an ordering which is
|
||||
as close as possible to the ordering from the original sampler.
|
||||
Arguments:
|
||||
sampler (Sampler): Base sampler.
|
||||
group_ids (list[int]): If the sampler produces indices in range [0, N),
|
||||
`group_ids` must be a list of `N` ints which contains the group id of each sample.
|
||||
The group ids must be a continuous set of integers starting from
|
||||
0, i.e. they must be in the range [0, num_groups).
|
||||
batch_size (int): Size of mini-batch.
|
||||
"""
|
||||
|
||||
def __init__(self, sampler, group_ids, batch_size):
|
||||
if not isinstance(sampler, Sampler):
|
||||
raise TypeError(
|
||||
"sampler should be an instance of torch.utils.data.Sampler, but got sampler={}".format(sampler)
|
||||
)
|
||||
self.sampler = sampler
|
||||
self.group_ids = group_ids
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __iter__(self):
|
||||
buffer_per_group = defaultdict(list)
|
||||
samples_per_group = defaultdict(list)
|
||||
|
||||
num_batches = 0
|
||||
for idx in self.sampler:
|
||||
group_id = self.group_ids[idx]
|
||||
buffer_per_group[group_id].append(idx)
|
||||
samples_per_group[group_id].append(idx)
|
||||
if len(buffer_per_group[group_id]) == self.batch_size:
|
||||
yield buffer_per_group[group_id] # TODO
|
||||
num_batches += 1
|
||||
del buffer_per_group[group_id]
|
||||
assert len(buffer_per_group[group_id]) < self.batch_size
|
||||
|
||||
# now we have run out of elements that satisfy
|
||||
# the group criteria, let's return the remaining
|
||||
# elements so that the size of the sampler is
|
||||
# deterministic
|
||||
expected_num_batches = len(self)
|
||||
num_remaining = expected_num_batches - num_batches
|
||||
if num_remaining > 0:
|
||||
# for the remaining batches, group the batches by similar lengths
|
||||
batch_idx = []
|
||||
for group_id, idxs in sorted(buffer_per_group.items(), key=lambda x: x[0]):
|
||||
batch_idx.extend(idxs)
|
||||
if len(batch_idx) >= self.batch_size:
|
||||
yield batch_idx[: self.batch_size]
|
||||
batch_idx = batch_idx[self.batch_size :]
|
||||
num_remaining -= 1
|
||||
if len(batch_idx) > 0:
|
||||
yield batch_idx
|
||||
num_remaining -= 1
|
||||
assert num_remaining == 0
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Return the number of mini-batches rather than the number of samples.
|
||||
"""
|
||||
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
|
||||
@@ -1,167 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Dataset to distilled models
|
||||
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from utils import logger
|
||||
|
||||
|
||||
class LmSeqsDataset(Dataset):
|
||||
"""Custom Dataset wrapping language modeling sequences.
|
||||
|
||||
Each sample will be retrieved by indexing the list of token_ids and their corresponding lengths.
|
||||
|
||||
Input:
|
||||
------
|
||||
params: `NameSpace` parameters
|
||||
data: `List[np.array[int]]
|
||||
"""
|
||||
|
||||
def __init__(self, params, data):
|
||||
self.params = params
|
||||
|
||||
self.token_ids = np.array(data)
|
||||
self.lengths = np.array([len(t) for t in data])
|
||||
|
||||
self.check()
|
||||
self.remove_long_sequences()
|
||||
self.remove_empty_sequences()
|
||||
self.remove_unknown_sequences()
|
||||
self.check()
|
||||
self.print_statistics()
|
||||
|
||||
def __getitem__(self, index):
|
||||
return (self.token_ids[index], self.lengths[index])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.lengths)
|
||||
|
||||
def check(self):
|
||||
"""
|
||||
Some sanity checks
|
||||
"""
|
||||
assert len(self.token_ids) == len(self.lengths)
|
||||
assert all(self.lengths[i] == len(self.token_ids[i]) for i in range(len(self.lengths)))
|
||||
|
||||
def remove_long_sequences(self):
|
||||
"""
|
||||
Sequences that are too long are split by chunk of max_model_input_size.
|
||||
"""
|
||||
max_len = self.params.max_model_input_size
|
||||
indices = self.lengths > max_len
|
||||
logger.info(f"Splitting {sum(indices)} too long sequences.")
|
||||
|
||||
def divide_chunks(l, n):
|
||||
return [l[i : i + n] for i in range(0, len(l), n)]
|
||||
|
||||
new_tok_ids = []
|
||||
new_lengths = []
|
||||
if self.params.mlm:
|
||||
cls_id, sep_id = self.params.special_tok_ids["cls_token"], self.params.special_tok_ids["sep_token"]
|
||||
else:
|
||||
cls_id, sep_id = self.params.special_tok_ids["bos_token"], self.params.special_tok_ids["eos_token"]
|
||||
|
||||
for seq_, len_ in zip(self.token_ids, self.lengths):
|
||||
assert (seq_[0] == cls_id) and (seq_[-1] == sep_id), seq_
|
||||
if len_ <= max_len:
|
||||
new_tok_ids.append(seq_)
|
||||
new_lengths.append(len_)
|
||||
else:
|
||||
sub_seqs = []
|
||||
for sub_s in divide_chunks(seq_, max_len - 2):
|
||||
if sub_s[0] != cls_id:
|
||||
sub_s = np.insert(sub_s, 0, cls_id)
|
||||
if sub_s[-1] != sep_id:
|
||||
sub_s = np.insert(sub_s, len(sub_s), sep_id)
|
||||
assert len(sub_s) <= max_len
|
||||
assert (sub_s[0] == cls_id) and (sub_s[-1] == sep_id), sub_s
|
||||
sub_seqs.append(sub_s)
|
||||
|
||||
new_tok_ids.extend(sub_seqs)
|
||||
new_lengths.extend([len(l) for l in sub_seqs])
|
||||
|
||||
self.token_ids = np.array(new_tok_ids)
|
||||
self.lengths = np.array(new_lengths)
|
||||
|
||||
def remove_empty_sequences(self):
|
||||
"""
|
||||
Too short sequences are simply removed. This could be tuned.
|
||||
"""
|
||||
init_size = len(self)
|
||||
indices = self.lengths > 11
|
||||
self.token_ids = self.token_ids[indices]
|
||||
self.lengths = self.lengths[indices]
|
||||
new_size = len(self)
|
||||
logger.info(f"Remove {init_size - new_size} too short (<=11 tokens) sequences.")
|
||||
|
||||
def remove_unknown_sequences(self):
|
||||
"""
|
||||
Remove sequences with a (too) high level of unknown tokens.
|
||||
"""
|
||||
if "unk_token" not in self.params.special_tok_ids:
|
||||
return
|
||||
else:
|
||||
unk_token_id = self.params.special_tok_ids["unk_token"]
|
||||
init_size = len(self)
|
||||
unk_occs = np.array([np.count_nonzero(a == unk_token_id) for a in self.token_ids])
|
||||
indices = (unk_occs / self.lengths) < 0.5
|
||||
self.token_ids = self.token_ids[indices]
|
||||
self.lengths = self.lengths[indices]
|
||||
new_size = len(self)
|
||||
logger.info(f"Remove {init_size - new_size} sequences with a high level of unknown tokens (50%).")
|
||||
|
||||
def print_statistics(self):
|
||||
"""
|
||||
Print some statistics on the corpus. Only the master process.
|
||||
"""
|
||||
if not self.params.is_master:
|
||||
return
|
||||
logger.info(f"{len(self)} sequences")
|
||||
# data_len = sum(self.lengths)
|
||||
# nb_unique_tokens = len(Counter(list(chain(*self.token_ids))))
|
||||
# logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)')
|
||||
|
||||
# unk_idx = self.params.special_tok_ids['unk_token']
|
||||
# nb_unknown = sum([(t==unk_idx).sum() for t in self.token_ids])
|
||||
# logger.info(f'{nb_unknown} unknown tokens (covering {100*nb_unknown/data_len:.2f}% of the data)')
|
||||
|
||||
def batch_sequences(self, batch):
|
||||
"""
|
||||
Do the padding and transform into torch.tensor.
|
||||
"""
|
||||
token_ids = [t[0] for t in batch]
|
||||
lengths = [t[1] for t in batch]
|
||||
assert len(token_ids) == len(lengths)
|
||||
|
||||
# Max for paddings
|
||||
max_seq_len_ = max(lengths)
|
||||
|
||||
# Pad token ids
|
||||
if self.params.mlm:
|
||||
pad_idx = self.params.special_tok_ids["pad_token"]
|
||||
else:
|
||||
pad_idx = self.params.special_tok_ids["unk_token"]
|
||||
tk_ = [list(t.astype(int)) + [pad_idx] * (max_seq_len_ - len(t)) for t in token_ids]
|
||||
assert len(tk_) == len(token_ids)
|
||||
assert all(len(t) == max_seq_len_ for t in tk_)
|
||||
|
||||
tk_t = torch.tensor(tk_) # (bs, max_seq_len_)
|
||||
lg_t = torch.tensor(lengths) # (bs)
|
||||
return tk_t, lg_t
|
||||
@@ -1,7 +0,0 @@
|
||||
transformers
|
||||
|
||||
gitpython==3.1.41
|
||||
tensorboard>=1.14.0
|
||||
tensorboardX==1.8
|
||||
psutil==5.6.6
|
||||
scipy>=1.4.1
|
||||
@@ -1,877 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""This is the exact same script as `examples/question-answering/run_squad.py` (as of 2020, January 8th) with an additional and optional step of distillation."""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import timeit
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
BertConfig,
|
||||
BertForQuestionAnswering,
|
||||
BertTokenizer,
|
||||
DistilBertConfig,
|
||||
DistilBertForQuestionAnswering,
|
||||
DistilBertTokenizer,
|
||||
RobertaConfig,
|
||||
RobertaForQuestionAnswering,
|
||||
RobertaTokenizer,
|
||||
XLMConfig,
|
||||
XLMForQuestionAnswering,
|
||||
XLMTokenizer,
|
||||
XLNetConfig,
|
||||
XLNetForQuestionAnswering,
|
||||
XLNetTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
squad_convert_examples_to_features,
|
||||
)
|
||||
from transformers.data.metrics.squad_metrics import (
|
||||
compute_predictions_log_probs,
|
||||
compute_predictions_logits,
|
||||
squad_evaluate,
|
||||
)
|
||||
from transformers.data.processors.squad import SquadResult, SquadV1Processor, SquadV2Processor
|
||||
from transformers.trainer_utils import is_main_process
|
||||
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except ImportError:
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
||||
"xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
|
||||
"xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
|
||||
"distilbert": (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
|
||||
"roberta": (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer),
|
||||
}
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
def to_list(tensor):
|
||||
return tensor.detach().cpu().tolist()
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
"""Train the model"""
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter()
|
||||
|
||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
if args.max_steps > 0:
|
||||
t_total = args.max_steps
|
||||
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
||||
else:
|
||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
|
||||
os.path.join(args.model_name_or_path, "scheduler.pt")
|
||||
):
|
||||
# Load in optimizer and scheduler states
|
||||
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
||||
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
|
||||
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||
|
||||
# multi-gpu training (should be after apex fp16 initialization)
|
||||
if args.n_gpu > 1:
|
||||
model = nn.DataParallel(model)
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if args.local_rank != -1:
|
||||
model = nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||
)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size
|
||||
* args.gradient_accumulation_steps
|
||||
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
global_step = 1
|
||||
epochs_trained = 0
|
||||
steps_trained_in_current_epoch = 0
|
||||
# Check if continuing training from a checkpoint
|
||||
if os.path.exists(args.model_name_or_path):
|
||||
try:
|
||||
# set global_step to global_step of last saved checkpoint from model path
|
||||
checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
|
||||
global_step = int(checkpoint_suffix)
|
||||
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||
|
||||
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
||||
logger.info(" Continuing training from epoch %d", epochs_trained)
|
||||
logger.info(" Continuing training from global step %d", global_step)
|
||||
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
||||
except ValueError:
|
||||
logger.info(" Starting fine-tuning.")
|
||||
|
||||
tr_loss, logging_loss = 0.0, 0.0
|
||||
model.zero_grad()
|
||||
train_iterator = trange(
|
||||
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
|
||||
)
|
||||
# Added here for reproducibility
|
||||
set_seed(args)
|
||||
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
# Skip past any already trained steps if resuming training
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
steps_trained_in_current_epoch -= 1
|
||||
continue
|
||||
|
||||
model.train()
|
||||
if teacher is not None:
|
||||
teacher.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
|
||||
inputs = {
|
||||
"input_ids": batch[0],
|
||||
"attention_mask": batch[1],
|
||||
"start_positions": batch[3],
|
||||
"end_positions": batch[4],
|
||||
}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2]
|
||||
if args.model_type in ["xlnet", "xlm"]:
|
||||
inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
|
||||
if args.version_2_with_negative:
|
||||
inputs.update({"is_impossible": batch[7]})
|
||||
outputs = model(**inputs)
|
||||
loss, start_logits_stu, end_logits_stu = outputs
|
||||
|
||||
# Distillation loss
|
||||
if teacher is not None:
|
||||
if "token_type_ids" not in inputs:
|
||||
inputs["token_type_ids"] = None if args.teacher_type == "xlm" else batch[2]
|
||||
with torch.no_grad():
|
||||
start_logits_tea, end_logits_tea = teacher(
|
||||
input_ids=inputs["input_ids"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
)
|
||||
assert start_logits_tea.size() == start_logits_stu.size()
|
||||
assert end_logits_tea.size() == end_logits_stu.size()
|
||||
|
||||
loss_fct = nn.KLDivLoss(reduction="batchmean")
|
||||
loss_start = loss_fct(
|
||||
nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1),
|
||||
nn.functional.softmax(start_logits_tea / args.temperature, dim=-1),
|
||||
) * (args.temperature**2)
|
||||
loss_end = loss_fct(
|
||||
nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1),
|
||||
nn.functional.softmax(end_logits_tea / args.temperature, dim=-1),
|
||||
) * (args.temperature**2)
|
||||
loss_ce = (loss_start + loss_end) / 2.0
|
||||
|
||||
loss = args.alpha_ce * loss_ce + args.alpha_squad * loss
|
||||
|
||||
if args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
if args.fp16:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
tr_loss += loss.item()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
if args.fp16:
|
||||
nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||
else:
|
||||
nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step() # Update learning rate schedule
|
||||
model.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
# Log metrics
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
# Only evaluate when single GPU otherwise metrics may not average well
|
||||
if args.local_rank == -1 and args.evaluate_during_training:
|
||||
results = evaluate(args, model, tokenizer)
|
||||
for key, value in results.items():
|
||||
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
||||
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
|
||||
logging_loss = tr_loss
|
||||
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
epoch_iterator.close()
|
||||
break
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
train_iterator.close()
|
||||
break
|
||||
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer.close()
|
||||
|
||||
return global_step, tr_loss / global_step
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, prefix=""):
|
||||
dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
|
||||
|
||||
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(dataset)
|
||||
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
# multi-gpu evaluate
|
||||
if args.n_gpu > 1 and not isinstance(model, nn.DataParallel):
|
||||
model = nn.DataParallel(model)
|
||||
|
||||
# Eval!
|
||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||
logger.info(" Num examples = %d", len(dataset))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
|
||||
all_results = []
|
||||
start_time = timeit.default_timer()
|
||||
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
model.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2] # XLM don't use segment_ids
|
||||
example_indices = batch[3]
|
||||
if args.model_type in ["xlnet", "xlm"]:
|
||||
inputs.update({"cls_index": batch[4], "p_mask": batch[5]})
|
||||
|
||||
outputs = model(**inputs)
|
||||
|
||||
for i, example_index in enumerate(example_indices):
|
||||
eval_feature = features[example_index.item()]
|
||||
unique_id = int(eval_feature.unique_id)
|
||||
|
||||
output = [to_list(output[i]) for output in outputs]
|
||||
|
||||
# Some models (XLNet, XLM) use 5 arguments for their predictions, while the other "simpler"
|
||||
# models only use two.
|
||||
if len(output) >= 5:
|
||||
start_logits = output[0]
|
||||
start_top_index = output[1]
|
||||
end_logits = output[2]
|
||||
end_top_index = output[3]
|
||||
cls_logits = output[4]
|
||||
|
||||
result = SquadResult(
|
||||
unique_id,
|
||||
start_logits,
|
||||
end_logits,
|
||||
start_top_index=start_top_index,
|
||||
end_top_index=end_top_index,
|
||||
cls_logits=cls_logits,
|
||||
)
|
||||
|
||||
else:
|
||||
start_logits, end_logits = output
|
||||
result = SquadResult(unique_id, start_logits, end_logits)
|
||||
|
||||
all_results.append(result)
|
||||
|
||||
evalTime = timeit.default_timer() - start_time
|
||||
logger.info(" Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset))
|
||||
|
||||
# Compute predictions
|
||||
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
|
||||
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
|
||||
|
||||
if args.version_2_with_negative:
|
||||
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
|
||||
else:
|
||||
output_null_log_odds_file = None
|
||||
|
||||
if args.model_type in ["xlnet", "xlm"]:
|
||||
# XLNet uses a more complex post-processing procedure
|
||||
predictions = compute_predictions_log_probs(
|
||||
examples,
|
||||
features,
|
||||
all_results,
|
||||
args.n_best_size,
|
||||
args.max_answer_length,
|
||||
output_prediction_file,
|
||||
output_nbest_file,
|
||||
output_null_log_odds_file,
|
||||
model.config.start_n_top,
|
||||
model.config.end_n_top,
|
||||
args.version_2_with_negative,
|
||||
tokenizer,
|
||||
args.verbose_logging,
|
||||
)
|
||||
else:
|
||||
predictions = compute_predictions_logits(
|
||||
examples,
|
||||
features,
|
||||
all_results,
|
||||
args.n_best_size,
|
||||
args.max_answer_length,
|
||||
args.do_lower_case,
|
||||
output_prediction_file,
|
||||
output_nbest_file,
|
||||
output_null_log_odds_file,
|
||||
args.verbose_logging,
|
||||
args.version_2_with_negative,
|
||||
args.null_score_diff_threshold,
|
||||
tokenizer,
|
||||
)
|
||||
|
||||
# Compute the F1 and exact scores.
|
||||
results = squad_evaluate(examples, predictions)
|
||||
return results
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
|
||||
if args.local_rank not in [-1, 0] and not evaluate:
|
||||
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
torch.distributed.barrier()
|
||||
|
||||
# Load data features from cache or dataset file
|
||||
input_file = args.predict_file if evaluate else args.train_file
|
||||
cached_features_file = os.path.join(
|
||||
os.path.dirname(input_file),
|
||||
"cached_distillation_{}_{}_{}".format(
|
||||
"dev" if evaluate else "train",
|
||||
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||||
str(args.max_seq_length),
|
||||
),
|
||||
)
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features_and_dataset = torch.load(cached_features_file)
|
||||
|
||||
try:
|
||||
features, dataset, examples = (
|
||||
features_and_dataset["features"],
|
||||
features_and_dataset["dataset"],
|
||||
features_and_dataset["examples"],
|
||||
)
|
||||
except KeyError:
|
||||
raise DeprecationWarning(
|
||||
"You seem to be loading features from an older version of this script please delete the "
|
||||
"file %s in order for it to be created again" % cached_features_file
|
||||
)
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s", input_file)
|
||||
processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
|
||||
if evaluate:
|
||||
examples = processor.get_dev_examples(args.data_dir, filename=args.predict_file)
|
||||
else:
|
||||
examples = processor.get_train_examples(args.data_dir, filename=args.train_file)
|
||||
|
||||
features, dataset = squad_convert_examples_to_features(
|
||||
examples=examples,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=args.max_seq_length,
|
||||
doc_stride=args.doc_stride,
|
||||
max_query_length=args.max_query_length,
|
||||
is_training=not evaluate,
|
||||
return_dataset="pt",
|
||||
threads=args.threads,
|
||||
)
|
||||
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file)
|
||||
|
||||
if args.local_rank == 0 and not evaluate:
|
||||
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
torch.distributed.barrier()
|
||||
|
||||
if output_examples:
|
||||
return dataset, examples, features
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model checkpoints and predictions will be written.",
|
||||
)
|
||||
|
||||
# Distillation parameters (optional)
|
||||
parser.add_argument(
|
||||
"--teacher_type",
|
||||
default=None,
|
||||
type=str,
|
||||
help=(
|
||||
"Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for"
|
||||
" distillation."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--teacher_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to the already SQuAD fine-tuned teacher model. Only for distillation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha_ce", default=0.5, type=float, help="Distillation loss linear weight. Only for distillation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha_squad", default=0.5, type=float, help="True SQuAD loss linear weight. Only for distillation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature", default=2.0, type=float, help="Distillation temperature. Only for distillation."
|
||||
)
|
||||
|
||||
# Other parameters
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="The input data dir. Should contain the .json files for the task."
|
||||
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_file",
|
||||
default=None,
|
||||
type=str,
|
||||
help="The input training file. If a data dir is specified, will look for the file there"
|
||||
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--predict_file",
|
||||
default=None,
|
||||
type=str,
|
||||
help="The input evaluation file. If a data dir is specified, will look for the file there"
|
||||
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from huggingface.co",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--version_2_with_negative",
|
||||
action="store_true",
|
||||
help="If true, the SQuAD examples contain some that do not have an answer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--null_score_diff_threshold",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="If null_score - best_non_null is greater than the threshold predict null.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=384,
|
||||
type=int,
|
||||
help=(
|
||||
"The maximum total input sequence length after WordPiece tokenization. Sequences "
|
||||
"longer than this will be truncated, and sequences shorter than this will be padded."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--doc_stride",
|
||||
default=128,
|
||||
type=int,
|
||||
help="When splitting up a long document into chunks, how much stride to take between chunks.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_query_length",
|
||||
default=64,
|
||||
type=int,
|
||||
help=(
|
||||
"The maximum number of tokens for the question. Questions longer than this will "
|
||||
"be truncated to this length."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||
parser.add_argument(
|
||||
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||
)
|
||||
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument(
|
||||
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||
)
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||
)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument(
|
||||
"--n_best_size",
|
||||
default=20,
|
||||
type=int,
|
||||
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_answer_length",
|
||||
default=30,
|
||||
type=int,
|
||||
help=(
|
||||
"The maximum length of an answer that can be generated. This is needed because the start "
|
||||
"and end predictions are not conditioned on one another."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose_logging",
|
||||
action="store_true",
|
||||
help=(
|
||||
"If true, all of the warnings related to data processing will be printed. "
|
||||
"A number of warnings are expected for a normal SQuAD evaluation."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument(
|
||||
"--eval_all_checkpoints",
|
||||
action="store_true",
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||
)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
|
||||
parser.add_argument(
|
||||
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
help=(
|
||||
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. "
|
||||
"See details at https://nvidia.github.io/apex/amp.html"
|
||||
),
|
||||
)
|
||||
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
||||
|
||||
parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features")
|
||||
args = parser.parse_args()
|
||||
|
||||
if (
|
||||
os.path.exists(args.output_dir)
|
||||
and os.listdir(args.output_dir)
|
||||
and args.do_train
|
||||
and not args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||
args.output_dir
|
||||
)
|
||||
)
|
||||
|
||||
# Setup distant debugging if needed
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
|
||||
# Setup CUDA, GPU & distributed training
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
|
||||
else: # Initializes the distributed backend which will take care of synchronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank,
|
||||
device,
|
||||
args.n_gpu,
|
||||
bool(args.local_rank != -1),
|
||||
args.fp16,
|
||||
)
|
||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||
if is_main_process(args.local_rank):
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
if args.local_rank not in [-1, 0]:
|
||||
# Make sure only the first process in distributed training will download model & vocab
|
||||
torch.distributed.barrier()
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
|
||||
if args.teacher_type is not None:
|
||||
assert args.teacher_name_or_path is not None
|
||||
assert args.alpha_ce > 0.0
|
||||
assert args.alpha_ce + args.alpha_squad > 0.0
|
||||
assert args.teacher_type != "distilbert", "We constraint teachers not to be of type DistilBERT."
|
||||
teacher_config_class, teacher_model_class, _ = MODEL_CLASSES[args.teacher_type]
|
||||
teacher_config = teacher_config_class.from_pretrained(
|
||||
args.teacher_name_or_path, cache_dir=args.cache_dir if args.cache_dir else None
|
||||
)
|
||||
teacher = teacher_model_class.from_pretrained(
|
||||
args.teacher_name_or_path, config=teacher_config, cache_dir=args.cache_dir if args.cache_dir else None
|
||||
)
|
||||
teacher.to(args.device)
|
||||
else:
|
||||
teacher = None
|
||||
|
||||
if args.local_rank == 0:
|
||||
# Make sure only the first process in distributed training will download model & vocab
|
||||
torch.distributed.barrier()
|
||||
|
||||
model.to(args.device)
|
||||
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
|
||||
# Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
|
||||
# remove the need for this code, but it is still valid.
|
||||
if args.fp16:
|
||||
try:
|
||||
import apex
|
||||
|
||||
apex.amp.register_half_function(torch, "einsum")
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
|
||||
# Training
|
||||
if args.do_train:
|
||||
train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer, teacher=teacher)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
# Save the trained model and the tokenizer
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
if args.do_train:
|
||||
logger.info("Loading checkpoints saved during training for evaluation")
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = [
|
||||
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||
]
|
||||
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
# Reload the model
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluate
|
||||
result = evaluate(args, model, tokenizer, prefix=global_step)
|
||||
|
||||
result = {k + ("_{}".format(global_step) if global_step else ""): v for k, v in result.items()}
|
||||
results.update(result)
|
||||
|
||||
logger.info("Results: {}".format(results))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,97 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Preprocessing script before distillation.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import pickle
|
||||
import random
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import BertTokenizer, GPT2Tokenizer, RobertaTokenizer
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids)."
|
||||
)
|
||||
parser.add_argument("--file_path", type=str, default="data/dump.txt", help="The path to the data.")
|
||||
parser.add_argument("--tokenizer_type", type=str, default="bert", choices=["bert", "roberta", "gpt2"])
|
||||
parser.add_argument("--tokenizer_name", type=str, default="bert-base-uncased", help="The tokenizer to use.")
|
||||
parser.add_argument("--dump_file", type=str, default="data/dump", help="The dump file prefix.")
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info(f"Loading Tokenizer ({args.tokenizer_name})")
|
||||
if args.tokenizer_type == "bert":
|
||||
tokenizer = BertTokenizer.from_pretrained(args.tokenizer_name)
|
||||
bos = tokenizer.special_tokens_map["cls_token"] # `[CLS]`
|
||||
sep = tokenizer.special_tokens_map["sep_token"] # `[SEP]`
|
||||
elif args.tokenizer_type == "roberta":
|
||||
tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name)
|
||||
bos = tokenizer.special_tokens_map["cls_token"] # `<s>`
|
||||
sep = tokenizer.special_tokens_map["sep_token"] # `</s>`
|
||||
elif args.tokenizer_type == "gpt2":
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(args.tokenizer_name)
|
||||
bos = tokenizer.special_tokens_map["bos_token"] # `<|endoftext|>`
|
||||
sep = tokenizer.special_tokens_map["eos_token"] # `<|endoftext|>`
|
||||
|
||||
logger.info(f"Loading text from {args.file_path}")
|
||||
with open(args.file_path, "r", encoding="utf8") as fp:
|
||||
data = fp.readlines()
|
||||
|
||||
logger.info("Start encoding")
|
||||
logger.info(f"{len(data)} examples to process.")
|
||||
|
||||
rslt = []
|
||||
iter = 0
|
||||
interval = 10000
|
||||
start = time.time()
|
||||
for text in data:
|
||||
text = f"{bos} {text.strip()} {sep}"
|
||||
token_ids = tokenizer.encode(text, add_special_tokens=False)
|
||||
rslt.append(token_ids)
|
||||
|
||||
iter += 1
|
||||
if iter % interval == 0:
|
||||
end = time.time()
|
||||
logger.info(f"{iter} examples processed. - {(end-start):.2f}s/{interval}expl")
|
||||
start = time.time()
|
||||
logger.info("Finished binarization")
|
||||
logger.info(f"{len(data)} examples processed.")
|
||||
|
||||
dp_file = f"{args.dump_file}.{args.tokenizer_name}.pickle"
|
||||
vocab_size = tokenizer.vocab_size
|
||||
if vocab_size < (1 << 16):
|
||||
rslt_ = [np.uint16(d) for d in rslt]
|
||||
else:
|
||||
rslt_ = [np.int32(d) for d in rslt]
|
||||
random.shuffle(rslt_)
|
||||
logger.info(f"Dump to {dp_file}")
|
||||
with open(dp_file, "wb") as handle:
|
||||
pickle.dump(rslt_, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,106 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Preprocessing script before training the distilled model.
|
||||
Specific to RoBERTa -> DistilRoBERTa and GPT2 -> DistilGPT2.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import GPT2LMHeadModel, RobertaForMaskedLM
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned"
|
||||
" Distillation"
|
||||
)
|
||||
)
|
||||
parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"])
|
||||
parser.add_argument("--model_name", default="roberta-large", type=str)
|
||||
parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_roberta_048131723.pth", type=str)
|
||||
parser.add_argument("--vocab_transform", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model_type == "roberta":
|
||||
model = RobertaForMaskedLM.from_pretrained(args.model_name)
|
||||
prefix = "roberta"
|
||||
elif args.model_type == "gpt2":
|
||||
model = GPT2LMHeadModel.from_pretrained(args.model_name)
|
||||
prefix = "transformer"
|
||||
|
||||
state_dict = model.state_dict()
|
||||
compressed_sd = {}
|
||||
|
||||
# Embeddings #
|
||||
if args.model_type == "gpt2":
|
||||
for param_name in ["wte.weight", "wpe.weight"]:
|
||||
compressed_sd[f"{prefix}.{param_name}"] = state_dict[f"{prefix}.{param_name}"]
|
||||
else:
|
||||
for w in ["word_embeddings", "position_embeddings", "token_type_embeddings"]:
|
||||
param_name = f"{prefix}.embeddings.{w}.weight"
|
||||
compressed_sd[param_name] = state_dict[param_name]
|
||||
for w in ["weight", "bias"]:
|
||||
param_name = f"{prefix}.embeddings.LayerNorm.{w}"
|
||||
compressed_sd[param_name] = state_dict[param_name]
|
||||
|
||||
# Transformer Blocks #
|
||||
std_idx = 0
|
||||
for teacher_idx in [0, 2, 4, 7, 9, 11]:
|
||||
if args.model_type == "gpt2":
|
||||
for layer in ["ln_1", "attn.c_attn", "attn.c_proj", "ln_2", "mlp.c_fc", "mlp.c_proj"]:
|
||||
for w in ["weight", "bias"]:
|
||||
compressed_sd[f"{prefix}.h.{std_idx}.{layer}.{w}"] = state_dict[
|
||||
f"{prefix}.h.{teacher_idx}.{layer}.{w}"
|
||||
]
|
||||
compressed_sd[f"{prefix}.h.{std_idx}.attn.bias"] = state_dict[f"{prefix}.h.{teacher_idx}.attn.bias"]
|
||||
else:
|
||||
for layer in [
|
||||
"attention.self.query",
|
||||
"attention.self.key",
|
||||
"attention.self.value",
|
||||
"attention.output.dense",
|
||||
"attention.output.LayerNorm",
|
||||
"intermediate.dense",
|
||||
"output.dense",
|
||||
"output.LayerNorm",
|
||||
]:
|
||||
for w in ["weight", "bias"]:
|
||||
compressed_sd[f"{prefix}.encoder.layer.{std_idx}.{layer}.{w}"] = state_dict[
|
||||
f"{prefix}.encoder.layer.{teacher_idx}.{layer}.{w}"
|
||||
]
|
||||
std_idx += 1
|
||||
|
||||
# Language Modeling Head ###s
|
||||
if args.model_type == "roberta":
|
||||
for layer in ["lm_head.decoder.weight", "lm_head.bias"]:
|
||||
compressed_sd[f"{layer}"] = state_dict[f"{layer}"]
|
||||
if args.vocab_transform:
|
||||
for w in ["weight", "bias"]:
|
||||
compressed_sd[f"lm_head.dense.{w}"] = state_dict[f"lm_head.dense.{w}"]
|
||||
compressed_sd[f"lm_head.layer_norm.{w}"] = state_dict[f"lm_head.layer_norm.{w}"]
|
||||
elif args.model_type == "gpt2":
|
||||
for w in ["weight", "bias"]:
|
||||
compressed_sd[f"{prefix}.ln_f.{w}"] = state_dict[f"{prefix}.ln_f.{w}"]
|
||||
compressed_sd["lm_head.weight"] = state_dict["lm_head.weight"]
|
||||
|
||||
print(f"N layers selected for distillation: {std_idx}")
|
||||
print(f"Number of params transferred for distillation: {len(compressed_sd.keys())}")
|
||||
|
||||
print(f"Save transferred checkpoint to {args.dump_checkpoint}.")
|
||||
torch.save(compressed_sd, args.dump_checkpoint)
|
||||
@@ -1,96 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Preprocessing script before training DistilBERT.
|
||||
Specific to BERT -> DistilBERT.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BertForMaskedLM
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned"
|
||||
" Distillation"
|
||||
)
|
||||
)
|
||||
parser.add_argument("--model_type", default="bert", choices=["bert"])
|
||||
parser.add_argument("--model_name", default="bert-base-uncased", type=str)
|
||||
parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_bert-base-uncased_0247911.pth", type=str)
|
||||
parser.add_argument("--vocab_transform", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model_type == "bert":
|
||||
model = BertForMaskedLM.from_pretrained(args.model_name)
|
||||
prefix = "bert"
|
||||
else:
|
||||
raise ValueError('args.model_type should be "bert".')
|
||||
|
||||
state_dict = model.state_dict()
|
||||
compressed_sd = {}
|
||||
|
||||
for w in ["word_embeddings", "position_embeddings"]:
|
||||
compressed_sd[f"distilbert.embeddings.{w}.weight"] = state_dict[f"{prefix}.embeddings.{w}.weight"]
|
||||
for w in ["weight", "bias"]:
|
||||
compressed_sd[f"distilbert.embeddings.LayerNorm.{w}"] = state_dict[f"{prefix}.embeddings.LayerNorm.{w}"]
|
||||
|
||||
std_idx = 0
|
||||
for teacher_idx in [0, 2, 4, 7, 9, 11]:
|
||||
for w in ["weight", "bias"]:
|
||||
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}"] = state_dict[
|
||||
f"{prefix}.encoder.layer.{teacher_idx}.attention.self.query.{w}"
|
||||
]
|
||||
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}"] = state_dict[
|
||||
f"{prefix}.encoder.layer.{teacher_idx}.attention.self.key.{w}"
|
||||
]
|
||||
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}"] = state_dict[
|
||||
f"{prefix}.encoder.layer.{teacher_idx}.attention.self.value.{w}"
|
||||
]
|
||||
|
||||
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}"] = state_dict[
|
||||
f"{prefix}.encoder.layer.{teacher_idx}.attention.output.dense.{w}"
|
||||
]
|
||||
compressed_sd[f"distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}"] = state_dict[
|
||||
f"{prefix}.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}"
|
||||
]
|
||||
|
||||
compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}"] = state_dict[
|
||||
f"{prefix}.encoder.layer.{teacher_idx}.intermediate.dense.{w}"
|
||||
]
|
||||
compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}"] = state_dict[
|
||||
f"{prefix}.encoder.layer.{teacher_idx}.output.dense.{w}"
|
||||
]
|
||||
compressed_sd[f"distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}"] = state_dict[
|
||||
f"{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}"
|
||||
]
|
||||
std_idx += 1
|
||||
|
||||
compressed_sd["vocab_projector.weight"] = state_dict["cls.predictions.decoder.weight"]
|
||||
compressed_sd["vocab_projector.bias"] = state_dict["cls.predictions.bias"]
|
||||
if args.vocab_transform:
|
||||
for w in ["weight", "bias"]:
|
||||
compressed_sd[f"vocab_transform.{w}"] = state_dict[f"cls.predictions.transform.dense.{w}"]
|
||||
compressed_sd[f"vocab_layer_norm.{w}"] = state_dict[f"cls.predictions.transform.LayerNorm.{w}"]
|
||||
|
||||
print(f"N layers selected for distillation: {std_idx}")
|
||||
print(f"Number of params transferred for distillation: {len(compressed_sd.keys())}")
|
||||
|
||||
print(f"Save transferred checkpoint to {args.dump_checkpoint}.")
|
||||
torch.save(compressed_sd, args.dump_checkpoint)
|
||||
@@ -1,57 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Preprocessing script before training the distilled model.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import pickle
|
||||
from collections import Counter
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Token Counts for smoothing the masking probabilities in MLM (cf XLM/word2vec)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_file", type=str, default="data/dump.bert-base-uncased.pickle", help="The binarized dataset."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token_counts_dump", type=str, default="data/token_counts.bert-base-uncased.pickle", help="The dump file."
|
||||
)
|
||||
parser.add_argument("--vocab_size", default=30522, type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info(f"Loading data from {args.data_file}")
|
||||
with open(args.data_file, "rb") as fp:
|
||||
data = pickle.load(fp)
|
||||
|
||||
logger.info("Counting occurrences for MLM.")
|
||||
counter = Counter()
|
||||
for tk_ids in data:
|
||||
counter.update(tk_ids)
|
||||
counts = [0] * args.vocab_size
|
||||
for k, v in counter.items():
|
||||
counts[k] = v
|
||||
|
||||
logger.info(f"Dump to {args.token_counts_dump}")
|
||||
with open(args.token_counts_dump, "wb") as handle:
|
||||
pickle.dump(counts, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
@@ -1,325 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Training the distilled model.
|
||||
Supported architectures include: BERT -> DistilBERT, RoBERTa -> DistilRoBERTa, GPT2 -> DistilGPT2.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from distiller import Distiller
|
||||
from lm_seqs_dataset import LmSeqsDataset
|
||||
|
||||
from transformers import (
|
||||
BertConfig,
|
||||
BertForMaskedLM,
|
||||
BertTokenizer,
|
||||
DistilBertConfig,
|
||||
DistilBertForMaskedLM,
|
||||
DistilBertTokenizer,
|
||||
GPT2Config,
|
||||
GPT2LMHeadModel,
|
||||
GPT2Tokenizer,
|
||||
RobertaConfig,
|
||||
RobertaForMaskedLM,
|
||||
RobertaTokenizer,
|
||||
)
|
||||
from utils import git_log, init_gpu_params, logger, set_seed
|
||||
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
|
||||
"roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
|
||||
"bert": (BertConfig, BertForMaskedLM, BertTokenizer),
|
||||
"gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
||||
}
|
||||
|
||||
|
||||
def sanity_checks(args):
|
||||
"""
|
||||
A bunch of args sanity checks to perform even starting...
|
||||
"""
|
||||
assert (args.mlm and args.alpha_mlm > 0.0) or (not args.mlm and args.alpha_mlm == 0.0)
|
||||
assert (args.alpha_mlm > 0.0 and args.alpha_clm == 0.0) or (args.alpha_mlm == 0.0 and args.alpha_clm > 0.0)
|
||||
if args.mlm:
|
||||
assert os.path.isfile(args.token_counts)
|
||||
assert (args.student_type in ["roberta", "distilbert"]) and (args.teacher_type in ["roberta", "bert"])
|
||||
else:
|
||||
assert (args.student_type in ["gpt2"]) and (args.teacher_type in ["gpt2"])
|
||||
|
||||
assert args.teacher_type == args.student_type or (
|
||||
args.student_type == "distilbert" and args.teacher_type == "bert"
|
||||
)
|
||||
assert os.path.isfile(args.student_config)
|
||||
if args.student_pretrained_weights is not None:
|
||||
assert os.path.isfile(args.student_pretrained_weights)
|
||||
|
||||
if args.freeze_token_type_embds:
|
||||
assert args.student_type in ["roberta"]
|
||||
|
||||
assert args.alpha_ce >= 0.0
|
||||
assert args.alpha_mlm >= 0.0
|
||||
assert args.alpha_clm >= 0.0
|
||||
assert args.alpha_mse >= 0.0
|
||||
assert args.alpha_cos >= 0.0
|
||||
assert args.alpha_ce + args.alpha_mlm + args.alpha_clm + args.alpha_mse + args.alpha_cos > 0.0
|
||||
|
||||
|
||||
def freeze_pos_embeddings(student, args):
|
||||
if args.student_type == "roberta":
|
||||
student.roberta.embeddings.position_embeddings.weight.requires_grad = False
|
||||
elif args.student_type == "gpt2":
|
||||
student.transformer.wpe.weight.requires_grad = False
|
||||
|
||||
|
||||
def freeze_token_type_embeddings(student, args):
|
||||
if args.student_type == "roberta":
|
||||
student.roberta.embeddings.token_type_embeddings.weight.requires_grad = False
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Training")
|
||||
parser.add_argument("--force", action="store_true", help="Overwrite dump_path if it already exists.")
|
||||
|
||||
parser.add_argument(
|
||||
"--dump_path", type=str, required=True, help="The output directory (log, checkpoints, parameters, etc.)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The binarized file (tokenized + tokens_to_ids) and grouped by sequence.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--student_type",
|
||||
type=str,
|
||||
choices=["distilbert", "roberta", "gpt2"],
|
||||
required=True,
|
||||
help="The student type (DistilBERT, RoBERTa).",
|
||||
)
|
||||
parser.add_argument("--student_config", type=str, required=True, help="Path to the student configuration.")
|
||||
parser.add_argument(
|
||||
"--student_pretrained_weights", default=None, type=str, help="Load student initialization checkpoint."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--teacher_type", choices=["bert", "roberta", "gpt2"], required=True, help="Teacher type (BERT, RoBERTa)."
|
||||
)
|
||||
parser.add_argument("--teacher_name", type=str, required=True, help="The teacher model.")
|
||||
|
||||
parser.add_argument("--temperature", default=2.0, type=float, help="Temperature for the softmax temperature.")
|
||||
parser.add_argument(
|
||||
"--alpha_ce", default=0.5, type=float, help="Linear weight for the distillation loss. Must be >=0."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha_mlm",
|
||||
default=0.0,
|
||||
type=float,
|
||||
help="Linear weight for the MLM loss. Must be >=0. Should be used in conjunction with `mlm` flag.",
|
||||
)
|
||||
parser.add_argument("--alpha_clm", default=0.5, type=float, help="Linear weight for the CLM loss. Must be >=0.")
|
||||
parser.add_argument("--alpha_mse", default=0.0, type=float, help="Linear weight of the MSE loss. Must be >=0.")
|
||||
parser.add_argument(
|
||||
"--alpha_cos", default=0.0, type=float, help="Linear weight of the cosine embedding loss. Must be >=0."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mlm", action="store_true", help="The LM step: MLM or CLM. If `mlm` is True, the MLM is used over CLM."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlm_mask_prop",
|
||||
default=0.15,
|
||||
type=float,
|
||||
help="Proportion of tokens for which we need to make a prediction.",
|
||||
)
|
||||
parser.add_argument("--word_mask", default=0.8, type=float, help="Proportion of tokens to mask out.")
|
||||
parser.add_argument("--word_keep", default=0.1, type=float, help="Proportion of tokens to keep.")
|
||||
parser.add_argument("--word_rand", default=0.1, type=float, help="Proportion of tokens to randomly replace.")
|
||||
parser.add_argument(
|
||||
"--mlm_smoothing",
|
||||
default=0.7,
|
||||
type=float,
|
||||
help="Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).",
|
||||
)
|
||||
parser.add_argument("--token_counts", type=str, help="The token counts in the data_file for MLM.")
|
||||
|
||||
parser.add_argument(
|
||||
"--restrict_ce_to_mask",
|
||||
action="store_true",
|
||||
help="If true, compute the distillation loss only the [MLM] prediction distribution.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--freeze_pos_embs",
|
||||
action="store_true",
|
||||
help="Freeze positional embeddings during distillation. For student_type in ['roberta', 'gpt2'] only.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--freeze_token_type_embds",
|
||||
action="store_true",
|
||||
help="Freeze token type embeddings during distillation if existent. For student_type in ['roberta'] only.",
|
||||
)
|
||||
|
||||
parser.add_argument("--n_epoch", type=int, default=3, help="Number of pass on the whole dataset.")
|
||||
parser.add_argument("--batch_size", type=int, default=5, help="Batch size (for each process).")
|
||||
parser.add_argument(
|
||||
"--group_by_size",
|
||||
action="store_false",
|
||||
help="If true, group sequences that have similar length into the same batch. Default is true.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Gradient accumulation for larger training batches.",
|
||||
)
|
||||
parser.add_argument("--warmup_prop", default=0.05, type=float, help="Linear warmup proportion.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||
parser.add_argument("--learning_rate", default=5e-4, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-6, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=5.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument("--initializer_range", default=0.02, type=float, help="Random initialization range.")
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
help=(
|
||||
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. "
|
||||
"See details at https://nvidia.github.io/apex/amp.html"
|
||||
),
|
||||
)
|
||||
parser.add_argument("--n_gpu", type=int, default=1, help="Number of GPUs in the node.")
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="Distributed training - Local rank")
|
||||
parser.add_argument("--seed", type=int, default=56, help="Random seed")
|
||||
|
||||
parser.add_argument("--log_interval", type=int, default=500, help="Tensorboard logging interval.")
|
||||
parser.add_argument("--checkpoint_interval", type=int, default=4000, help="Checkpoint interval.")
|
||||
args = parser.parse_args()
|
||||
sanity_checks(args)
|
||||
|
||||
# ARGS #
|
||||
init_gpu_params(args)
|
||||
set_seed(args)
|
||||
if args.is_master:
|
||||
if os.path.exists(args.dump_path):
|
||||
if not args.force:
|
||||
raise ValueError(
|
||||
f"Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite"
|
||||
" itUse `--force` if you want to overwrite it"
|
||||
)
|
||||
else:
|
||||
shutil.rmtree(args.dump_path)
|
||||
|
||||
if not os.path.exists(args.dump_path):
|
||||
os.makedirs(args.dump_path)
|
||||
logger.info(f"Experiment will be dumped and logged in {args.dump_path}")
|
||||
|
||||
# SAVE PARAMS #
|
||||
logger.info(f"Param: {args}")
|
||||
with open(os.path.join(args.dump_path, "parameters.json"), "w") as f:
|
||||
json.dump(vars(args), f, indent=4)
|
||||
git_log(args.dump_path)
|
||||
|
||||
student_config_class, student_model_class, _ = MODEL_CLASSES[args.student_type]
|
||||
teacher_config_class, teacher_model_class, teacher_tokenizer_class = MODEL_CLASSES[args.teacher_type]
|
||||
|
||||
# TOKENIZER #
|
||||
tokenizer = teacher_tokenizer_class.from_pretrained(args.teacher_name)
|
||||
special_tok_ids = {}
|
||||
for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
|
||||
idx = tokenizer.all_special_tokens.index(tok_symbol)
|
||||
special_tok_ids[tok_name] = tokenizer.all_special_ids[idx]
|
||||
logger.info(f"Special tokens {special_tok_ids}")
|
||||
args.special_tok_ids = special_tok_ids
|
||||
args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name]
|
||||
|
||||
# DATA LOADER #
|
||||
logger.info(f"Loading data from {args.data_file}")
|
||||
with open(args.data_file, "rb") as fp:
|
||||
data = pickle.load(fp)
|
||||
|
||||
if args.mlm:
|
||||
logger.info(f"Loading token counts from {args.token_counts} (already pre-computed)")
|
||||
with open(args.token_counts, "rb") as fp:
|
||||
counts = pickle.load(fp)
|
||||
|
||||
token_probs = np.maximum(counts, 1) ** -args.mlm_smoothing
|
||||
for idx in special_tok_ids.values():
|
||||
token_probs[idx] = 0.0 # do not predict special tokens
|
||||
token_probs = torch.from_numpy(token_probs)
|
||||
else:
|
||||
token_probs = None
|
||||
|
||||
train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)
|
||||
logger.info("Data loader created.")
|
||||
|
||||
# STUDENT #
|
||||
logger.info(f"Loading student config from {args.student_config}")
|
||||
stu_architecture_config = student_config_class.from_pretrained(args.student_config)
|
||||
stu_architecture_config.output_hidden_states = True
|
||||
|
||||
if args.student_pretrained_weights is not None:
|
||||
logger.info(f"Loading pretrained weights from {args.student_pretrained_weights}")
|
||||
student = student_model_class.from_pretrained(args.student_pretrained_weights, config=stu_architecture_config)
|
||||
else:
|
||||
student = student_model_class(stu_architecture_config)
|
||||
|
||||
if args.n_gpu > 0:
|
||||
student.to(f"cuda:{args.local_rank}")
|
||||
logger.info("Student loaded.")
|
||||
|
||||
# TEACHER #
|
||||
teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True)
|
||||
if args.n_gpu > 0:
|
||||
teacher.to(f"cuda:{args.local_rank}")
|
||||
logger.info(f"Teacher loaded from {args.teacher_name}.")
|
||||
|
||||
# FREEZING #
|
||||
if args.freeze_pos_embs:
|
||||
freeze_pos_embeddings(student, args)
|
||||
if args.freeze_token_type_embds:
|
||||
freeze_token_type_embeddings(student, args)
|
||||
|
||||
# SANITY CHECKS #
|
||||
assert student.config.vocab_size == teacher.config.vocab_size
|
||||
assert student.config.hidden_size == teacher.config.hidden_size
|
||||
assert student.config.max_position_embeddings == teacher.config.max_position_embeddings
|
||||
if args.mlm:
|
||||
assert token_probs.size(0) == stu_architecture_config.vocab_size
|
||||
|
||||
# DISTILLER #
|
||||
torch.cuda.empty_cache()
|
||||
distiller = Distiller(
|
||||
params=args, dataset=train_lm_seq_dataset, token_probs=token_probs, student=student, teacher=teacher
|
||||
)
|
||||
distiller.train()
|
||||
logger.info("Let's go get some drinks.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,15 +0,0 @@
|
||||
{
|
||||
"activation": "gelu",
|
||||
"attention_dropout": 0.1,
|
||||
"dim": 768,
|
||||
"dropout": 0.1,
|
||||
"hidden_dim": 3072,
|
||||
"initializer_range": 0.02,
|
||||
"max_position_embeddings": 512,
|
||||
"n_heads": 12,
|
||||
"n_layers": 6,
|
||||
"sinusoidal_pos_embds": true,
|
||||
"tie_weights_": true,
|
||||
"vocab_size": 28996
|
||||
}
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
{
|
||||
"activation": "gelu",
|
||||
"attention_dropout": 0.1,
|
||||
"dim": 768,
|
||||
"dropout": 0.1,
|
||||
"hidden_dim": 3072,
|
||||
"initializer_range": 0.02,
|
||||
"max_position_embeddings": 512,
|
||||
"n_heads": 12,
|
||||
"n_layers": 6,
|
||||
"sinusoidal_pos_embds": true,
|
||||
"tie_weights_": true,
|
||||
"vocab_size": 119547
|
||||
}
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
{
|
||||
"activation": "gelu",
|
||||
"attention_dropout": 0.1,
|
||||
"dim": 768,
|
||||
"dropout": 0.1,
|
||||
"hidden_dim": 3072,
|
||||
"initializer_range": 0.02,
|
||||
"max_position_embeddings": 512,
|
||||
"n_heads": 12,
|
||||
"n_layers": 6,
|
||||
"sinusoidal_pos_embds": true,
|
||||
"tie_weights_": true,
|
||||
"vocab_size": 30522
|
||||
}
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
{
|
||||
"initializer_range": 0.02,
|
||||
"layer_norm_epsilon": 0.00001,
|
||||
"n_embd": 768,
|
||||
"n_head": 12,
|
||||
"n_layer": 6,
|
||||
"n_positions": 1024,
|
||||
"vocab_size": 50257
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
{
|
||||
"vocab_size": 50265,
|
||||
"hidden_size": 768,
|
||||
"num_hidden_layers": 6,
|
||||
"num_attention_heads": 12,
|
||||
"intermediate_size": 3072,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"max_position_embeddings": 514,
|
||||
"type_vocab_size": 1,
|
||||
"initializer_range": 0.02,
|
||||
"layer_norm_eps": 0.00001
|
||||
}
|
||||
@@ -1,134 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utils to train DistilBERT
|
||||
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
|
||||
import git
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def git_log(folder_path: str):
|
||||
"""
|
||||
Log commit info.
|
||||
"""
|
||||
repo = git.Repo(search_parent_directories=True)
|
||||
repo_infos = {
|
||||
"repo_id": str(repo),
|
||||
"repo_sha": str(repo.head.object.hexsha),
|
||||
"repo_branch": str(repo.active_branch),
|
||||
}
|
||||
|
||||
with open(os.path.join(folder_path, "git_log.json"), "w") as f:
|
||||
json.dump(repo_infos, f, indent=4)
|
||||
|
||||
|
||||
def init_gpu_params(params):
|
||||
"""
|
||||
Handle single and multi-GPU / multi-node.
|
||||
"""
|
||||
if params.n_gpu <= 0:
|
||||
params.local_rank = 0
|
||||
params.master_port = -1
|
||||
params.is_master = True
|
||||
params.multi_gpu = False
|
||||
return
|
||||
|
||||
assert torch.cuda.is_available()
|
||||
|
||||
logger.info("Initializing GPUs")
|
||||
if params.n_gpu > 1:
|
||||
assert params.local_rank != -1
|
||||
|
||||
params.world_size = int(os.environ["WORLD_SIZE"])
|
||||
params.n_gpu_per_node = int(os.environ["N_GPU_NODE"])
|
||||
params.global_rank = int(os.environ["RANK"])
|
||||
|
||||
# number of nodes / node ID
|
||||
params.n_nodes = params.world_size // params.n_gpu_per_node
|
||||
params.node_id = params.global_rank // params.n_gpu_per_node
|
||||
params.multi_gpu = True
|
||||
|
||||
assert params.n_nodes == int(os.environ["N_NODES"])
|
||||
assert params.node_id == int(os.environ["NODE_RANK"])
|
||||
|
||||
# local job (single GPU)
|
||||
else:
|
||||
assert params.local_rank == -1
|
||||
|
||||
params.n_nodes = 1
|
||||
params.node_id = 0
|
||||
params.local_rank = 0
|
||||
params.global_rank = 0
|
||||
params.world_size = 1
|
||||
params.n_gpu_per_node = 1
|
||||
params.multi_gpu = False
|
||||
|
||||
# sanity checks
|
||||
assert params.n_nodes >= 1
|
||||
assert 0 <= params.node_id < params.n_nodes
|
||||
assert 0 <= params.local_rank <= params.global_rank < params.world_size
|
||||
assert params.world_size == params.n_nodes * params.n_gpu_per_node
|
||||
|
||||
# define whether this is the master process / if we are in multi-node distributed mode
|
||||
params.is_master = params.node_id == 0 and params.local_rank == 0
|
||||
params.multi_node = params.n_nodes > 1
|
||||
|
||||
# summary
|
||||
PREFIX = f"--- Global rank: {params.global_rank} - "
|
||||
logger.info(PREFIX + "Number of nodes: %i" % params.n_nodes)
|
||||
logger.info(PREFIX + "Node ID : %i" % params.node_id)
|
||||
logger.info(PREFIX + "Local rank : %i" % params.local_rank)
|
||||
logger.info(PREFIX + "World size : %i" % params.world_size)
|
||||
logger.info(PREFIX + "GPUs per node : %i" % params.n_gpu_per_node)
|
||||
logger.info(PREFIX + "Master : %s" % str(params.is_master))
|
||||
logger.info(PREFIX + "Multi-node : %s" % str(params.multi_node))
|
||||
logger.info(PREFIX + "Multi-GPU : %s" % str(params.multi_gpu))
|
||||
logger.info(PREFIX + "Hostname : %s" % socket.gethostname())
|
||||
|
||||
# set GPU device
|
||||
torch.cuda.set_device(params.local_rank)
|
||||
|
||||
# initialize multi-GPU
|
||||
if params.multi_gpu:
|
||||
logger.info("Initializing PyTorch distributed")
|
||||
torch.distributed.init_process_group(
|
||||
init_method="env://",
|
||||
backend="nccl",
|
||||
)
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
"""
|
||||
Set the random seed.
|
||||
"""
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
@@ -1,88 +0,0 @@
|
||||
<p align="center"> <img src="http://sayef.tech:8082/uploads/FSNER-LOGO-2.png" alt="FSNER LOGO"> </p>
|
||||
|
||||
<p align="center">
|
||||
Implemented by <a href="https://huggingface.co/sayef"> sayef </a>.
|
||||
</p>
|
||||
|
||||
## Overview
|
||||
|
||||
The FSNER model was proposed in [Example-Based Named Entity Recognition](https://arxiv.org/abs/2008.10570) by Morteza Ziyadi, Yuting Sun, Abhishek Goswami, Jade Huang, Weizhu Chen. To identify entity spans in a new domain, it uses a train-free few-shot learning approach inspired by question-answering.
|
||||
|
||||
|
||||
|
||||
## Abstract
|
||||
----
|
||||
> We present a novel approach to named entity recognition (NER) in the presence of scarce data that we call example-based NER. Our train-free few-shot learning approach takes inspiration from question-answering to identify entity spans in a new and unseen domain. In comparison with the current state-of-the-art, the proposed method performs significantly better, especially when using a low number of support examples.
|
||||
|
||||
|
||||
|
||||
## Model Training Details
|
||||
-----
|
||||
|
||||
| identifier | epochs | datasets |
|
||||
| ---------- |:----------:| :-----:|
|
||||
| [sayef/fsner-bert-base-uncased](https://huggingface.co/sayef/fsner-bert-base-uncased) | 10 | ontonotes5, conll2003, wnut2017, and fin (Alvarado et al.). |
|
||||
|
||||
|
||||
## Installation and Example Usage
|
||||
------
|
||||
|
||||
You can use the FSNER model in 3 ways:
|
||||
|
||||
1. Install directly from PyPI: `pip install fsner` and import the model as shown in the code example below
|
||||
|
||||
or
|
||||
|
||||
2. Install from source: `python setup.py install` and import the model as shown in the code example below
|
||||
|
||||
or
|
||||
|
||||
3. Clone repo and change directory to `src` and import the model as shown in the code example below
|
||||
|
||||
|
||||
|
||||
```python
|
||||
from fsner import FSNERModel, FSNERTokenizerUtils
|
||||
|
||||
model = FSNERModel("sayef/fsner-bert-base-uncased")
|
||||
|
||||
tokenizer = FSNERTokenizerUtils("sayef/fsner-bert-base-uncased")
|
||||
|
||||
# size of query and supports must be the same. If you want to find all the entitites in one particular query, just repeat the same query n times where n is equal to the number of supports (or entities).
|
||||
|
||||
|
||||
query = [
|
||||
'KWE 4000 can reach with a maximum speed from up to 450 P/min an accuracy from 50 mg',
|
||||
'I would like to order a computer from eBay.',
|
||||
]
|
||||
|
||||
# each list in supports are the examples of one entity type
|
||||
# wrap entities around with [E] and [/E] in the examples
|
||||
|
||||
supports = [
|
||||
[
|
||||
'Horizontal flow wrapper [E] Pack 403 [/E] features the new retrofit-kit „paper-ON-form“',
|
||||
'[E] Paloma Pick-and-Place-Roboter [/E] arranges the bakery products for the downstream tray-forming equipment',
|
||||
'Finally, the new [E] Kliklok ACE [/E] carton former forms cartons and trays without the use of glue',
|
||||
'We set up our pilot plant with the right [E] FibreForm® [/E] configuration to make prototypes for your marketing tests and package validation',
|
||||
'The [E] CAR-T5 [/E] is a reliable, purely mechanically driven cartoning machine for versatile application fields'
|
||||
],
|
||||
[
|
||||
"[E] Walmart [/E] is a leading e-commerce company",
|
||||
"I recently ordered a book from [E] Amazon [/E]",
|
||||
"I ordered this from [E] ShopClues [/E]",
|
||||
"[E] Flipkart [/E] started it's journey from zero"
|
||||
]
|
||||
]
|
||||
|
||||
device = 'cpu'
|
||||
|
||||
W_query = tokenizer.tokenize(query).to(device)
|
||||
W_supports = tokenizer.tokenize(supports).to(device)
|
||||
|
||||
start_prob, end_prob = model(W_query, W_supports)
|
||||
|
||||
output = tokenizer.extract_entity_from_scores(query, W_query, start_prob, end_prob, thresh=0.50)
|
||||
|
||||
print(output)
|
||||
```
|
||||
@@ -1,7 +0,0 @@
|
||||
[build-system]
|
||||
requires = [
|
||||
"setuptools>=57.4.0",
|
||||
"wheel>=0.37.0",
|
||||
"transformers>=4.9.2"
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
@@ -1 +0,0 @@
|
||||
transformers>=4.9.2
|
||||
@@ -1,27 +0,0 @@
|
||||
import setuptools
|
||||
|
||||
|
||||
with open("README.md", "r", encoding="utf-8") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
setuptools.setup(
|
||||
name="fsner",
|
||||
version="0.0.1",
|
||||
author="msi sayef",
|
||||
author_email="msi.sayef@gmail.com",
|
||||
description="Few-shot Named Entity Recognition",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/huggingface/transformers/tree/main/examples/research_projects/fsner",
|
||||
project_urls={
|
||||
"Bug Tracker": "https://github.com/huggingface/transformers/issues",
|
||||
},
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"Operating System :: OS Independent",
|
||||
],
|
||||
package_dir={"": "src"},
|
||||
packages=setuptools.find_packages(where="src"),
|
||||
python_requires=">=3.6",
|
||||
install_requires=["torch>=1.9.0", "transformers>=4.9.2"],
|
||||
)
|
||||
@@ -1,5 +0,0 @@
|
||||
from .model import FSNERModel
|
||||
from .tokenizer_utils import FSNERTokenizerUtils
|
||||
|
||||
|
||||
__all__ = ["FSNERModel", "FSNERTokenizerUtils"]
|
||||
@@ -1,80 +0,0 @@
|
||||
import torch
|
||||
|
||||
from transformers import AutoModel
|
||||
|
||||
|
||||
class FSNERModel(torch.nn.Module):
|
||||
"""
|
||||
The FSNER model implements a few-shot named entity recognition method from the paper `Example-Based Named Entity Recognition <https://arxiv.org/abs/2008.10570>`__ by
|
||||
Morteza Ziyadi, Yuting Sun, Abhishek Goswami, Jade Huang, Weizhu Chen. To identify entity spans in a new domain, it
|
||||
uses a train-free few-shot learning approach inspired by question-answering.
|
||||
"""
|
||||
|
||||
def __init__(self, pretrained_model_name_or_path="sayef/fsner-bert-base-uncased"):
|
||||
super(FSNERModel, self).__init__()
|
||||
|
||||
self.bert = AutoModel.from_pretrained(pretrained_model_name_or_path, return_dict=True)
|
||||
self.cos = torch.nn.CosineSimilarity(3, 1e-08)
|
||||
self.softmax = torch.nn.Softmax(dim=1)
|
||||
|
||||
def BERT(self, **inputs):
|
||||
return self.bert(**inputs).last_hidden_state
|
||||
|
||||
def VectorSum(self, token_embeddings):
|
||||
return token_embeddings.sum(2, keepdim=True)
|
||||
|
||||
def Atten(self, q_rep, S_rep, T=1):
|
||||
return self.softmax(T * self.cos(q_rep, S_rep))
|
||||
|
||||
def forward(self, W_query, W_supports):
|
||||
"""
|
||||
Find scores of each token being start and end token for an entity.
|
||||
Args:
|
||||
W_query (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of query sequence tokens in the vocabulary.
|
||||
W_supports (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of support sequence tokens in the vocabulary.
|
||||
Returns:
|
||||
p_start (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): Scores of each token as
|
||||
being start token of an entity
|
||||
p_end (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): Scores of each token as
|
||||
being end token of an entity
|
||||
"""
|
||||
|
||||
support_sizes = W_supports["sizes"].tolist()
|
||||
start_token_id = W_supports["start_token_id"].item()
|
||||
end_token_id = W_supports["end_token_id"].item()
|
||||
|
||||
del W_supports["sizes"]
|
||||
del W_supports["start_token_id"]
|
||||
del W_supports["end_token_id"]
|
||||
|
||||
q = self.BERT(**W_query)
|
||||
S = self.BERT(**W_supports)
|
||||
|
||||
p_starts = None
|
||||
p_ends = None
|
||||
|
||||
start_token_masks = W_supports["input_ids"] == start_token_id
|
||||
end_token_masks = W_supports["input_ids"] == end_token_id
|
||||
|
||||
for i, size in enumerate(support_sizes):
|
||||
if i == 0:
|
||||
s = 0
|
||||
else:
|
||||
s = support_sizes[i - 1]
|
||||
|
||||
s_start = S[s : s + size][start_token_masks[s : s + size]]
|
||||
s_end = S[s : s + size][end_token_masks[s : s + size]]
|
||||
|
||||
p_start = torch.matmul(q[i], s_start.T).sum(1).softmax(0)
|
||||
p_end = torch.matmul(q[i], s_end.T).sum(1).softmax(0)
|
||||
|
||||
if p_starts is not None:
|
||||
p_starts = torch.vstack((p_starts, p_start))
|
||||
p_ends = torch.vstack((p_ends, p_end))
|
||||
else:
|
||||
p_starts = p_start
|
||||
p_ends = p_end
|
||||
|
||||
return p_starts, p_ends
|
||||
@@ -1,102 +0,0 @@
|
||||
import torch
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
class FSNERTokenizerUtils:
|
||||
def __init__(self, pretrained_model_name_or_path):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
|
||||
|
||||
def tokenize(self, x):
|
||||
"""
|
||||
Wrapper function for tokenizing query and supports
|
||||
Args:
|
||||
x (`List[str] or List[List[str]]`):
|
||||
List of strings for query or list of lists of strings for supports.
|
||||
Returns:
|
||||
`transformers.tokenization_utils_base.BatchEncoding` dict with additional keys and values for start_token_id, end_token_id and sizes of example lists for each entity type
|
||||
"""
|
||||
|
||||
if isinstance(x, list) and all(isinstance(_x, list) for _x in x):
|
||||
d = None
|
||||
for l in x:
|
||||
t = self.tokenizer(
|
||||
l,
|
||||
padding="max_length",
|
||||
max_length=384,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
t["sizes"] = torch.tensor([len(l)])
|
||||
if d is not None:
|
||||
for k in d.keys():
|
||||
d[k] = torch.cat((d[k], t[k]), 0)
|
||||
else:
|
||||
d = t
|
||||
|
||||
d["start_token_id"] = torch.tensor(self.tokenizer.convert_tokens_to_ids("[E]"))
|
||||
d["end_token_id"] = torch.tensor(self.tokenizer.convert_tokens_to_ids("[/E]"))
|
||||
|
||||
elif isinstance(x, list) and all(isinstance(_x, str) for _x in x):
|
||||
d = self.tokenizer(
|
||||
x,
|
||||
padding="max_length",
|
||||
max_length=384,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
else:
|
||||
raise Exception(
|
||||
"Type of parameter x was not recognized! Only `list of strings` for query or `list of lists of"
|
||||
" strings` for supports are supported."
|
||||
)
|
||||
|
||||
return d
|
||||
|
||||
def extract_entity_from_scores(self, query, W_query, p_start, p_end, thresh=0.70):
|
||||
"""
|
||||
Extracts entities from query and scores given a threshold.
|
||||
Args:
|
||||
query (`List[str]`):
|
||||
List of query strings.
|
||||
W_query (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of query sequence tokens in the vocabulary.
|
||||
p_start (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
||||
Scores of each token as being start token of an entity
|
||||
p_end (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
||||
Scores of each token as being end token of an entity
|
||||
thresh (`float`):
|
||||
Score threshold value
|
||||
Returns:
|
||||
A list of lists of tuples(decoded entity, score)
|
||||
"""
|
||||
|
||||
final_outputs = []
|
||||
for idx in range(len(W_query["input_ids"])):
|
||||
start_indexes = end_indexes = range(p_start.shape[1])
|
||||
|
||||
output = []
|
||||
for start_id in start_indexes:
|
||||
for end_id in end_indexes:
|
||||
if start_id < end_id:
|
||||
output.append(
|
||||
(
|
||||
start_id,
|
||||
end_id,
|
||||
p_start[idx][start_id].item(),
|
||||
p_end[idx][end_id].item(),
|
||||
)
|
||||
)
|
||||
|
||||
output.sort(key=lambda tup: (tup[2] * tup[3]), reverse=True)
|
||||
temp = []
|
||||
for k in range(len(output)):
|
||||
if output[k][2] * output[k][3] >= thresh:
|
||||
c_start_pos, c_end_pos = output[k][0], output[k][1]
|
||||
decoded = self.tokenizer.decode(W_query["input_ids"][idx][c_start_pos:c_end_pos])
|
||||
temp.append((decoded, output[k][2] * output[k][3]))
|
||||
|
||||
final_outputs.append(temp)
|
||||
|
||||
return final_outputs
|
||||
@@ -1,100 +0,0 @@
|
||||
|
||||
# Information Gain Filtration(IGF)
|
||||
|
||||
Authors @Tuko @mraunak
|
||||
|
||||
This folder contains the code how to implement IGF for finetuning on GPT-2.
|
||||
|
||||
## What is IGF?
|
||||
|
||||
Here we present a general fine-tuning method that we call information gain filtration for improving the overall training efficiency and final
|
||||
performance of language model fine-tuning(see paper below). The method is an alternative fine-tuning method that trains
|
||||
a secondary model (e.g., a simple convolutional network) to predict the amount of information
|
||||
gained over a given pre-trained model. The secondary model is lightweight and trained to
|
||||
predict the Information Gain measure. Information Gain is defined as the change in a loss
|
||||
function for a model before and after an SGD update with a sample (Equation X in the paper).
|
||||
A small subset of the training set named the “objective” set, is used to measure information
|
||||
gain on the pre-trained model, and consequently to train the secondary model. After
|
||||
training, the model is used for filtering samples for the fine-tuning process. Therefore,
|
||||
a high information gain value would suggest a sample is informative, whereas a low value
|
||||
would suggest a non-informative sample that should be filtered out. Thus, a thresholding
|
||||
strategy is defined to select informative samples. With such a strategy, samples are filtered
|
||||
and once enough samples are selected to form a mini-batch and a usual fine-tuning/optimization
|
||||
step is applied. The filtration process is repeated until the fine-tuning process is over.
|
||||
|
||||
Paper [Selecting Informative Contexts Improves Language Model Finetuning](https://arxiv.org/abs/2005.00175)
|
||||
|
||||
# Results
|
||||
|
||||
Several experiments were conducted to show the robustness of the IGF method versus the
|
||||
standard fine-tuning process. For example, we achieve a median perplexity of 54.0 on the
|
||||
Books dataset compared to 57.3 for standard fine-tuning on GPT-2 Small. The code was
|
||||
implemented using the Transformers library and Pytorch. While the method may seem more
|
||||
expensive, we saw enough evidence that it may lead to a performance benefit in the final models.
|
||||
|
||||

|
||||
|
||||
Figure 1: Comparing IGF to Standard Fine-tuning:
|
||||
IGF with constant (p < 10−3 , t-test) and shifting(p < 10−6 , t-test) thresholding significantly outperform standard fine-tuning. The left-hand figure shows
|
||||
test-set perplexity after each fine-tuning batch, averaged over 50 runs (error bars denote ± one standard error). The right-hand figure shows the perplexity of each
|
||||
method after 60 batches. IGF with shifting thresholding (red) clearly improves over standard batched fine-tuning with Adam
|
||||
|
||||
## How to use this project?
|
||||
|
||||
To fine-tune a transformer model with IGF on a language modeling task, use the following script:
|
||||
|
||||
- `model_name_or_path`: Path to pretrained model or model identifier from huggingface.co/models
|
||||
- `data_file`: A jbl file containing tokenized data which can be split as objective dataset,
|
||||
train_dataset and test_dataset
|
||||
- `igf_data_file`: A jbl file containing the context and information gain pairs to train secondary learner.
|
||||
- `context_len`: The maximum total input sequence length after tokenization. Sequences longer
|
||||
than this will be truncated, sequences shorter will be padded.
|
||||
- `size_objective_set`: Number of articles that are long enough to be used as our objective set"
|
||||
- `min_len`: The minimum length of the article to be used as objective set
|
||||
- `trim`: Truncate the example if it exceeds context length
|
||||
- `eval_freq`: Secondary model evaluation can be triggered at eval_freq
|
||||
- `max_steps`: To calculate training epochs
|
||||
- `number`: The number of examples split to be used as objective_set/test_data
|
||||
- `secondary_learner_batch_size`: The batch size of training data for secondary learner
|
||||
- `secondary_learner_max_epochs`: The number of epochs to train secondary learner
|
||||
- `recopy_model`: Reset the model to the original pretrained GPT-2 weights after each iteration
|
||||
- `eval_interval`: Decay the selectivity of our secondary learner filter from"
|
||||
1 standard deviation above average to 1 below average after eval_interval(10) batches"
|
||||
|
||||
|
||||
```python
|
||||
python run_clm_igf.py\
|
||||
--model_name_or_path "openai-community/gpt2" \
|
||||
--data_file="data/tokenized_stories_train_wikitext103" \
|
||||
--igf_data_file="data/IGF_values" \
|
||||
--context_len 32 \
|
||||
--size_objective_set 100 \
|
||||
--min_len 1026 \
|
||||
--trim True \
|
||||
--eval_freq 100 \
|
||||
--max_steps 1000 \
|
||||
--secondary_learner_batch_size 128 \
|
||||
--secondary_learner_max_epochs 15 \
|
||||
--number 100 \
|
||||
--recopy_model \
|
||||
--eval_interval 10 \
|
||||
```
|
||||
|
||||
## Citation
|
||||
|
||||
If you find the resource useful, please cite the following paper
|
||||
|
||||
```bibtex
|
||||
@inproceedings{antonello-etal-2021-selecting,
|
||||
title = "Selecting Informative Contexts Improves Language Model Fine-tuning",
|
||||
author = "Antonello, Richard and Beckage, Nicole and Turek, Javier and Huth, Alexander",
|
||||
booktitle = "Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers)",
|
||||
month = aug,
|
||||
year = "2021",
|
||||
address = "Online",
|
||||
publisher = "Association for Computational Linguistics",
|
||||
url = "https://aclanthology.org/2021.acl-long.87",
|
||||
doi = "10.18653/v1/2021.acl-long.87",
|
||||
pages = "1072--1085",
|
||||
}
|
||||
```
|
||||
@@ -1,416 +0,0 @@
|
||||
# Copyright 2022 - Intel Corp. All rights reserved.
|
||||
# Authors: Mayank Kumar Raunak, Javier Turek, Nicole Backage
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import random
|
||||
|
||||
import joblib
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import AdamW, GPT2LMHeadModel, get_linear_schedule_with_warmup
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
"""
|
||||
For reproducible training
|
||||
|
||||
Args:
|
||||
seed: A seed for reproducible training
|
||||
|
||||
"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def compute_perplexity(model, test_data, context_len):
|
||||
"""
|
||||
Computes perplexity of the transformer model on data in test_data
|
||||
|
||||
Args:
|
||||
model: Pre-trained GPT2 model
|
||||
test_data: Data on which perplexity calculation is required
|
||||
context_len: The maximum total input sequence length after tokenization. Sequences longer
|
||||
than this will be truncated, sequences shorter will be padded
|
||||
|
||||
Returns:
|
||||
Perplexity on input test data
|
||||
|
||||
"""
|
||||
|
||||
model.eval()
|
||||
device = next(model.parameters()).device
|
||||
eval_batch_size = 1
|
||||
context = torch.zeros((eval_batch_size, context_len), dtype=torch.long, device=device)
|
||||
eval_dataloader = DataLoader(test_data, shuffle=False, batch_size=eval_batch_size)
|
||||
eval_loss = torch.zeros(1, device=device)
|
||||
nb_eval_examples = 0
|
||||
for batch in eval_dataloader:
|
||||
batch.to(device)
|
||||
# pad
|
||||
context.zero_()
|
||||
for i in range(eval_batch_size):
|
||||
context[i, :] = batch[i]
|
||||
outputs = model(context, labels=context)
|
||||
eval_loss += outputs[0].sum().item()
|
||||
nb_eval_examples += batch.size(0)
|
||||
eval_loss = eval_loss / nb_eval_examples
|
||||
perplexity = torch.exp(eval_loss)
|
||||
model.train()
|
||||
return perplexity
|
||||
|
||||
|
||||
def load_gpt2(model_name="openai-community/gpt2"):
|
||||
"""
|
||||
load original openai-community/gpt2 and save off for quicker loading
|
||||
|
||||
Args:
|
||||
model_name: GPT-2
|
||||
|
||||
Returns:
|
||||
GPT-2 model
|
||||
|
||||
"""
|
||||
|
||||
model = GPT2LMHeadModel.from_pretrained(model_name, output_hidden_states=True)
|
||||
torch.save(model.state_dict(), model_name + "local.pt")
|
||||
return model
|
||||
|
||||
|
||||
def recopy_gpt2(orig_model, device, max_steps):
|
||||
"""
|
||||
Reset the model to the original pretrained GPT-2 weights after each iteration
|
||||
|
||||
Args:
|
||||
orig_model: Original pretrained GPT-2 model imported from Transformers library
|
||||
device: CPU/GPU
|
||||
max_steps: number of training steps
|
||||
|
||||
Returns:
|
||||
Original PreTrained GPT-2 model,
|
||||
lm_optimizer: Adam optimizer with Decoupled weight decay
|
||||
lm_scheduler: linear scheduler with the appropriate schedule
|
||||
|
||||
"""
|
||||
model = copy.deepcopy(orig_model)
|
||||
model.to(device)
|
||||
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||
]
|
||||
lm_optimizer = AdamW(optimizer_grouped_parameters, lr=5e-5, eps=1e-8)
|
||||
lm_scheduler = get_linear_schedule_with_warmup(lm_optimizer, 0, max_steps)
|
||||
torch.cuda.empty_cache()
|
||||
return model, lm_optimizer, lm_scheduler
|
||||
|
||||
|
||||
def intermittent_save(contexts, real_perps, past_perps, filename):
|
||||
"""
|
||||
save the perplexity differences to filename
|
||||
|
||||
Args:
|
||||
contexts: Example on which the perplexity is calculated
|
||||
real_perps: Perplexity after back-propagating on the selected context
|
||||
past_perps: Perplexity of model before training on the context
|
||||
filename: File to store perplexity differences
|
||||
|
||||
Returns:
|
||||
file with perplexity differences
|
||||
|
||||
"""
|
||||
# save the perplexity differences to filename
|
||||
avg = np.array(real_perps).mean()
|
||||
std = np.array(real_perps).std()
|
||||
perp_diff = (real_perps - avg) / std
|
||||
data_final = list(zip(contexts, perp_diff, past_perps))
|
||||
joblib.dump(data_final, filename)
|
||||
|
||||
|
||||
def collect_objective_set(
|
||||
model,
|
||||
orig_perp,
|
||||
context_len,
|
||||
train_data,
|
||||
objective_set,
|
||||
max_steps,
|
||||
device,
|
||||
filename="dev.jbl",
|
||||
recopy_model=recopy_gpt2,
|
||||
):
|
||||
"""
|
||||
Collect individual IGF values from pre-trained transformer model
|
||||
max_steps samples of training data to train secondary model
|
||||
|
||||
Args:
|
||||
model: Pre-trained GPT2 model
|
||||
orig_perp: Perplexity of original pretrained GPT-2 model
|
||||
context_len: The maximum total input sequence length after tokenization. Sequences longer
|
||||
than this will be truncated, sequences shorter will be padded
|
||||
train_data: Data to train model
|
||||
objective_set: Contexts used to create (X,IG(X)) pairs which is the training data for secondary learner
|
||||
max_steps: To calculate training epochs of model
|
||||
device: GPU/CPU
|
||||
filename: To store intermediate perplexity differences
|
||||
recopy_model: Reset the model to the original pretrained GPT-2 weights after each iteration
|
||||
|
||||
Returns:
|
||||
file stored intermediate perplexity differences in intermediate stages
|
||||
|
||||
"""
|
||||
|
||||
# initialize variables to record relevant information
|
||||
contexts = []
|
||||
real_perps = []
|
||||
past_perps = []
|
||||
|
||||
# Initialize the transformer model
|
||||
orig_model = copy.deepcopy(model)
|
||||
orig_model.to(device="cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Compute perplexity of initial transformer model for comparison
|
||||
model.train()
|
||||
model, lm_optimizer, lm_scheduler = recopy_model(orig_model, device, max_steps)
|
||||
|
||||
for step in tqdm(range(max_steps)):
|
||||
context = torch.zeros((1, context_len), dtype=torch.long, device=device)
|
||||
story = random.choice(train_data)
|
||||
start = random.randint(0, len(story[0]) - context_len - 1)
|
||||
context[0, :] = story[0][start : start + context_len]
|
||||
lm_optimizer.zero_grad()
|
||||
outputs = model(context, labels=context)
|
||||
lm_loss = outputs[0]
|
||||
past_perp = compute_perplexity(model, context, context_len)
|
||||
model.train()
|
||||
lm_loss.backward()
|
||||
# Do LM backprop
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
|
||||
lm_optimizer.step()
|
||||
lm_scheduler.step() # Update learning rate schedule
|
||||
|
||||
# Compute perplexity after back-propagating on the selected context
|
||||
real_perp = compute_perplexity(model, objective_set, context_len)
|
||||
|
||||
# Periodically save the stored (X, IG(X)) pairs
|
||||
if step % 1000 == 0 and step > 1:
|
||||
intermittent_save(contexts, real_perps, past_perps, filename)
|
||||
|
||||
# Reset the pretrained model to the original pretrained GPT-2 weights after each iteration
|
||||
model, lm_optimizer, lm_scheduler = recopy_model(orig_model, device, max_steps)
|
||||
|
||||
past_perps.append(past_perp.item())
|
||||
real_perps.append(orig_perp - real_perp.item())
|
||||
contexts.append(np.array(context.cpu()))
|
||||
|
||||
intermittent_save(contexts, real_perps, past_perps, filename)
|
||||
|
||||
|
||||
def generate_datasets(
|
||||
context_len, file="data/tokenized_stories_train_wikitext103.jbl", number=100, min_len=1026, trim=True
|
||||
):
|
||||
"""
|
||||
Generate objective set and training set
|
||||
|
||||
Args:
|
||||
context_len: The maximum total input sequence length after tokenization. Sequences longer
|
||||
than this will be truncated, sequences shorter will be padded
|
||||
file: Tokenized data split into training set and objective set
|
||||
number: size of objective dataset
|
||||
min_len: minimum length of a context in objective set
|
||||
trim: If True truncate the context if it exceeds context length
|
||||
|
||||
Returns:
|
||||
Generated objective set and training data
|
||||
|
||||
|
||||
"""
|
||||
# Generate objective set and training set
|
||||
# Designate the first number (100) articles that are long enough to be used
|
||||
# as our objective set, rest (that are long enough) are training data for
|
||||
# secondary learner
|
||||
|
||||
data = joblib.load(file)
|
||||
print("data loaded")
|
||||
objective_set = []
|
||||
if trim:
|
||||
for i, example in enumerate(data):
|
||||
if len(example[0]) > min_len:
|
||||
start = random.randint(0, len(example[0]) - context_len - 1)
|
||||
objective_set.append(example[0, start : start + context_len])
|
||||
if len(objective_set) >= number:
|
||||
break
|
||||
train_data = []
|
||||
for j in range(i + 1, len(data)):
|
||||
if len(data[j][0]) > min_len:
|
||||
train_data.append(data[j])
|
||||
else:
|
||||
objective_set = data[0:number]
|
||||
train_data = data[number:]
|
||||
|
||||
joblib.dump(objective_set, "objective_set.jbl")
|
||||
print("objective set saved")
|
||||
return train_data, objective_set
|
||||
|
||||
|
||||
def train_secondary_learner(
|
||||
secondary_learner, train_dataset, max_epochs, batch_size, eval_freq=50, igf_model_path="secondary_learner.pt"
|
||||
):
|
||||
"""
|
||||
Train the secondary learner (igf_model)
|
||||
|
||||
Args:
|
||||
secondary_learner: secondary learner
|
||||
train_dataset: data to train secondary learner
|
||||
max_epochs: number of epochs to train secondary learner
|
||||
batch_size: batch size of training data of secondary learner
|
||||
eval_freq: secondary model evaluation can be triggered at eval_freq
|
||||
igf_model_path: path to store trained secondary learner
|
||||
|
||||
Returns:
|
||||
Trained secondary learner
|
||||
|
||||
"""
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
# We will use the first 512 pairs from our dataset as a test set for
|
||||
# our secondary learner and the rest to train
|
||||
test_dataset = train_dataset[:512]
|
||||
train_dataset = train_dataset[512:]
|
||||
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
|
||||
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)
|
||||
|
||||
# secondary learner model set up
|
||||
loss = nn.MSELoss()
|
||||
test_loss = nn.MSELoss(reduction="sum")
|
||||
secondary_learner.to(device)
|
||||
q_optimizer = torch.optim.Adam(secondary_learner.parameters(), lr=0.00001)
|
||||
secondary_learner.train()
|
||||
|
||||
# TODO in original code this is written as number of actual batches seen
|
||||
# not number of items seen but other places it is number of items instead.
|
||||
# improve consistency! changed this to epochs for clarity
|
||||
best_test_loss = float("inf")
|
||||
# Iterate through batches until we've used max_steps batches
|
||||
for epoch in range(int(max_epochs)):
|
||||
tr_q_loss = 0.0
|
||||
secondary_learner.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
context = batch[0].to(device)
|
||||
real_q = batch[1].to(device)
|
||||
predicted_q = secondary_learner(context)
|
||||
q_optimizer.zero_grad()
|
||||
q_loss = loss(predicted_q, real_q.float())
|
||||
q_loss.backward()
|
||||
q_optimizer.step()
|
||||
tr_q_loss += q_loss.item()
|
||||
|
||||
# model trains fairly quickly so we won't wait for a full epoch
|
||||
# eval is triggered at eval_freq and end of epochs
|
||||
if (step % eval_freq == 0 and step > 0) or ((step + 1) == len(train_dataloader)):
|
||||
tr_loss = tr_q_loss / (step + 1)
|
||||
|
||||
secondary_learner.eval()
|
||||
q_loss2 = 0.0
|
||||
sum_q2 = 0.0
|
||||
predicted = []
|
||||
actual = []
|
||||
# Compute performance of the secondary learner after this batch
|
||||
for step2, batch2 in enumerate(test_dataloader):
|
||||
features2 = batch2[0].to(device)
|
||||
real_q2 = batch2[1].to(device)
|
||||
predicted_q2 = secondary_learner(features2)
|
||||
q_loss2 += test_loss(predicted_q2, real_q2).item()
|
||||
sum_q2 += torch.sum(predicted_q2).item()
|
||||
for ei, i in enumerate(predicted_q2.cpu().detach().numpy()):
|
||||
predicted.append(i.item())
|
||||
for ei, i in enumerate(real_q2.cpu().detach().numpy()):
|
||||
actual.append(i.item())
|
||||
|
||||
q_loss2 /= len(test_dataset)
|
||||
print(
|
||||
"Epoch: ",
|
||||
epoch,
|
||||
"step: ",
|
||||
step,
|
||||
"Avg. q:",
|
||||
sum_q2 / len(test_dataset),
|
||||
"Train Loss: ",
|
||||
tr_loss,
|
||||
"Test Loss: ",
|
||||
q_loss2,
|
||||
)
|
||||
if q_loss2 < best_test_loss:
|
||||
joblib.dump((predicted, actual), "pred_vs_actual.jbl")
|
||||
torch.save(secondary_learner.state_dict(), igf_model_path)
|
||||
best_test_loss = q_loss2
|
||||
|
||||
secondary_learner.train()
|
||||
return secondary_learner
|
||||
|
||||
|
||||
class SecondaryLearner(nn.Module):
|
||||
"""
|
||||
Our secondary learner
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
"""
|
||||
We use a simple convolutional network as our secondary learner
|
||||
|
||||
Args:
|
||||
model: Pre-trained GPT2 model
|
||||
"""
|
||||
# embeddings are from the pretrained model
|
||||
super(SecondaryLearner, self).__init__()
|
||||
self.embeddings = model.transformer.wte
|
||||
self.embeddings.weight = copy.deepcopy(model.transformer.wte.weight)
|
||||
self.conv = nn.Conv1d(self.embeddings.weight.size(1), 256, 3, padding=1)
|
||||
self.fc = nn.Sequential(nn.Linear(256, 32), nn.Dropout(p=0.1), nn.Linear(32, 32), nn.Linear(32, 1))
|
||||
|
||||
def forward(self, context):
|
||||
"""
|
||||
Forward pass through the secondary learner
|
||||
|
||||
Args:
|
||||
context: Context input to the secondary learner
|
||||
|
||||
Returns:
|
||||
tensor after squeeze operation
|
||||
|
||||
"""
|
||||
pooled = torch.max(self.conv(self.embeddings(context).squeeze(1).transpose(1, 2)), 2)[0]
|
||||
qs = self.fc(pooled)
|
||||
return qs.squeeze(1)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, state_path, model):
|
||||
"""
|
||||
Load the secondary learner
|
||||
|
||||
Args:
|
||||
state_path: Path to save secondary learner
|
||||
model: Pretrained GPT-2
|
||||
|
||||
Returns:
|
||||
secondary learner
|
||||
"""
|
||||
|
||||
secondary_learner = cls(model) # this calls __init__
|
||||
state_dict = torch.load(state_path)
|
||||
secondary_learner.load_state_dict(state_dict)
|
||||
secondary_learner.embeddings = model.transformer.wte
|
||||
secondary_learner.embeddings.weight = copy.deepcopy(model.transformer.wte.weight)
|
||||
return secondary_learner
|
||||
@@ -1,6 +0,0 @@
|
||||
matplotlib
|
||||
numpy>=1.17.2
|
||||
joblib>=0.13.2
|
||||
scipy
|
||||
torch>=1.10.1
|
||||
transformers>=3.5
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 34 KiB |
@@ -1,450 +0,0 @@
|
||||
# Copyright 2022 - Intel Corp. All rights reserved.
|
||||
# Authors: Mayank Kumar Raunak, Javier Turek, Nicole Beckage
|
||||
|
||||
"""
|
||||
Implementation of a new method for fine-tuning transformer models that we call
|
||||
Information Gain Filtration 'IGF' on WikiText data set and compared the results
|
||||
with the standard fine-tuning method
|
||||
|
||||
Steps followed in the code:
|
||||
|
||||
1) Generate a objective dataset of pairs (X, IG(X)). IG(X)--Informativeness of context 'X'.
|
||||
Our IG (information gain) model is learning to predict the ‘informativeness’ of a particular
|
||||
context. Informativeness is the change in metric between the model’s accuracy on an
|
||||
objective set before and after seeing that context. For casual language modeling, the
|
||||
metric is perplexity.
|
||||
|
||||
2) A secondary learner is trained to infer a function approximation for IG using the dataset
|
||||
created in (1).
|
||||
|
||||
3) The learner created in (2) is used to inform the fine-tuning process and filter out low informative samples.
|
||||
|
||||
Last, a plot is generated to compare the performance of IGF to standard fine-tuning without any filtering
|
||||
|
||||
"""
|
||||
|
||||
# Prerequisite libraries:
|
||||
|
||||
import argparse
|
||||
import random
|
||||
|
||||
import joblib
|
||||
import numpy as np
|
||||
import torch
|
||||
from igf.igf import (
|
||||
SecondaryLearner,
|
||||
collect_objective_set,
|
||||
compute_perplexity,
|
||||
generate_datasets,
|
||||
load_gpt2,
|
||||
recopy_gpt2,
|
||||
set_seed,
|
||||
train_secondary_learner,
|
||||
)
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
|
||||
from transformers import GPT2LMHeadModel
|
||||
|
||||
|
||||
def generate_n_pairs(
|
||||
context_len=32,
|
||||
max_steps=10,
|
||||
size_objective_set=100,
|
||||
min_len=1026,
|
||||
trim=True,
|
||||
data_file="data/tokenized_stories_train_wikitext103.jbl",
|
||||
igf_data_file="igf_context_pairs.jbl",
|
||||
):
|
||||
"""
|
||||
Collecting *n* pairs for training the secondary learner
|
||||
Args:
|
||||
context_len: The maximum total input sequence length after tokenization. Sequences longer
|
||||
than this will be truncated, sequences shorter will be padded
|
||||
max_steps: To calculate training epochs of secondary learner
|
||||
size_objective_set: size of objective data set used to create (X,IG(X)) pairs which is the training data for secondary learner
|
||||
min_len: The minimum length of the article to be used as objective set
|
||||
trim: If True truncate the context if it exceeds context length
|
||||
data_file: Tokenized data set split for training and evaluation of model
|
||||
igf_data_file: file to store (I,IG(X)) paired data set to train secondary learner
|
||||
|
||||
Returns:
|
||||
Data stored in igf_data_file
|
||||
|
||||
"""
|
||||
# generates same data everytime
|
||||
set_seed(3)
|
||||
# generate train_data and objective_set
|
||||
train_data, objective_set = generate_datasets(
|
||||
context_len, data_file, number=size_objective_set, min_len=1026, trim=True
|
||||
)
|
||||
# keeps model same across runs
|
||||
set_seed(4)
|
||||
# model, lm_optimizer, lm_scheduler = recopy_gpt2(model, device, max_steps) # store original model weights
|
||||
# can we train on GPU?
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# load pretrained model
|
||||
model = load_gpt2("openai-community/gpt2").to(device)
|
||||
print("computing perplexity on objective set")
|
||||
orig_perp = compute_perplexity(model, objective_set, context_len).item()
|
||||
print("perplexity on objective set:", orig_perp)
|
||||
|
||||
# collect igf pairs and save to file demo.jbl
|
||||
collect_objective_set(model, orig_perp, context_len, train_data, objective_set, max_steps, device, igf_data_file)
|
||||
|
||||
# clean up, delete model and data we don't need anymore
|
||||
del model, train_data, objective_set
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def training_secondary_learner(
|
||||
secondary_learner_train_data,
|
||||
secondary_learner_max_epochs=15,
|
||||
secondary_learner_batch_size=128,
|
||||
eval_freq=100,
|
||||
igf_model_path="igf_model.pt",
|
||||
):
|
||||
"""
|
||||
Train the secondary learner
|
||||
|
||||
Args:
|
||||
secondary_learner_train_data: Data set with (X,IG(X)) pairs to train secondary learner where IG(X) - measure of informativeness and X- context
|
||||
secondary_learner_max_epochs: Number of epochs to train secondary learner
|
||||
secondary_learner_batch_size: Batch size to train secondary learner
|
||||
eval_freq (object): secondary model evaluation can be triggered at eval_freq
|
||||
igf_model_path: path to store trained secondary learner
|
||||
|
||||
Returns:
|
||||
Trained secondary learner
|
||||
"""
|
||||
|
||||
set_seed(42)
|
||||
|
||||
# Load pre-trained model
|
||||
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")
|
||||
|
||||
# Initialize secondary learner to use embedding weights of model
|
||||
secondary_learner = SecondaryLearner(model)
|
||||
|
||||
# Train secondary learner
|
||||
secondary_learner = train_secondary_learner(
|
||||
secondary_learner,
|
||||
secondary_learner_train_data,
|
||||
max_epochs=secondary_learner_max_epochs,
|
||||
batch_size=secondary_learner_batch_size,
|
||||
eval_freq=100,
|
||||
igf_model_path=igf_model_path,
|
||||
)
|
||||
|
||||
del model, secondary_learner_train_data
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return secondary_learner
|
||||
|
||||
|
||||
def finetune(
|
||||
model,
|
||||
train_dataset,
|
||||
test_dataset,
|
||||
context_len=32,
|
||||
max_steps=1000,
|
||||
batch_size=16,
|
||||
threshold=1.0,
|
||||
recopy_model=recopy_gpt2,
|
||||
secondary_learner=None,
|
||||
eval_interval=10,
|
||||
finetuned_model_name="openai-community/gpt2_finetuned.pt",
|
||||
):
|
||||
"""
|
||||
fine-tune with IGF if secondary_learner is not None, else standard fine-tuning
|
||||
|
||||
Args:
|
||||
model: pre-trained GPT-2 model
|
||||
train_dataset: Data set to train GPT-2 model
|
||||
test_dataset: Evaluate GPT-2 model
|
||||
context_len: The maximum total input sequence length after tokenization. Sequences longer
|
||||
than this will be truncated, sequences shorter will be padded
|
||||
max_steps: To calculate training epochs
|
||||
batch_size: Batch size to train GPT-2 model
|
||||
threshold: The threshold value used by secondary learner to filter the train_data and allow only"
|
||||
informative data as input to the model
|
||||
recopy_model: Reset the model to the original pretrained GPT-2 weights after each iteration
|
||||
secondary_learner: Selection of IGF as fine-tuning method if not None
|
||||
eval_interval: number of batches after which decay the selectivity of our secondary learner filter from
|
||||
1 standard deviation above average to 1 below average
|
||||
fine-tuned_model_name: name of the final final-tuned GPT-2 model
|
||||
|
||||
Returns:
|
||||
Fine-tuned GPT-2 model
|
||||
|
||||
"""
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
train_sampler = RandomSampler(train_dataset)
|
||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler)
|
||||
|
||||
num_train_epochs = max_steps // (len(train_dataset)) + 1
|
||||
global_step = 0
|
||||
context = torch.zeros((1, context_len), dtype=torch.long, device=device)
|
||||
model, lm_optimizer, lm_scheduler = recopy_model(model, device, max_steps)
|
||||
|
||||
model.train()
|
||||
if secondary_learner is not None:
|
||||
secondary_learner.to(device)
|
||||
secondary_learner.eval()
|
||||
contexts = []
|
||||
examples = 0
|
||||
|
||||
observed_qs = []
|
||||
test_perps = []
|
||||
|
||||
# Compute the performance of the transformer model at the beginning
|
||||
real_perp = compute_perplexity(model, test_dataset, context_len)
|
||||
test_perps.append(real_perp)
|
||||
print("Test perplexity, step", global_step, ":", real_perp)
|
||||
for epoch in range(int(num_train_epochs)):
|
||||
for step, example in enumerate(train_dataloader):
|
||||
torch.cuda.empty_cache()
|
||||
start = random.randint(0, example.size(2) - context_len - 1)
|
||||
context[0, :] = example[0, 0, start : start + context_len]
|
||||
lm_optimizer.zero_grad()
|
||||
outputs = model(context, labels=context)
|
||||
do_backprop = True
|
||||
|
||||
if secondary_learner is not None:
|
||||
predicted_q = secondary_learner.forward(
|
||||
torch.tensor(context, dtype=torch.long, device=device).unsqueeze(0)
|
||||
)[0].item()
|
||||
observed_qs.append(float(predicted_q))
|
||||
|
||||
# Here we implement the simple non-constant threshold for the predicted IG(X) value
|
||||
# We will decay the selectivity of our secondary learner filter from
|
||||
# 1 standard deviation above average to 1 below average after 10 batches.
|
||||
|
||||
if global_step == 10:
|
||||
threshold = -1
|
||||
if predicted_q < threshold:
|
||||
do_backprop = False
|
||||
|
||||
# If we passed the filter, add the context to the batch!
|
||||
if do_backprop:
|
||||
contexts.append(np.array(context.cpu()))
|
||||
lm_loss = outputs[0]
|
||||
lm_loss.backward()
|
||||
examples += 1
|
||||
|
||||
del outputs
|
||||
|
||||
# Once the batch is filled with enough contexts, backprop on the batch.
|
||||
if examples == batch_size:
|
||||
torch.cuda.empty_cache()
|
||||
examples = 0
|
||||
# Do LM backprop
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
|
||||
lm_optimizer.step()
|
||||
lm_scheduler.step() # Update learning rate schedule
|
||||
global_step += 1
|
||||
# Compute the performance of the transformer model at this batch
|
||||
if global_step % eval_interval == 0:
|
||||
real_perp = compute_perplexity(model, test_dataset, context_len)
|
||||
test_perps.append(real_perp)
|
||||
|
||||
print("Test perplexity, step", global_step, ":", real_perp)
|
||||
# Break out of the loop after 60 batches
|
||||
if max_steps > 0 and global_step > 60:
|
||||
break
|
||||
if max_steps > 0 and global_step > 60:
|
||||
break
|
||||
|
||||
# save finetuned transformer model
|
||||
torch.save(model.state_dict(), finetuned_model_name)
|
||||
torch.cuda.empty_cache()
|
||||
# Do some cleaning up so we can reinitialize for the next run of this function
|
||||
del lm_optimizer
|
||||
del lm_scheduler
|
||||
return model
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Fine-tune a transformer model with IGF on a language modeling task")
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain data files for WikiText.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"A jbl file containing tokenized data which can be split as objective dataset, "
|
||||
"train_dataset and test_dataset."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--igf_data_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A jbl file containing the context and information gain pairs to train secondary learner.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the final fine-tuned model is stored.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
|
||||
parser.add_argument(
|
||||
"--context_len",
|
||||
default=32,
|
||||
type=int,
|
||||
help=(
|
||||
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--size_objective_set",
|
||||
default=100,
|
||||
type=int,
|
||||
help="number of articles that are long enough to be used as our objective set",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_freq", default=100, type=int, help="secondary model evaluation is triggered at eval_freq"
|
||||
)
|
||||
|
||||
parser.add_argument("--max_steps", default=1000, type=int, help="To calculate training epochs")
|
||||
|
||||
parser.add_argument(
|
||||
"--secondary_learner_batch_size",
|
||||
default=128,
|
||||
type=int,
|
||||
help="batch size of training data for secondary learner",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
default=16,
|
||||
type=int,
|
||||
help="batch size of training data of language model(openai-community/gpt2) ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--eval_interval",
|
||||
default=10,
|
||||
type=int,
|
||||
help=(
|
||||
"decay the selectivity of our secondary learner filter from "
|
||||
"1 standard deviation above average to 1 below average after 10 batches"
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--number", default=100, type=int, help="The number of examples split to be used as objective_set/test_data"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--min_len", default=1026, type=int, help="The minimum length of the article to be used as objective set"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--secondary_learner_max_epochs", default=15, type=int, help="number of epochs to train secondary learner"
|
||||
)
|
||||
|
||||
parser.add_argument("--trim", default=True, type=bool, help="truncate the example if it exceeds context length")
|
||||
|
||||
parser.add_argument(
|
||||
"--threshold",
|
||||
default=1.0,
|
||||
type=float,
|
||||
help=(
|
||||
"The threshold value used by secondary learner to filter the train_data and allow only"
|
||||
" informative data as input to the model"
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--finetuned_model_name", default="openai-community/gpt2_finetuned.pt", type=str, help="finetuned_model_name"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--recopy_model",
|
||||
default=recopy_gpt2,
|
||||
type=str,
|
||||
help="Reset the model to the original pretrained GPT-2 weights after each iteration",
|
||||
)
|
||||
|
||||
# function calls
|
||||
# Collecting *n* pairs of context and information gain(X, IG(X)) for training the secondary learner
|
||||
generate_n_pairs(
|
||||
context_len=32,
|
||||
max_steps=10,
|
||||
size_objective_set=100,
|
||||
min_len=1026,
|
||||
trim=True,
|
||||
data_file="data/tokenized_stories_train_wikitext103.jbl",
|
||||
igf_data_file="igf_context_pairs.jbl",
|
||||
)
|
||||
|
||||
# Load train data for secondary learner
|
||||
secondary_learner_train_data = joblib.load("data/IGF_values.jbl")
|
||||
|
||||
# Train secondary learner
|
||||
secondary_learner = training_secondary_learner(
|
||||
secondary_learner_train_data,
|
||||
secondary_learner_max_epochs=15,
|
||||
secondary_learner_batch_size=128,
|
||||
eval_freq=100,
|
||||
igf_model_path="igf_model.pt",
|
||||
)
|
||||
|
||||
# load pretrained openai-community/gpt2 model
|
||||
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")
|
||||
set_seed(42)
|
||||
|
||||
# Generate train and test data to train and evaluate openai-community/gpt2 model
|
||||
train_dataset, test_dataset = generate_datasets(
|
||||
context_len=32, file="data/tokenized_stories_train_wikitext103.jbl", number=100, min_len=1026, trim=True
|
||||
)
|
||||
|
||||
# fine-tuning of the openai-community/gpt2 model using igf (Information Gain Filtration)
|
||||
finetune(
|
||||
model,
|
||||
train_dataset,
|
||||
test_dataset,
|
||||
context_len=32,
|
||||
max_steps=1000,
|
||||
batch_size=16,
|
||||
threshold=1.0,
|
||||
recopy_model=recopy_gpt2,
|
||||
secondary_learner=secondary_learner,
|
||||
eval_interval=10,
|
||||
finetuned_model_name="openai-community/gpt2_finetuned.pt",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,109 +0,0 @@
|
||||
# How to propose a Flax/JAX + Transformers project
|
||||
|
||||
Great that you've opened this document!
|
||||
While we at 🤗 are proposing a couple of projects, we strongly
|
||||
believe that the community can come up with much more **creative**, **fun**, and
|
||||
**impactful** projects on their own. This being said, we are really looking forward
|
||||
to seeing your project proposal!
|
||||
|
||||
## What a project should be about
|
||||
|
||||
The proposed project should fall into the machine learning fields of **Natural Language Processing (NLP)** and/or **Computer Vision (CV)** (possibly also **Speech Recognition (ASR)** depending on whether Speech Recognition models are available in Flax in due time) and aim at solving a specific task.
|
||||
Possible tasks can belong to:
|
||||
|
||||
* text classification
|
||||
* text generation
|
||||
* image recognition
|
||||
* image processing
|
||||
* image captioning
|
||||
* audio classification
|
||||
* and other tasks you can think of!
|
||||
|
||||
The clearer a task is defined, the better your project proposal is.
|
||||
*E.g.* "Using a T5 model to learn grammar correction in French" or "Adapting a pre-trained CLIP model for zero-shot image classification in Spanish" are **well-defined and clear** project proposals, while something like "Train a language model" or "Image classification" are **too vague**.
|
||||
|
||||
There is no limit to your creativity as long as the project is feasible and ethical.
|
||||
The more creative & specific your project proposal, the more interesting it will be,
|
||||
and the more likely will you find motivated team members to work on your project!
|
||||
To get an idea of how to formulate your project proposals, you can browse through
|
||||
existing project proposals on the [forum](https://discuss.huggingface.co/c/flax-jax-projects/22).
|
||||
|
||||
## How to submit a project proposal
|
||||
|
||||
First, you should make sure that you are [logged in](https://huggingface.co/login?sso=bm9uY2U9OTRlNjZjZmZhYjMwMmJmMWMyYjc5MmFiMTMyMzY5ODYmcmV0dXJuX3Nzb191cmw9aHR0cHMlM0ElMkYlMkZkaXNjdXNzLmh1Z2dpbmdmYWNlLmNvJTJGc2Vzc2lvbiUyRnNzb19sb2dpbg%3D%3D&sig=429ad8924bcb33c40f9823027ea749abb55d393f4f58924f36a2dba3ab0a48da) with your Hugging Face account on the forum.
|
||||
|
||||
Second, make sure that your project idea doesn't already exist by checking [existing projects](https://discuss.huggingface.co/c/flax-jax-projects/22).
|
||||
If your project already exists - great! This means that you can comment and improve
|
||||
the existing idea and join the project to form a team! If your project idea already
|
||||
exists for a different language, feel free to submit the same project idea, just in
|
||||
a different language.
|
||||
|
||||
Third, having ensured that your project doesn't exist, click on the *"New Topic"*
|
||||
button on the [Flax/JAX Projects Forum category](https://discuss.huggingface.co/c/flax-jax-projects/22) to create a new project proposal.
|
||||
|
||||
Fourth, make sure that your project proposal includes the following information:
|
||||
|
||||
1. *A clear description of the project*
|
||||
2. *In which language should the project be conducted?* English, German, Chinese, ...? It can also be a multi-lingual project
|
||||
3. *Which model should be used?* If you want to adapt an existing model, you can add the link to one of the 4000 available checkpoints in JAX [here](https://huggingface.co/models?filter=jax) If you want to train a model from scratch, you can simply state the model architecture to be used, *e.g.* BERT, CLIP, etc. You can also base your project on a model that is not part of transformers. For an overview of libraries based on JAX, you can take a look at [awesome-jax](https://github.com/n2cholas/awesome-jax#awesome-jax-). **Note** that for a project that is not based on Transformers it will be more difficult for the 🤗 team to help you. Also have a look at the section [Quickstart Flax & Jax in Transformers](https://github.com/huggingface/transformers/tree/main/examples/research_projects/jax-projects#quickstart-flax-and-jax-in-transformers) to see what model architectures are currently supported in 🤗 Transformers.
|
||||
4. *What data should be used?* It is important to state at least what kind of data you would like to use. Ideally, you can already point to publicly available data or a dataset in the 🤗 Datasets library.
|
||||
5. *Are similar training scripts available in Flax/JAX?* It would be important to find similar training scripts that already exist in Flax/JAX. *E.g.* if you are working on a Seq-to-Seq task, you can make use of the [`run_summarization_flax.py`](https://github.com/huggingface/transformers/blob/main/examples/flax/summarization/run_summarization_flax.py) script which is very similar to any seq2seq training. Also have a look at the section [Quickstart Flax & Jax in Transformers](https://github.com/huggingface/transformers/tree/main/examples/research_projects/jax-projects#quickstart-flax-and-jax-in-transformers) to see what training scripts are currently supported in 🤗 Transformers.
|
||||
6. *(Optionally) What are possible challenges?* List possible difficulties with your project. *E.g.* If you know that training convergence usually takes a lot of time, it is worth stating this here!
|
||||
7. *(Optionally) What is the desired project outcome?* - How would you like to demo your project? One could *e.g.* create a Streamlit application.
|
||||
8. *(Optionally) Links to read upon* - Can you provide any links that would help the reader to better understand your project idea?
|
||||
|
||||
Feel free to copy-paste the following format for your project proposal and fill out the respective sections:
|
||||
|
||||
```
|
||||
# <FILL ME: Name of project>
|
||||
|
||||
<FILL ME: A clear description of the project>
|
||||
|
||||
## 2. Language
|
||||
|
||||
The model will be trained in <FILL ME: which language?>.
|
||||
|
||||
## 3. Model
|
||||
|
||||
<FILL ME: 3. Which model should be used?>
|
||||
|
||||
## 4. Datasets
|
||||
|
||||
<FILL ME: 4. Which data should be used?>
|
||||
|
||||
Possible links to publicly available datasets include:
|
||||
- <FILL ME: Link 1 to dataset>
|
||||
- <FILL ME: Link 2 to dataset>
|
||||
- <FILL ME: Link 3 to dataset>
|
||||
|
||||
## 5. Training scripts
|
||||
|
||||
<FILL ME: 5. Are there publicly available training scripts that can be used/tweaked for the project?>
|
||||
|
||||
We can make use of <FILL ME: link to training script> to train the model.>
|
||||
|
||||
## 6. (Optional) Challenges
|
||||
|
||||
<(Optionally) FILL ME: 6. What are possible challenges?>
|
||||
|
||||
## 7. (Optional) Desired project outcome
|
||||
|
||||
<(Optionally) FILL ME: 7. What is the desired project outcome? A demo?>
|
||||
|
||||
## 8. (Optional) Reads
|
||||
|
||||
The following links can be useful to better understand the project and
|
||||
what has previously been done.
|
||||
|
||||
- <FILL ME: Link 1 to read>
|
||||
- <FILL ME: Link 2 to read>
|
||||
- <FILL ME: Link 3 to read>
|
||||
```
|
||||
|
||||
To see how a proposed project looks like, please have a look at submitted project
|
||||
proposals [here](https://discuss.huggingface.co/c/flax-jax-projects/22).
|
||||
|
||||
## Will my project proposal be selected?
|
||||
|
||||
Having submitted a project proposal, you can now promote your idea in the Slack channel `#flax-jax-community-week` to try to convince other participants to join your project!
|
||||
Once other people have joined your project, one of the organizers (`@Suzana, @valhalla, @osanseviero, @patrickvonplaten`) will officially create a team for your project and add your project to [this google sheet](https://docs.google.com/spreadsheets/d/1GpHebL7qrwJOc9olTpIPgjf8vOS0jNb6zR_B8x_Jtik/edit?usp=sharing).
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,60 +0,0 @@
|
||||
|
||||
Author: [@vasudevgupta7](https://github.com/thevasudevgupta/)
|
||||
|
||||
## Intro
|
||||
|
||||
In this project, we fine-tuned [**BigBird**](https://arxiv.org/abs/2007.14062) on [**natural-questions**](https://huggingface.co/datasets/natural_questions) dataset for **question-answering** task on long documents. **BigBird**, is a **sparse-attention based transformer** which extends Transformer based models, such as BERT to much **longer sequences**.
|
||||
|
||||
Read more about BigBird at https://huggingface.co/blog/big-bird
|
||||
|
||||
## Fine-tuning
|
||||
|
||||
**Setup**
|
||||
|
||||
You need to install jax yourself by following the official docs ([refer this](https://github.com/google/jax#installation)). Other requirements for this project can be installed by running following command:
|
||||
|
||||
```shell
|
||||
pip3 install -qr requirements.txt
|
||||
```
|
||||
|
||||
**Download & prepare dataset**
|
||||
|
||||
The Natural Questions corpus contains questions from real users, and it requires QA systems to read and comprehend an entire Wikipedia article that may or may not contain the answer to the question. This corpus takes ~100 GB on disk. We have used HuggingFace datasets to download & process the dataset.
|
||||
|
||||
```shell
|
||||
# just run following CMD
|
||||
python3 prepare_natural_questions.py
|
||||
|
||||
# this will download the whole dataset from HuggingFace Hub & will make it ready for training
|
||||
# this script takes ~3 hours to process the dataset
|
||||
```
|
||||
|
||||
**Launch Training**
|
||||
|
||||
We have trained on Cloud's TPU v3-8. Each epoch took around 4.5 hours and the model got converged in just 2 epochs. You can see complete training args in [this script](bigbird_flax.py).
|
||||
|
||||
```shell
|
||||
# just run following CMD
|
||||
python3 train.py
|
||||
|
||||
# In case, you want to try hparams tuning, you can run wandb sweep
|
||||
wandb sweep --project=bigbird sweep_flax.yaml
|
||||
wandb agent <agent-id-obtained-by-above-CMD>
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
Our evaluation script is different from the original script and we are evaluating sequences with length up to 4096 for simplicity. We managed to get the **EM score of ~55.2** using our evaluation script.
|
||||
|
||||
```shell
|
||||
# download validation-dataset first
|
||||
mkdir natural-questions-validation
|
||||
wget https://huggingface.co/datasets/vasudevgupta/natural-questions-validation/resolve/main/natural_questions-validation.arrow -P natural-questions-validation
|
||||
wget https://huggingface.co/datasets/vasudevgupta/natural-questions-validation/resolve/main/dataset_info.json -P natural-questions-validation
|
||||
wget https://huggingface.co/datasets/vasudevgupta/natural-questions-validation/resolve/main/state.json -P natural-questions-validation
|
||||
|
||||
# simply run following command
|
||||
python3 evaluate.py
|
||||
```
|
||||
|
||||
You can find our checkpoint on HuggingFace Hub ([see this](https://huggingface.co/vasudevgupta/flax-bigbird-natural-questions)). In case you are interested in PyTorch BigBird fine-tuning, you can refer to [this repository](https://github.com/thevasudevgupta/bigbird).
|
||||
@@ -1,323 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Callable
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import joblib
|
||||
import optax
|
||||
import wandb
|
||||
from flax import jax_utils, struct, traverse_util
|
||||
from flax.serialization import from_bytes, to_bytes
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import shard
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from transformers import BigBirdConfig, FlaxBigBirdForQuestionAnswering
|
||||
from transformers.models.big_bird.modeling_flax_big_bird import FlaxBigBirdForQuestionAnsweringModule
|
||||
|
||||
|
||||
class FlaxBigBirdForNaturalQuestionsModule(FlaxBigBirdForQuestionAnsweringModule):
|
||||
"""
|
||||
BigBirdForQuestionAnswering with CLS Head over the top for predicting category
|
||||
|
||||
This way we can load its weights with FlaxBigBirdForQuestionAnswering
|
||||
"""
|
||||
|
||||
config: BigBirdConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
add_pooling_layer: bool = True
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
self.cls = nn.Dense(5, dtype=self.dtype)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
outputs = super().__call__(*args, **kwargs)
|
||||
cls_out = self.cls(outputs[2])
|
||||
return outputs[:2] + (cls_out,)
|
||||
|
||||
|
||||
class FlaxBigBirdForNaturalQuestions(FlaxBigBirdForQuestionAnswering):
|
||||
module_class = FlaxBigBirdForNaturalQuestionsModule
|
||||
|
||||
|
||||
def calculate_loss_for_nq(start_logits, start_labels, end_logits, end_labels, pooled_logits, pooler_labels):
|
||||
def cross_entropy(logits, labels, reduction=None):
|
||||
"""
|
||||
Args:
|
||||
logits: bsz, seqlen, vocab_size
|
||||
labels: bsz, seqlen
|
||||
"""
|
||||
vocab_size = logits.shape[-1]
|
||||
labels = (labels[..., None] == jnp.arange(vocab_size)[None]).astype("f4")
|
||||
logits = jax.nn.log_softmax(logits, axis=-1)
|
||||
loss = -jnp.sum(labels * logits, axis=-1)
|
||||
if reduction is not None:
|
||||
loss = reduction(loss)
|
||||
return loss
|
||||
|
||||
cross_entropy = partial(cross_entropy, reduction=jnp.mean)
|
||||
start_loss = cross_entropy(start_logits, start_labels)
|
||||
end_loss = cross_entropy(end_logits, end_labels)
|
||||
pooled_loss = cross_entropy(pooled_logits, pooler_labels)
|
||||
return (start_loss + end_loss + pooled_loss) / 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
model_id: str = "google/bigbird-roberta-base"
|
||||
logging_steps: int = 3000
|
||||
save_steps: int = 10500
|
||||
|
||||
block_size: int = 128
|
||||
num_random_blocks: int = 3
|
||||
|
||||
batch_size_per_device: int = 1
|
||||
max_epochs: int = 5
|
||||
|
||||
# tx_args
|
||||
lr: float = 3e-5
|
||||
init_lr: float = 0.0
|
||||
warmup_steps: int = 20000
|
||||
weight_decay: float = 0.0095
|
||||
|
||||
save_dir: str = "bigbird-roberta-natural-questions"
|
||||
base_dir: str = "training-expt"
|
||||
tr_data_path: str = "data/nq-training.jsonl"
|
||||
val_data_path: str = "data/nq-validation.jsonl"
|
||||
|
||||
def __post_init__(self):
|
||||
os.makedirs(self.base_dir, exist_ok=True)
|
||||
self.save_dir = os.path.join(self.base_dir, self.save_dir)
|
||||
self.batch_size = self.batch_size_per_device * jax.device_count()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollator:
|
||||
pad_id: int
|
||||
max_length: int = 4096 # no dynamic padding on TPUs
|
||||
|
||||
def __call__(self, batch):
|
||||
batch = self.collate_fn(batch)
|
||||
batch = jax.tree_util.tree_map(shard, batch)
|
||||
return batch
|
||||
|
||||
def collate_fn(self, features):
|
||||
input_ids, attention_mask = self.fetch_inputs(features["input_ids"])
|
||||
batch = {
|
||||
"input_ids": jnp.array(input_ids, dtype=jnp.int32),
|
||||
"attention_mask": jnp.array(attention_mask, dtype=jnp.int32),
|
||||
"start_labels": jnp.array(features["start_token"], dtype=jnp.int32),
|
||||
"end_labels": jnp.array(features["end_token"], dtype=jnp.int32),
|
||||
"pooled_labels": jnp.array(features["category"], dtype=jnp.int32),
|
||||
}
|
||||
return batch
|
||||
|
||||
def fetch_inputs(self, input_ids: list):
|
||||
inputs = [self._fetch_inputs(ids) for ids in input_ids]
|
||||
return zip(*inputs)
|
||||
|
||||
def _fetch_inputs(self, input_ids: list):
|
||||
attention_mask = [1 for _ in range(len(input_ids))]
|
||||
while len(input_ids) < self.max_length:
|
||||
input_ids.append(self.pad_id)
|
||||
attention_mask.append(0)
|
||||
return input_ids, attention_mask
|
||||
|
||||
|
||||
def get_batched_dataset(dataset, batch_size, seed=None):
|
||||
if seed is not None:
|
||||
dataset = dataset.shuffle(seed=seed)
|
||||
for i in range(len(dataset) // batch_size):
|
||||
batch = dataset[i * batch_size : (i + 1) * batch_size]
|
||||
yield dict(batch)
|
||||
|
||||
|
||||
@partial(jax.pmap, axis_name="batch")
|
||||
def train_step(state, drp_rng, **model_inputs):
|
||||
def loss_fn(params):
|
||||
start_labels = model_inputs.pop("start_labels")
|
||||
end_labels = model_inputs.pop("end_labels")
|
||||
pooled_labels = model_inputs.pop("pooled_labels")
|
||||
|
||||
outputs = state.apply_fn(**model_inputs, params=params, dropout_rng=drp_rng, train=True)
|
||||
start_logits, end_logits, pooled_logits = outputs
|
||||
|
||||
return state.loss_fn(
|
||||
start_logits,
|
||||
start_labels,
|
||||
end_logits,
|
||||
end_labels,
|
||||
pooled_logits,
|
||||
pooled_labels,
|
||||
)
|
||||
|
||||
drp_rng, new_drp_rng = jax.random.split(drp_rng)
|
||||
grad_fn = jax.value_and_grad(loss_fn)
|
||||
loss, grads = grad_fn(state.params)
|
||||
metrics = jax.lax.pmean({"loss": loss}, axis_name="batch")
|
||||
grads = jax.lax.pmean(grads, "batch")
|
||||
|
||||
state = state.apply_gradients(grads=grads)
|
||||
return state, metrics, new_drp_rng
|
||||
|
||||
|
||||
@partial(jax.pmap, axis_name="batch")
|
||||
def val_step(state, **model_inputs):
|
||||
start_labels = model_inputs.pop("start_labels")
|
||||
end_labels = model_inputs.pop("end_labels")
|
||||
pooled_labels = model_inputs.pop("pooled_labels")
|
||||
|
||||
outputs = state.apply_fn(**model_inputs, params=state.params, train=False)
|
||||
start_logits, end_logits, pooled_logits = outputs
|
||||
|
||||
loss = state.loss_fn(start_logits, start_labels, end_logits, end_labels, pooled_logits, pooled_labels)
|
||||
metrics = jax.lax.pmean({"loss": loss}, axis_name="batch")
|
||||
return metrics
|
||||
|
||||
|
||||
class TrainState(train_state.TrainState):
|
||||
loss_fn: Callable = struct.field(pytree_node=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Trainer:
|
||||
args: Args
|
||||
data_collator: Callable
|
||||
train_step_fn: Callable
|
||||
val_step_fn: Callable
|
||||
model_save_fn: Callable
|
||||
logger: wandb
|
||||
scheduler_fn: Callable = None
|
||||
|
||||
def create_state(self, model, tx, num_train_steps, ckpt_dir=None):
|
||||
params = model.params
|
||||
state = TrainState.create(
|
||||
apply_fn=model.__call__,
|
||||
params=params,
|
||||
tx=tx,
|
||||
loss_fn=calculate_loss_for_nq,
|
||||
)
|
||||
if ckpt_dir is not None:
|
||||
params, opt_state, step, args, data_collator = restore_checkpoint(ckpt_dir, state)
|
||||
tx_args = {
|
||||
"lr": args.lr,
|
||||
"init_lr": args.init_lr,
|
||||
"warmup_steps": args.warmup_steps,
|
||||
"num_train_steps": num_train_steps,
|
||||
"weight_decay": args.weight_decay,
|
||||
}
|
||||
tx, lr = build_tx(**tx_args)
|
||||
state = train_state.TrainState(
|
||||
step=step,
|
||||
apply_fn=model.__call__,
|
||||
params=params,
|
||||
tx=tx,
|
||||
opt_state=opt_state,
|
||||
)
|
||||
self.args = args
|
||||
self.data_collator = data_collator
|
||||
self.scheduler_fn = lr
|
||||
model.params = params
|
||||
state = jax_utils.replicate(state)
|
||||
return state
|
||||
|
||||
def train(self, state, tr_dataset, val_dataset):
|
||||
args = self.args
|
||||
total = len(tr_dataset) // args.batch_size
|
||||
|
||||
rng = jax.random.PRNGKey(0)
|
||||
drp_rng = jax.random.split(rng, jax.device_count())
|
||||
for epoch in range(args.max_epochs):
|
||||
running_loss = jnp.array(0, dtype=jnp.float32)
|
||||
tr_dataloader = get_batched_dataset(tr_dataset, args.batch_size, seed=epoch)
|
||||
i = 0
|
||||
for batch in tqdm(tr_dataloader, total=total, desc=f"Running EPOCH-{epoch}"):
|
||||
batch = self.data_collator(batch)
|
||||
state, metrics, drp_rng = self.train_step_fn(state, drp_rng, **batch)
|
||||
running_loss += jax_utils.unreplicate(metrics["loss"])
|
||||
i += 1
|
||||
if i % args.logging_steps == 0:
|
||||
state_step = jax_utils.unreplicate(state.step)
|
||||
tr_loss = running_loss.item() / i
|
||||
lr = self.scheduler_fn(state_step - 1)
|
||||
|
||||
eval_loss = self.evaluate(state, val_dataset)
|
||||
logging_dict = {
|
||||
"step": state_step.item(),
|
||||
"eval_loss": eval_loss.item(),
|
||||
"tr_loss": tr_loss,
|
||||
"lr": lr.item(),
|
||||
}
|
||||
tqdm.write(str(logging_dict))
|
||||
self.logger.log(logging_dict, commit=True)
|
||||
|
||||
if i % args.save_steps == 0:
|
||||
self.save_checkpoint(args.save_dir + f"-e{epoch}-s{i}", state=state)
|
||||
|
||||
def evaluate(self, state, dataset):
|
||||
dataloader = get_batched_dataset(dataset, self.args.batch_size)
|
||||
total = len(dataset) // self.args.batch_size
|
||||
running_loss = jnp.array(0, dtype=jnp.float32)
|
||||
i = 0
|
||||
for batch in tqdm(dataloader, total=total, desc="Evaluating ... "):
|
||||
batch = self.data_collator(batch)
|
||||
metrics = self.val_step_fn(state, **batch)
|
||||
running_loss += jax_utils.unreplicate(metrics["loss"])
|
||||
i += 1
|
||||
return running_loss / i
|
||||
|
||||
def save_checkpoint(self, save_dir, state):
|
||||
state = jax_utils.unreplicate(state)
|
||||
print(f"SAVING CHECKPOINT IN {save_dir}", end=" ... ")
|
||||
self.model_save_fn(save_dir, params=state.params)
|
||||
with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
|
||||
f.write(to_bytes(state.opt_state))
|
||||
joblib.dump(self.args, os.path.join(save_dir, "args.joblib"))
|
||||
joblib.dump(self.data_collator, os.path.join(save_dir, "data_collator.joblib"))
|
||||
with open(os.path.join(save_dir, "training_state.json"), "w") as f:
|
||||
json.dump({"step": state.step.item()}, f)
|
||||
print("DONE")
|
||||
|
||||
|
||||
def restore_checkpoint(save_dir, state):
|
||||
print(f"RESTORING CHECKPOINT FROM {save_dir}", end=" ... ")
|
||||
with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f:
|
||||
params = from_bytes(state.params, f.read())
|
||||
|
||||
with open(os.path.join(save_dir, "opt_state.msgpack"), "rb") as f:
|
||||
opt_state = from_bytes(state.opt_state, f.read())
|
||||
|
||||
args = joblib.load(os.path.join(save_dir, "args.joblib"))
|
||||
data_collator = joblib.load(os.path.join(save_dir, "data_collator.joblib"))
|
||||
|
||||
with open(os.path.join(save_dir, "training_state.json"), "r") as f:
|
||||
training_state = json.load(f)
|
||||
step = training_state["step"]
|
||||
|
||||
print("DONE")
|
||||
return params, opt_state, step, args, data_collator
|
||||
|
||||
|
||||
def scheduler_fn(lr, init_lr, warmup_steps, num_train_steps):
|
||||
decay_steps = num_train_steps - warmup_steps
|
||||
warmup_fn = optax.linear_schedule(init_value=init_lr, end_value=lr, transition_steps=warmup_steps)
|
||||
decay_fn = optax.linear_schedule(init_value=lr, end_value=1e-7, transition_steps=decay_steps)
|
||||
lr = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[warmup_steps])
|
||||
return lr
|
||||
|
||||
|
||||
def build_tx(lr, init_lr, warmup_steps, num_train_steps, weight_decay):
|
||||
def weight_decay_mask(params):
|
||||
params = traverse_util.flatten_dict(params)
|
||||
mask = {k: (v[-1] != "bias" and v[-2:] != ("LayerNorm", "scale")) for k, v in params.items()}
|
||||
return traverse_util.unflatten_dict(mask)
|
||||
|
||||
lr = scheduler_fn(lr, init_lr, warmup_steps, num_train_steps)
|
||||
|
||||
tx = optax.adamw(learning_rate=lr, weight_decay=weight_decay, mask=weight_decay_mask)
|
||||
return tx, lr
|
||||
@@ -1,164 +0,0 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from bigbird_flax import FlaxBigBirdForNaturalQuestions
|
||||
from datasets import load_from_disk
|
||||
|
||||
from transformers import BigBirdTokenizerFast
|
||||
|
||||
|
||||
CATEGORY_MAPPING = {0: "null", 1: "short", 2: "long", 3: "yes", 4: "no"}
|
||||
PUNCTUATION_SET_TO_EXCLUDE = set("".join(["‘", "’", "´", "`", ".", ",", "-", '"']))
|
||||
|
||||
|
||||
def get_sub_answers(answers, begin=0, end=None):
|
||||
return [" ".join(x.split(" ")[begin:end]) for x in answers if len(x.split(" ")) > 1]
|
||||
|
||||
|
||||
def expand_to_aliases(given_answers, make_sub_answers=False):
|
||||
if make_sub_answers:
|
||||
# if answers are longer than one word, make sure a predictions is correct if it coresponds to the complete 1: or :-1 sub word
|
||||
# *e.g.* if the correct answer contains a prefix such as "the", or "a"
|
||||
given_answers = (
|
||||
given_answers + get_sub_answers(given_answers, begin=1) + get_sub_answers(given_answers, end=-1)
|
||||
)
|
||||
answers = []
|
||||
for answer in given_answers:
|
||||
alias = answer.replace("_", " ").lower()
|
||||
alias = "".join(c if c not in PUNCTUATION_SET_TO_EXCLUDE else " " for c in alias)
|
||||
answers.append(" ".join(alias.split()).strip())
|
||||
return set(answers)
|
||||
|
||||
|
||||
def get_best_valid_start_end_idx(start_scores, end_scores, top_k=1, max_size=100):
|
||||
best_start_scores, best_start_idx = jax.lax.top_k(start_scores, top_k)
|
||||
best_end_scores, best_end_idx = jax.lax.top_k(end_scores, top_k)
|
||||
|
||||
widths = best_end_idx[:, None] - best_start_idx[None, :]
|
||||
mask = jnp.logical_or(widths < 0, widths > max_size)
|
||||
scores = (best_end_scores[:, None] + best_start_scores[None, :]) - (1e8 * mask)
|
||||
best_score = jnp.argmax(scores).item()
|
||||
|
||||
return best_start_idx[best_score % top_k], best_end_idx[best_score // top_k]
|
||||
|
||||
|
||||
def format_dataset(sample):
|
||||
question = sample["question"]["text"]
|
||||
context = sample["document"]["tokens"]["token"]
|
||||
is_html = sample["document"]["tokens"]["is_html"]
|
||||
long_answers = sample["annotations"]["long_answer"]
|
||||
short_answers = sample["annotations"]["short_answers"]
|
||||
|
||||
context_string = " ".join([context[i] for i in range(len(context)) if not is_html[i]])
|
||||
|
||||
# 0 - No ; 1 - Yes
|
||||
for answer in sample["annotations"]["yes_no_answer"]:
|
||||
if answer == 0 or answer == 1:
|
||||
return {
|
||||
"question": question,
|
||||
"context": context_string,
|
||||
"short": [],
|
||||
"long": [],
|
||||
"category": "no" if answer == 0 else "yes",
|
||||
}
|
||||
|
||||
short_targets = []
|
||||
for s in short_answers:
|
||||
short_targets.extend(s["text"])
|
||||
short_targets = list(set(short_targets))
|
||||
|
||||
long_targets = []
|
||||
for s in long_answers:
|
||||
if s["start_token"] == -1:
|
||||
continue
|
||||
answer = context[s["start_token"] : s["end_token"]]
|
||||
html = is_html[s["start_token"] : s["end_token"]]
|
||||
new_answer = " ".join([answer[i] for i in range(len(answer)) if not html[i]])
|
||||
if new_answer not in long_targets:
|
||||
long_targets.append(new_answer)
|
||||
|
||||
category = "long_short" if len(short_targets + long_targets) > 0 else "null"
|
||||
|
||||
return {
|
||||
"question": question,
|
||||
"context": context_string,
|
||||
"short": short_targets,
|
||||
"long": long_targets,
|
||||
"category": category,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
dataset = load_from_disk("natural-questions-validation")
|
||||
dataset = dataset.map(format_dataset).remove_columns(["annotations", "document", "id"])
|
||||
print(dataset)
|
||||
|
||||
short_validation_dataset = dataset.filter(lambda x: (len(x["question"]) + len(x["context"])) < 4 * 4096)
|
||||
short_validation_dataset = short_validation_dataset.filter(lambda x: x["category"] != "null")
|
||||
|
||||
model_id = "vasudevgupta/flax-bigbird-natural-questions"
|
||||
model = FlaxBigBirdForNaturalQuestions.from_pretrained(model_id)
|
||||
tokenizer = BigBirdTokenizerFast.from_pretrained(model_id)
|
||||
|
||||
@jax.jit
|
||||
def forward(*args, **kwargs):
|
||||
start_logits, end_logits, pooled_logits = model(*args, **kwargs)
|
||||
return start_logits, end_logits, jnp.argmax(pooled_logits, axis=-1)
|
||||
|
||||
def evaluate(example):
|
||||
# encode question and context so that they are separated by a tokenizer.sep_token and cut at max_length
|
||||
inputs = tokenizer(
|
||||
example["question"],
|
||||
example["context"],
|
||||
return_tensors="np",
|
||||
max_length=4096,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
start_scores, end_scores, category = forward(**inputs)
|
||||
|
||||
predicted_category = CATEGORY_MAPPING[category.item()]
|
||||
|
||||
example["targets"] = example["long"] + example["short"]
|
||||
if example["category"] in ["yes", "no", "null"]:
|
||||
example["targets"] = [example["category"]]
|
||||
example["has_tgt"] = example["category"] != "null"
|
||||
# Now target can be: "yes", "no", "null", "list of long & short answers"
|
||||
|
||||
if predicted_category in ["yes", "no", "null"]:
|
||||
example["output"] = [predicted_category]
|
||||
example["match"] = example["output"] == example["targets"]
|
||||
example["has_pred"] = predicted_category != "null"
|
||||
return example
|
||||
|
||||
max_size = 38 if predicted_category == "short" else 1024
|
||||
start_score, end_score = get_best_valid_start_end_idx(
|
||||
start_scores[0], end_scores[0], top_k=8, max_size=max_size
|
||||
)
|
||||
|
||||
input_ids = inputs["input_ids"][0].tolist()
|
||||
example["output"] = [tokenizer.decode(input_ids[start_score : end_score + 1])]
|
||||
|
||||
answers = expand_to_aliases(example["targets"], make_sub_answers=True)
|
||||
predictions = expand_to_aliases(example["output"])
|
||||
|
||||
# some preprocessing to both prediction and answer
|
||||
answers = {"".join(a.split()) for a in answers}
|
||||
predictions = {"".join(p.split()) for p in predictions}
|
||||
predictions = {s for s in predictions if s not in ["``", "''", "`", "'"]}
|
||||
|
||||
# if there is a common element, it's a exact match
|
||||
example["match"] = len(list(answers & predictions)) > 0
|
||||
example["has_pred"] = predicted_category != "null" and len(predictions) > 0
|
||||
|
||||
return example
|
||||
|
||||
short_validation_dataset = short_validation_dataset.map(evaluate)
|
||||
|
||||
total = len(short_validation_dataset)
|
||||
matched = len(short_validation_dataset.filter(lambda x: x["match"] == 1))
|
||||
print("EM score:", (matched / total) * 100, "%")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,329 +0,0 @@
|
||||
import os
|
||||
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
DOC_STRIDE = 2048
|
||||
MAX_LENGTH = 4096
|
||||
SEED = 42
|
||||
PROCESS_TRAIN = os.environ.pop("PROCESS_TRAIN", "false")
|
||||
CATEGORY_MAPPING = {"null": 0, "short": 1, "long": 2, "yes": 3, "no": 4}
|
||||
|
||||
|
||||
def _get_single_answer(example):
|
||||
def choose_first(answer, is_long_answer=False):
|
||||
assert isinstance(answer, list)
|
||||
if len(answer) == 1:
|
||||
answer = answer[0]
|
||||
return {k: [answer[k]] for k in answer} if is_long_answer else answer
|
||||
for a in answer:
|
||||
if is_long_answer:
|
||||
a = {k: [a[k]] for k in a}
|
||||
if len(a["start_token"]) > 0:
|
||||
break
|
||||
return a
|
||||
|
||||
answer = {"id": example["id"]}
|
||||
annotation = example["annotations"]
|
||||
yes_no_answer = annotation["yes_no_answer"]
|
||||
if 0 in yes_no_answer or 1 in yes_no_answer:
|
||||
answer["category"] = ["yes"] if 1 in yes_no_answer else ["no"]
|
||||
answer["start_token"] = answer["end_token"] = []
|
||||
answer["start_byte"] = answer["end_byte"] = []
|
||||
answer["text"] = ["<cls>"]
|
||||
else:
|
||||
answer["category"] = ["short"]
|
||||
out = choose_first(annotation["short_answers"])
|
||||
if len(out["start_token"]) == 0:
|
||||
# answer will be long if short is not available
|
||||
answer["category"] = ["long"]
|
||||
out = choose_first(annotation["long_answer"], is_long_answer=True)
|
||||
out["text"] = []
|
||||
answer.update(out)
|
||||
|
||||
# disregard some samples
|
||||
if len(answer["start_token"]) > 1 or answer["start_token"] == answer["end_token"]:
|
||||
answer["remove_it"] = True
|
||||
else:
|
||||
answer["remove_it"] = False
|
||||
|
||||
cols = ["start_token", "end_token", "start_byte", "end_byte", "text"]
|
||||
if not all(isinstance(answer[k], list) for k in cols):
|
||||
raise ValueError("Issue in ID", example["id"])
|
||||
|
||||
return answer
|
||||
|
||||
|
||||
def get_context_and_ans(example, assertion=False):
|
||||
"""Gives new context after removing <html> & new answer tokens as per new context"""
|
||||
answer = _get_single_answer(example)
|
||||
# bytes are of no use
|
||||
del answer["start_byte"]
|
||||
del answer["end_byte"]
|
||||
|
||||
# handle yes_no answers explicitly
|
||||
if answer["category"][0] in ["yes", "no"]: # category is list with one element
|
||||
doc = example["document"]["tokens"]
|
||||
context = []
|
||||
for i in range(len(doc["token"])):
|
||||
if not doc["is_html"][i]:
|
||||
context.append(doc["token"][i])
|
||||
return {
|
||||
"context": " ".join(context),
|
||||
"answer": {
|
||||
"start_token": -100, # ignore index in cross-entropy
|
||||
"end_token": -100, # ignore index in cross-entropy
|
||||
"category": answer["category"],
|
||||
"span": answer["category"], # extra
|
||||
},
|
||||
}
|
||||
|
||||
# later, help in removing all no answers
|
||||
if answer["start_token"] == [-1]:
|
||||
return {
|
||||
"context": "None",
|
||||
"answer": {
|
||||
"start_token": -1,
|
||||
"end_token": -1,
|
||||
"category": "null",
|
||||
"span": "None", # extra
|
||||
},
|
||||
}
|
||||
|
||||
# handling normal samples
|
||||
|
||||
cols = ["start_token", "end_token"]
|
||||
answer.update({k: answer[k][0] if len(answer[k]) > 0 else answer[k] for k in cols}) # e.g. [10] == 10
|
||||
|
||||
doc = example["document"]["tokens"]
|
||||
start_token = answer["start_token"]
|
||||
end_token = answer["end_token"]
|
||||
|
||||
context = []
|
||||
for i in range(len(doc["token"])):
|
||||
if not doc["is_html"][i]:
|
||||
context.append(doc["token"][i])
|
||||
else:
|
||||
if answer["start_token"] > i:
|
||||
start_token -= 1
|
||||
if answer["end_token"] > i:
|
||||
end_token -= 1
|
||||
new = " ".join(context[start_token:end_token])
|
||||
|
||||
# checking above code
|
||||
if assertion:
|
||||
"""checking if above code is working as expected for all the samples"""
|
||||
is_html = doc["is_html"][answer["start_token"] : answer["end_token"]]
|
||||
old = doc["token"][answer["start_token"] : answer["end_token"]]
|
||||
old = " ".join([old[i] for i in range(len(old)) if not is_html[i]])
|
||||
if new != old:
|
||||
print("ID:", example["id"])
|
||||
print("New:", new, end="\n")
|
||||
print("Old:", old, end="\n\n")
|
||||
|
||||
return {
|
||||
"context": " ".join(context),
|
||||
"answer": {
|
||||
"start_token": start_token,
|
||||
"end_token": end_token - 1, # this makes it inclusive
|
||||
"category": answer["category"], # either long or short
|
||||
"span": new, # extra
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_strided_contexts_and_ans(example, tokenizer, doc_stride=2048, max_length=4096, assertion=True):
|
||||
# overlap will be of doc_stride - q_len
|
||||
|
||||
out = get_context_and_ans(example, assertion=assertion)
|
||||
answer = out["answer"]
|
||||
|
||||
# later, removing these samples
|
||||
if answer["start_token"] == -1:
|
||||
return {
|
||||
"example_id": example["id"],
|
||||
"input_ids": [[-1]],
|
||||
"labels": {
|
||||
"start_token": [-1],
|
||||
"end_token": [-1],
|
||||
"category": ["null"],
|
||||
},
|
||||
}
|
||||
|
||||
input_ids = tokenizer(example["question"]["text"], out["context"]).input_ids
|
||||
q_len = input_ids.index(tokenizer.sep_token_id) + 1
|
||||
|
||||
# return yes/no
|
||||
if answer["category"][0] in ["yes", "no"]: # category is list with one element
|
||||
inputs = []
|
||||
category = []
|
||||
q_indices = input_ids[:q_len]
|
||||
doc_start_indices = range(q_len, len(input_ids), max_length - doc_stride)
|
||||
for i in doc_start_indices:
|
||||
end_index = i + max_length - q_len
|
||||
slice = input_ids[i:end_index]
|
||||
inputs.append(q_indices + slice)
|
||||
category.append(answer["category"][0])
|
||||
if slice[-1] == tokenizer.sep_token_id:
|
||||
break
|
||||
|
||||
return {
|
||||
"example_id": example["id"],
|
||||
"input_ids": inputs,
|
||||
"labels": {
|
||||
"start_token": [-100] * len(category),
|
||||
"end_token": [-100] * len(category),
|
||||
"category": category,
|
||||
},
|
||||
}
|
||||
|
||||
splitted_context = out["context"].split()
|
||||
complete_end_token = splitted_context[answer["end_token"]]
|
||||
answer["start_token"] = len(
|
||||
tokenizer(
|
||||
" ".join(splitted_context[: answer["start_token"]]),
|
||||
add_special_tokens=False,
|
||||
).input_ids
|
||||
)
|
||||
answer["end_token"] = len(
|
||||
tokenizer(" ".join(splitted_context[: answer["end_token"]]), add_special_tokens=False).input_ids
|
||||
)
|
||||
|
||||
answer["start_token"] += q_len
|
||||
answer["end_token"] += q_len
|
||||
|
||||
# fixing end token
|
||||
num_sub_tokens = len(tokenizer(complete_end_token, add_special_tokens=False).input_ids)
|
||||
if num_sub_tokens > 1:
|
||||
answer["end_token"] += num_sub_tokens - 1
|
||||
|
||||
old = input_ids[answer["start_token"] : answer["end_token"] + 1] # right & left are inclusive
|
||||
start_token = answer["start_token"]
|
||||
end_token = answer["end_token"]
|
||||
|
||||
if assertion:
|
||||
"""This won't match exactly because of extra gaps => visaully inspect everything"""
|
||||
new = tokenizer.decode(old)
|
||||
if answer["span"] != new:
|
||||
print("ISSUE IN TOKENIZATION")
|
||||
print("OLD:", answer["span"])
|
||||
print("NEW:", new, end="\n\n")
|
||||
|
||||
if len(input_ids) <= max_length:
|
||||
return {
|
||||
"example_id": example["id"],
|
||||
"input_ids": [input_ids],
|
||||
"labels": {
|
||||
"start_token": [answer["start_token"]],
|
||||
"end_token": [answer["end_token"]],
|
||||
"category": answer["category"],
|
||||
},
|
||||
}
|
||||
|
||||
q_indices = input_ids[:q_len]
|
||||
doc_start_indices = range(q_len, len(input_ids), max_length - doc_stride)
|
||||
|
||||
inputs = []
|
||||
answers_start_token = []
|
||||
answers_end_token = []
|
||||
answers_category = [] # null, yes, no, long, short
|
||||
for i in doc_start_indices:
|
||||
end_index = i + max_length - q_len
|
||||
slice = input_ids[i:end_index]
|
||||
inputs.append(q_indices + slice)
|
||||
assert len(inputs[-1]) <= max_length, "Issue in truncating length"
|
||||
|
||||
if start_token >= i and end_token <= end_index - 1:
|
||||
start_token = start_token - i + q_len
|
||||
end_token = end_token - i + q_len
|
||||
answers_category.append(answer["category"][0]) # ["short"] -> "short"
|
||||
else:
|
||||
start_token = -100
|
||||
end_token = -100
|
||||
answers_category.append("null")
|
||||
new = inputs[-1][start_token : end_token + 1]
|
||||
|
||||
answers_start_token.append(start_token)
|
||||
answers_end_token.append(end_token)
|
||||
if assertion:
|
||||
"""checking if above code is working as expected for all the samples"""
|
||||
if new != old and new != [tokenizer.cls_token_id]:
|
||||
print("ISSUE in strided for ID:", example["id"])
|
||||
print("New:", tokenizer.decode(new))
|
||||
print("Old:", tokenizer.decode(old), end="\n\n")
|
||||
if slice[-1] == tokenizer.sep_token_id:
|
||||
break
|
||||
|
||||
return {
|
||||
"example_id": example["id"],
|
||||
"input_ids": inputs,
|
||||
"labels": {
|
||||
"start_token": answers_start_token,
|
||||
"end_token": answers_end_token,
|
||||
"category": answers_category,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def prepare_inputs(example, tokenizer, doc_stride=2048, max_length=4096, assertion=False):
|
||||
example = get_strided_contexts_and_ans(
|
||||
example,
|
||||
tokenizer,
|
||||
doc_stride=doc_stride,
|
||||
max_length=max_length,
|
||||
assertion=assertion,
|
||||
)
|
||||
|
||||
return example
|
||||
|
||||
|
||||
def save_to_disk(hf_data, file_name):
|
||||
with jsonlines.open(file_name, "a") as writer:
|
||||
for example in tqdm(hf_data, total=len(hf_data), desc="Saving samples ... "):
|
||||
labels = example["labels"]
|
||||
for ids, start, end, cat in zip(
|
||||
example["input_ids"],
|
||||
labels["start_token"],
|
||||
labels["end_token"],
|
||||
labels["category"],
|
||||
):
|
||||
if start == -1 and end == -1:
|
||||
continue # leave waste samples with no answer
|
||||
if cat == "null" and np.random.rand() < 0.6:
|
||||
continue # removing 50 % samples
|
||||
writer.write(
|
||||
{
|
||||
"input_ids": ids,
|
||||
"start_token": start,
|
||||
"end_token": end,
|
||||
"category": CATEGORY_MAPPING[cat],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""Running area"""
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import BigBirdTokenizer
|
||||
|
||||
data = load_dataset("natural_questions")
|
||||
tokenizer = BigBirdTokenizer.from_pretrained("google/bigbird-roberta-base")
|
||||
|
||||
data = data["train" if PROCESS_TRAIN == "true" else "validation"]
|
||||
|
||||
fn_kwargs = {
|
||||
"tokenizer": tokenizer,
|
||||
"doc_stride": DOC_STRIDE,
|
||||
"max_length": MAX_LENGTH,
|
||||
"assertion": False,
|
||||
}
|
||||
data = data.map(prepare_inputs, fn_kwargs=fn_kwargs)
|
||||
data = data.remove_columns(["annotations", "document", "id", "question"])
|
||||
print(data)
|
||||
|
||||
np.random.seed(SEED)
|
||||
cache_file_name = "nq-training.jsonl" if PROCESS_TRAIN == "true" else "nq-validation.jsonl"
|
||||
save_to_disk(data, file_name=cache_file_name)
|
||||
@@ -1,6 +0,0 @@
|
||||
git+https://github.com/huggingface/transformers@main
|
||||
datasets
|
||||
sentencepiece
|
||||
wandb
|
||||
flax
|
||||
jsonlines
|
||||
@@ -1,16 +0,0 @@
|
||||
command:
|
||||
- python3
|
||||
- train.py
|
||||
method: random
|
||||
parameters:
|
||||
lr:
|
||||
values: [4e-5, 3e-5]
|
||||
warmup_steps:
|
||||
values: [20000, 15000, 10000, 5000]
|
||||
weight_decay:
|
||||
distribution: normal
|
||||
mu: 1e-2
|
||||
sigma: 2e-3
|
||||
metric:
|
||||
name: eval_loss
|
||||
goal: minimize
|
||||
@@ -1,78 +0,0 @@
|
||||
import os
|
||||
from dataclasses import replace
|
||||
|
||||
import jax
|
||||
import wandb
|
||||
from bigbird_flax import Args, DataCollator, FlaxBigBirdForNaturalQuestions, Trainer, build_tx, train_step, val_step
|
||||
from datasets import load_dataset
|
||||
from flax import jax_utils
|
||||
|
||||
from transformers import BigBirdTokenizerFast
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("#################### AVAILABLE DEVICES ####################")
|
||||
print(jax.devices())
|
||||
print("###########################################################")
|
||||
|
||||
# setup for wandb sweep
|
||||
args = Args()
|
||||
logger = wandb.init(project="bigbird-natural-questions", config=args.__dict__)
|
||||
wandb_args = dict(logger.config)
|
||||
del wandb_args["batch_size"]
|
||||
args = replace(args, **wandb_args)
|
||||
base_dir = args.base_dir + "-" + wandb.run.id
|
||||
args = replace(args, base_dir=base_dir)
|
||||
print(args)
|
||||
|
||||
tr_dataset = load_dataset("json", data_files=args.tr_data_path)["train"]
|
||||
val_dataset = load_dataset("json", data_files=args.val_data_path)["train"]
|
||||
|
||||
# drop extra batch for now
|
||||
indices = range(len(tr_dataset) - len(tr_dataset) % args.batch_size)
|
||||
tr_dataset = tr_dataset.shuffle().select(indices)
|
||||
indices = range(len(val_dataset) - len(val_dataset) % args.batch_size)
|
||||
val_dataset = val_dataset.shuffle().select(indices)
|
||||
|
||||
if os.environ.get("TRAIN_ON_SMALL", "false") == "true":
|
||||
tr_dataset = tr_dataset.shuffle().select(range(80000))
|
||||
val_dataset = val_dataset.shuffle().select(range(8000))
|
||||
|
||||
print(tr_dataset)
|
||||
print(val_dataset)
|
||||
|
||||
model = FlaxBigBirdForNaturalQuestions.from_pretrained(
|
||||
args.model_id, block_size=args.block_size, num_random_blocks=args.num_random_blocks
|
||||
)
|
||||
tokenizer = BigBirdTokenizerFast.from_pretrained(args.model_id)
|
||||
data_collator = DataCollator(pad_id=tokenizer.pad_token_id, max_length=4096)
|
||||
|
||||
tx_args = {
|
||||
"lr": args.lr,
|
||||
"init_lr": args.init_lr,
|
||||
"warmup_steps": args.warmup_steps,
|
||||
"num_train_steps": args.max_epochs * (len(tr_dataset) // args.batch_size),
|
||||
"weight_decay": args.weight_decay,
|
||||
}
|
||||
tx, lr = build_tx(**tx_args)
|
||||
|
||||
trainer = Trainer(
|
||||
args=args,
|
||||
data_collator=data_collator,
|
||||
model_save_fn=model.save_pretrained,
|
||||
train_step_fn=train_step,
|
||||
val_step_fn=val_step,
|
||||
logger=logger,
|
||||
scheduler_fn=lr,
|
||||
)
|
||||
|
||||
ckpt_dir = None
|
||||
state = trainer.create_state(model, tx, num_train_steps=tx_args["num_train_steps"], ckpt_dir=ckpt_dir)
|
||||
try:
|
||||
trainer.train(state, tr_dataset, val_dataset)
|
||||
except KeyboardInterrupt:
|
||||
print("Oooops; TRAINING STOPPED UNFORTUNATELY")
|
||||
|
||||
print("SAVING WEIGHTS IN `final-weights`")
|
||||
params = jax_utils.unreplicate(state.params)
|
||||
model.save_pretrained(os.path.join(args.base_dir, "final-weights"), params=params)
|
||||
@@ -1,121 +0,0 @@
|
||||
<!---
|
||||
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
# Language model training examples in streaming mode
|
||||
|
||||
The following examples showcase how to train a language model from scratch
|
||||
using the JAX/Flax backend.
|
||||
|
||||
JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU.
|
||||
Models written in JAX/Flax are **immutable** and updated in a purely functional
|
||||
way which enables simple and efficient model parallelism.
|
||||
|
||||
All of the following examples make use of [dataset streaming](https://huggingface.co/docs/datasets/master/dataset_streaming), therefore allowing to train models on massive datasets\
|
||||
without ever having to download the full dataset.
|
||||
|
||||
## Masked language modeling
|
||||
|
||||
In the following, we demonstrate how to train a bi-directional transformer model
|
||||
using masked language modeling objective as introduced in [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805).
|
||||
More specifically, we demonstrate how JAX/Flax and dataset streaming can be leveraged
|
||||
to pre-train [**`FacebookAI/roberta-base`**](https://huggingface.co/FacebookAI/roberta-base)
|
||||
in English on a single TPUv3-8 pod for 10000 update steps.
|
||||
|
||||
The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.
|
||||
|
||||
Let's start by creating a model repository to save the trained model and logs.
|
||||
Here we call the model `"english-roberta-base-dummy"`, but you can change the model name as you like.
|
||||
|
||||
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
|
||||
you are logged in) or via the command line:
|
||||
|
||||
```bash
|
||||
huggingface-cli repo create english-roberta-base-dummy
|
||||
```
|
||||
|
||||
Next we clone the model repository to add the tokenizer and model files.
|
||||
|
||||
```bash
|
||||
git clone https://huggingface.co/<your-username>/english-roberta-base-dummy
|
||||
```
|
||||
|
||||
To ensure that all tensorboard traces will be uploaded correctly, we need to
|
||||
track them. You can run the following command inside your model repo to do so.
|
||||
|
||||
```bash
|
||||
cd english-roberta-base-dummy
|
||||
git lfs track "*tfevents*"
|
||||
```
|
||||
|
||||
Great, we have set up our model repository. During training, we will automatically
|
||||
push the training logs and model weights to the repo.
|
||||
|
||||
Next, let's add a symbolic link to the `run_mlm_flax.py`.
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="./english-roberta-base-dummy"
|
||||
ln -s ~/transformers/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py ./
|
||||
```
|
||||
|
||||
### Copy config and tokenizer of existing model
|
||||
|
||||
In this example, we will simply copy an existing config and tokenizer in English.
|
||||
You can run the following code in a Python shell to do so.
|
||||
|
||||
```python
|
||||
from transformers import RobertaTokenizerFast, RobertaConfig
|
||||
|
||||
model_dir = "./english-roberta-base-dummy"
|
||||
|
||||
tokenizer = RobertaTokenizerFast.from_pretrained("FacebookAI/roberta-base")
|
||||
config = RobertaConfig.from_pretrained("FacebookAI/roberta-base")
|
||||
|
||||
tokenizer.save_pretrained(model_dir)
|
||||
config.save_pretrained(model_dir)
|
||||
```
|
||||
|
||||
### Train model
|
||||
|
||||
Next we can run the example script to pretrain the model.
|
||||
Compared to the default [`run_mlm_flax`](https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_mlm_flax.py), we introduced 4 new training settings:
|
||||
- `num_train_steps` - how many update steps should be run.
|
||||
- `num_eval_samples` - how many training samples should be taken for evaluation.
|
||||
- `logging_steps` - at what rate should the training loss be logged.
|
||||
- `eval_steps` - at what rate should evaluation be run.
|
||||
10K update steps
|
||||
|
||||
```bash
|
||||
./run_mlm_flax_stream.py \
|
||||
--output_dir="${MODEL_DIR}" \
|
||||
--model_type="roberta" \
|
||||
--config_name="${MODEL_DIR}" \
|
||||
--tokenizer_name="${MODEL_DIR}" \
|
||||
--dataset_name="oscar" \
|
||||
--dataset_config_name="unshuffled_deduplicated_en" \
|
||||
--max_seq_length="128" \
|
||||
--per_device_train_batch_size="128" \
|
||||
--per_device_eval_batch_size="128" \
|
||||
--learning_rate="3e-4" \
|
||||
--warmup_steps="1000" \
|
||||
--overwrite_output_dir \
|
||||
--adam_beta1="0.9" \
|
||||
--adam_beta2="0.98" \
|
||||
--num_train_steps="10000" \
|
||||
--num_eval_samples="5000" \
|
||||
--logging_steps="250" \
|
||||
--eval_steps="1000" \
|
||||
--push_to_hub
|
||||
```
|
||||
@@ -1,637 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
|
||||
text file or a dataset.
|
||||
|
||||
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
|
||||
https://huggingface.co/models?filter=fill-mask
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import datasets
|
||||
import flax
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
from datasets import load_dataset
|
||||
from flax import jax_utils, traverse_util
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
FlaxAutoModelForMaskedLM,
|
||||
HfArgumentParser,
|
||||
PreTrainedTokenizerBase,
|
||||
TensorType,
|
||||
TrainingArguments,
|
||||
is_tensorboard_available,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
|
||||
if datasets.__version__ <= "1.8.0":
|
||||
raise ValueError("Make sure to upgrade `datasets` to a version >= 1.9.0 to use dataset streaming")
|
||||
|
||||
|
||||
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
||||
"""
|
||||
|
||||
model_name_or_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
|
||||
)
|
||||
},
|
||||
)
|
||||
model_type: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
||||
)
|
||||
use_fast_tokenizer: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
||||
)
|
||||
dtype: Optional[str] = field(
|
||||
default="float32",
|
||||
metadata={
|
||||
"help": (
|
||||
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
|
||||
" `[float32, float16, bfloat16]`."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
dataset_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
dataset_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
||||
validation_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
||||
)
|
||||
train_ref_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
|
||||
)
|
||||
validation_ref_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
validation_split_percentage: Optional[int] = field(
|
||||
default=5,
|
||||
metadata={
|
||||
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
||||
},
|
||||
)
|
||||
max_seq_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated. Default to the max input length of the model."
|
||||
)
|
||||
},
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
mlm_probability: float = field(
|
||||
default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
|
||||
)
|
||||
pad_to_max_length: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": (
|
||||
"Whether to pad all samples to `max_seq_length`. "
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
||||
)
|
||||
},
|
||||
)
|
||||
line_by_line: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
|
||||
)
|
||||
text_column_name: str = field(
|
||||
default="text", metadata={"help": "The name of the column to retrieve the training text."}
|
||||
)
|
||||
shuffle_buffer_size: int = field(
|
||||
default=10000, metadata={"help": "The number of examples to pre-load for shuffling."}
|
||||
)
|
||||
num_train_steps: int = field(default=50000, metadata={"help": "The number of training steps."})
|
||||
num_eval_samples: int = field(default=50000, metadata={"help": "The number of samples to be used for evaluation"})
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||
else:
|
||||
if self.train_file is not None:
|
||||
extension = self.train_file.split(".")[-1]
|
||||
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
|
||||
if self.validation_file is not None:
|
||||
extension = self.validation_file.split(".")[-1]
|
||||
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxDataCollatorForLanguageModeling:
|
||||
"""
|
||||
Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
|
||||
are not all of the same length.
|
||||
|
||||
Args:
|
||||
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
||||
The tokenizer used for encoding the data.
|
||||
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
|
||||
The probability with which to (randomly) mask tokens in the input.
|
||||
|
||||
.. note::
|
||||
|
||||
For best performance, this data collator should be used with a dataset having items that are dictionaries or
|
||||
BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
|
||||
:class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
|
||||
argument :obj:`return_special_tokens_mask=True`.
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
mlm_probability: float = 0.15
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer.mask_token is None:
|
||||
raise ValueError(
|
||||
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
|
||||
"You should pass `mlm=False` to train on causal language modeling instead."
|
||||
)
|
||||
|
||||
def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]:
|
||||
# Handle dict or lists with proper padding and conversion to tensor.
|
||||
batch = self.tokenizer.pad(examples, return_tensors=TensorType.NUMPY)
|
||||
|
||||
# If special token mask has been preprocessed, pop it from the dict.
|
||||
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
||||
|
||||
batch["input_ids"], batch["labels"] = self.mask_tokens(
|
||||
batch["input_ids"], special_tokens_mask=special_tokens_mask
|
||||
)
|
||||
return batch
|
||||
|
||||
def mask_tokens(
|
||||
self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
|
||||
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
||||
"""
|
||||
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
||||
"""
|
||||
labels = inputs.copy()
|
||||
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
|
||||
probability_matrix = np.full(labels.shape, self.mlm_probability)
|
||||
special_tokens_mask = special_tokens_mask.astype("bool")
|
||||
|
||||
probability_matrix[special_tokens_mask] = 0.0
|
||||
masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
|
||||
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
||||
|
||||
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
||||
indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
|
||||
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
||||
|
||||
# 10% of the time, we replace masked input tokens with random word
|
||||
indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
|
||||
indices_random &= masked_indices & ~indices_replaced
|
||||
|
||||
random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
|
||||
inputs[indices_random] = random_words[indices_random]
|
||||
|
||||
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
||||
return inputs, labels
|
||||
|
||||
|
||||
def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray:
|
||||
num_samples = len(samples_idx)
|
||||
samples_to_remove = num_samples % batch_size
|
||||
|
||||
if samples_to_remove != 0:
|
||||
samples_idx = samples_idx[:-samples_to_remove]
|
||||
sections_split = num_samples // batch_size
|
||||
batch_idx = np.split(samples_idx, sections_split)
|
||||
return batch_idx
|
||||
|
||||
|
||||
def advance_iter_and_group_samples(train_iterator, num_samples, max_seq_length):
|
||||
"""
|
||||
The training iterator is advanced so that after groupifying the samples,
|
||||
`num_samples` of length `max_seq_length` are returned.
|
||||
"""
|
||||
num_total_tokens = max_seq_length * num_samples
|
||||
samples = defaultdict(list)
|
||||
|
||||
i = 0
|
||||
while i < num_total_tokens:
|
||||
tokenized_samples = next(train_iterator)
|
||||
i += len(tokenized_samples["input_ids"])
|
||||
|
||||
# concatenate tokenized samples to list (excluding "id" and "text")
|
||||
samples = {
|
||||
k: samples[k] + tokenized_samples[k] for k in ["input_ids", "attention_mask", "special_tokens_mask"]
|
||||
}
|
||||
|
||||
# Concatenated tokens are split to lists of length `max_seq_length`.
|
||||
# Note that remainedr of % max_seq_length are thrown away.
|
||||
def group_texts(examples):
|
||||
result = {
|
||||
k: [t[i : i + max_seq_length] for i in range(0, num_total_tokens, max_seq_length)]
|
||||
for k, t in examples.items()
|
||||
}
|
||||
return result
|
||||
|
||||
grouped_samples = group_texts(samples)
|
||||
return grouped_samples
|
||||
|
||||
|
||||
def write_train_metric(summary_writer, train_metrics, train_time, step):
|
||||
summary_writer.scalar("train_time", train_time, step)
|
||||
|
||||
train_metrics = get_metrics(train_metrics)
|
||||
for key, vals in train_metrics.items():
|
||||
tag = f"train_{key}"
|
||||
for i, val in enumerate(vals):
|
||||
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
||||
|
||||
|
||||
def write_eval_metric(summary_writer, eval_metrics, step):
|
||||
for metric_name, value in eval_metrics.items():
|
||||
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
if (
|
||||
os.path.exists(training_args.output_dir)
|
||||
and os.listdir(training_args.output_dir)
|
||||
and training_args.do_train
|
||||
and not training_args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
||||
"Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
level="INFO",
|
||||
datefmt="[%X]",
|
||||
)
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
|
||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# Set seed before initializing model.
|
||||
set_seed(training_args.seed)
|
||||
|
||||
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
||||
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
||||
# (the dataset will be downloaded automatically from the datasets Hub).
|
||||
#
|
||||
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
|
||||
# 'text' is found. You can easily tweak this behavior (see below).
|
||||
if data_args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
dataset = load_dataset(
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
cache_dir=model_args.cache_dir,
|
||||
streaming=True,
|
||||
split="train",
|
||||
)
|
||||
|
||||
if model_args.config_name:
|
||||
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
|
||||
elif model_args.model_name_or_path:
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
|
||||
else:
|
||||
config = CONFIG_MAPPING[model_args.model_type]()
|
||||
logger.warning("You are instantiating a new config instance from scratch.")
|
||||
|
||||
if model_args.tokenizer_name:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
||||
)
|
||||
elif model_args.model_name_or_path:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are instantiating a new tokenizer from scratch. This is not supported by this script. "
|
||||
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
||||
)
|
||||
|
||||
# Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
|
||||
# We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
|
||||
# efficient when it receives the `special_tokens_mask`.
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(examples[data_args.text_column_name], return_special_tokens_mask=True)
|
||||
|
||||
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=list(dataset.features.keys()))
|
||||
|
||||
shuffle_seed = training_args.seed
|
||||
tokenized_datasets = tokenized_datasets.shuffle(buffer_size=data_args.shuffle_buffer_size, seed=shuffle_seed)
|
||||
|
||||
has_tensorboard = is_tensorboard_available()
|
||||
if has_tensorboard and jax.process_index() == 0:
|
||||
try:
|
||||
from flax.metrics.tensorboard import SummaryWriter
|
||||
except ImportError as ie:
|
||||
has_tensorboard = False
|
||||
logger.warning(
|
||||
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
|
||||
)
|
||||
|
||||
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
||||
|
||||
# Data collator
|
||||
# This one will take care of randomly masking the tokens.
|
||||
data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
|
||||
|
||||
# Initialize our training
|
||||
rng = jax.random.PRNGKey(training_args.seed)
|
||||
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
||||
|
||||
if model_args.model_name_or_path:
|
||||
model = FlaxAutoModelForMaskedLM.from_pretrained(
|
||||
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
||||
)
|
||||
else:
|
||||
model = FlaxAutoModelForMaskedLM.from_config(
|
||||
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
||||
)
|
||||
|
||||
# Store some constant
|
||||
num_epochs = int(training_args.num_train_epochs)
|
||||
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
||||
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
||||
|
||||
# define number steps per stream epoch
|
||||
num_train_steps = data_args.num_train_steps
|
||||
|
||||
# Create learning rate schedule
|
||||
warmup_fn = optax.linear_schedule(
|
||||
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
|
||||
)
|
||||
decay_fn = optax.linear_schedule(
|
||||
init_value=training_args.learning_rate,
|
||||
end_value=0,
|
||||
transition_steps=num_train_steps - training_args.warmup_steps,
|
||||
)
|
||||
linear_decay_lr_schedule_fn = optax.join_schedules(
|
||||
schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
|
||||
)
|
||||
|
||||
# We use Optax's "masking" functionality to not apply weight decay
|
||||
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
||||
# mask boolean with the same structure as the parameters.
|
||||
# The mask is True for parameters that should be decayed.
|
||||
# Note that this mask is specifically adapted for FlaxBERT-like models.
|
||||
# For other models, one should correct the layer norm parameter naming
|
||||
# accordingly.
|
||||
def decay_mask_fn(params):
|
||||
flat_params = traverse_util.flatten_dict(params)
|
||||
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
|
||||
return traverse_util.unflatten_dict(flat_mask)
|
||||
|
||||
# create adam optimizer
|
||||
adamw = optax.adamw(
|
||||
learning_rate=linear_decay_lr_schedule_fn,
|
||||
b1=training_args.adam_beta1,
|
||||
b2=training_args.adam_beta2,
|
||||
eps=training_args.adam_epsilon,
|
||||
weight_decay=training_args.weight_decay,
|
||||
mask=decay_mask_fn,
|
||||
)
|
||||
|
||||
# Setup train state
|
||||
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
|
||||
|
||||
# Define gradient update step fn
|
||||
def train_step(state, batch, dropout_rng):
|
||||
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
||||
|
||||
def loss_fn(params):
|
||||
labels = batch.pop("labels")
|
||||
|
||||
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
||||
|
||||
# compute loss, ignore padded input tokens
|
||||
label_mask = jnp.where(labels > 0, 1.0, 0.0)
|
||||
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
||||
|
||||
# take average
|
||||
loss = loss.sum() / label_mask.sum()
|
||||
|
||||
return loss
|
||||
|
||||
grad_fn = jax.value_and_grad(loss_fn)
|
||||
loss, grad = grad_fn(state.params)
|
||||
grad = jax.lax.pmean(grad, "batch")
|
||||
new_state = state.apply_gradients(grads=grad)
|
||||
|
||||
metrics = jax.lax.pmean(
|
||||
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
|
||||
)
|
||||
|
||||
return new_state, metrics, new_dropout_rng
|
||||
|
||||
# Create parallel version of the train step
|
||||
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
||||
|
||||
# Define eval fn
|
||||
def eval_step(params, batch):
|
||||
labels = batch.pop("labels")
|
||||
|
||||
logits = model(**batch, params=params, train=False)[0]
|
||||
|
||||
# compute loss, ignore padded input tokens
|
||||
label_mask = jnp.where(labels > 0, 1.0, 0.0)
|
||||
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
||||
|
||||
# compute accuracy
|
||||
accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
|
||||
|
||||
# summarize metrics
|
||||
metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
|
||||
metrics = jax.lax.psum(metrics, axis_name="batch")
|
||||
|
||||
return metrics
|
||||
|
||||
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
|
||||
|
||||
# Replicate the train state on each device
|
||||
state = jax_utils.replicate(state)
|
||||
|
||||
train_time = 0
|
||||
train_start = time.time()
|
||||
train_metrics = []
|
||||
eval_metrics = []
|
||||
|
||||
training_iter = iter(tokenized_datasets)
|
||||
|
||||
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
||||
eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
|
||||
|
||||
steps = tqdm(range(num_train_steps), desc="Training...", position=0)
|
||||
for step in range(num_train_steps):
|
||||
# ======================== Training ================================
|
||||
try:
|
||||
samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
|
||||
except StopIteration:
|
||||
# Once the end of the dataset stream is reached, the training iterator
|
||||
# is reinitialized and reshuffled and a new eval dataset is randomly chosen.
|
||||
shuffle_seed += 1
|
||||
tokenized_datasets.set_epoch(shuffle_seed)
|
||||
|
||||
training_iter = iter(tokenized_datasets)
|
||||
|
||||
eval_dataset = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
|
||||
samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
|
||||
|
||||
# process input samples
|
||||
model_inputs = data_collator(samples)
|
||||
|
||||
# Model forward
|
||||
model_inputs = shard(model_inputs.data)
|
||||
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
||||
|
||||
train_metrics.append(train_metric)
|
||||
|
||||
if step % training_args.logging_steps == 0 and step > 0:
|
||||
steps.write(
|
||||
f"Step... ({step} | Loss: {train_metric['loss'].mean()}, Learning Rate:"
|
||||
f" {train_metric['learning_rate'].mean()})"
|
||||
)
|
||||
train_time += time.time() - train_start
|
||||
if has_tensorboard and jax.process_index() == 0:
|
||||
write_train_metric(summary_writer, train_metrics, train_time, step)
|
||||
train_metrics = []
|
||||
|
||||
# ======================== Evaluating ==============================
|
||||
if step % training_args.eval_steps == 0 and step > 0:
|
||||
# Avoid using jax.numpy here in case of TPU training
|
||||
eval_samples_idx = np.arange(data_args.num_eval_samples)
|
||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||
|
||||
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=1)):
|
||||
# process input samples
|
||||
batch_eval_samples = {k: [v[idx] for idx in batch_idx] for k, v in eval_samples.items()}
|
||||
model_inputs = data_collator(batch_eval_samples)
|
||||
|
||||
# Model forward
|
||||
model_inputs = shard(model_inputs.data)
|
||||
metrics = p_eval_step(state.params, model_inputs)
|
||||
eval_metrics.append(metrics)
|
||||
|
||||
# normalize eval metrics
|
||||
eval_metrics = get_metrics(eval_metrics)
|
||||
eval_metrics = jax.tree_util.tree_map(jnp.sum, eval_metrics)
|
||||
eval_normalizer = eval_metrics.pop("normalizer")
|
||||
eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
||||
|
||||
# Update progress bar
|
||||
steps.desc = (
|
||||
f"Step... ({step + 1}/{num_train_steps} | Loss: {eval_metrics['loss']}, Acc:"
|
||||
f" {eval_metrics['accuracy']})"
|
||||
)
|
||||
|
||||
if has_tensorboard and jax.process_index() == 0:
|
||||
write_eval_metric(summary_writer, eval_metrics, step)
|
||||
eval_metrics = []
|
||||
|
||||
# save checkpoint after each epoch and push checkpoint to the hub
|
||||
if jax.process_index() == 0:
|
||||
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
||||
model.save_pretrained(
|
||||
training_args.output_dir,
|
||||
params=params,
|
||||
push_to_hub=training_args.push_to_hub,
|
||||
commit_message=f"Saving weights and logs of step {step+1}",
|
||||
)
|
||||
|
||||
# update tqdm bar
|
||||
steps.update(1)
|
||||
@@ -1,172 +0,0 @@
|
||||
<!---
|
||||
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
# Vision-Text dual encoder model training examples
|
||||
|
||||
> Note: This example is experimental and might not give the best possible results
|
||||
|
||||
The following example showcases how to train a CLIP like vision-text dual encoder model
|
||||
using a pre-trained vision and text encoder using the JAX/Flax backend.
|
||||
|
||||
Such a model can be used for natural language image search and potentially zero-shot image classification.
|
||||
The model is inspired by the [CLIP](https://openai.com/blog/clip/) approach, introduced by Alec Radford et al.
|
||||
The idea is to train a vision encoder and a text encoder jointly to project the representation of images and their
|
||||
captions into the same embedding space, such that the caption embeddings are located near the embeddings
|
||||
of the images they describe.
|
||||
|
||||
JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU.
|
||||
Models written in JAX/Flax are **immutable** and updated in a purely functional
|
||||
way which enables simple and efficient model parallelism.
|
||||
|
||||
In this example we will use the vision model from [CLIP](https://huggingface.co/models?filter=clip)
|
||||
as the image encoder and [`FacebookAI/roberta-base`](https://huggingface.co/FacebookAI/roberta-base) as the text encoder.
|
||||
Note that one can also use the [ViT](https://huggingface.co/models?filter=vit) model as image encoder and any other BERT or ROBERTa model as text encoder.
|
||||
To train the model on languages other than English one should choose a text encoder trained on the desired
|
||||
language and a image-text dataset in that language. One such dataset is [WIT](https://github.com/google-research-datasets/wit).
|
||||
|
||||
Let's start by creating a model repository to save the trained model and logs.
|
||||
Here we call the model `"clip-roberta-base"`, but you can change the model name as you like.
|
||||
|
||||
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
|
||||
you are logged in) or via the command line:
|
||||
|
||||
```bash
|
||||
huggingface-cli repo create clip-roberta-base
|
||||
```
|
||||
Next we clone the model repository to add the tokenizer and model files.
|
||||
```bash
|
||||
git clone https://huggingface.co/<your-username>/clip-roberta-base
|
||||
```
|
||||
To ensure that all tensorboard traces will be uploaded correctly, we need to
|
||||
track them. You can run the following command inside your model repo to do so.
|
||||
|
||||
```bash
|
||||
cd clip-roberta-base
|
||||
git lfs track "*tfevents*"
|
||||
```
|
||||
|
||||
Great, we have set up our model repository. During training, we will automatically
|
||||
push the training logs and model weights to the repo.
|
||||
|
||||
Next, let's add a symbolic link to the `run_hybrid_clip.py`.
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="./clip-roberta-base
|
||||
ln -s ~/transformers/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py run_hybrid_clip.py
|
||||
```
|
||||
|
||||
## How to use the `FlaxHybridCLIP` model:
|
||||
|
||||
The `FlaxHybridCLIP` class let's you load any text and vision encoder model to create a dual encoder.
|
||||
Here is an example of how to load the model using pre-trained text and vision models.
|
||||
|
||||
```python
|
||||
from modeling_hybrid_clip import FlaxHybridCLIP
|
||||
|
||||
model = FlaxHybridCLIP.from_text_vision_pretrained("google-bert/bert-base-uncased", "openai/clip-vit-base-patch32")
|
||||
|
||||
# save the model
|
||||
model.save_pretrained("bert-clip")
|
||||
|
||||
# load the saved model
|
||||
model = FlaxHybridCLIP.from_pretrained("bert-clip")
|
||||
```
|
||||
|
||||
If the checkpoints are in PyTorch then one could pass `text_from_pt=True` and `vision_from_pt=True`. This will load the model
|
||||
PyTorch checkpoints convert them to flax and load the model.
|
||||
|
||||
```python
|
||||
model = FlaxHybridCLIP.from_text_vision_pretrained("google-bert/bert-base-uncased", "openai/clip-vit-base-patch32", text_from_pt=True, vision_from_pt=True)
|
||||
```
|
||||
|
||||
This loads both the text and vision encoders using pre-trained weights, the projection layers are randomly
|
||||
initialized except for CLIP's vision model. If you use CLIP to initialize the vision model then the vision projection weights are also
|
||||
loaded using the pre-trained weights.
|
||||
|
||||
## Prepare the dataset
|
||||
|
||||
We will use the MS-COCO dataset to train our dual encoder model. MS-COCO contains over 82,000 images, each of which has at least 5 different caption annotations. The dataset is usually used for image captioning tasks, but we can repurpose the image-caption pairs to train our dual encoder model for image search.
|
||||
|
||||
### Download and extract the data.
|
||||
|
||||
It consists of two compressed folders: one with images, and the other—with associated image captions. Note that the compressed images folder is 13GB in size.
|
||||
|
||||
```bash
|
||||
wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip
|
||||
wget http://images.cocodataset.org/zips/train2014.zip
|
||||
|
||||
unzip annotations_trainval2014.zip
|
||||
unzip train2014.zip
|
||||
|
||||
mkdir coco_dataset
|
||||
mv train2014 coco_dataset/
|
||||
mv annotations coco_dataset/
|
||||
```
|
||||
|
||||
### Prepare dataset files and split the dataset.
|
||||
|
||||
```python
|
||||
import json
|
||||
import collections
|
||||
|
||||
images_dir = "coco_dataset/train2014"
|
||||
annotation_file = "coco_dataset/annotations/captions_train2014.json"
|
||||
with open(annotation_file, "r") as f:
|
||||
annotations = json.load(f)["annotations"]
|
||||
|
||||
image_path_to_caption = collections.defaultdict(list)
|
||||
for element in annotations:
|
||||
caption = f"{element['caption'].lower().rstrip('.')}"
|
||||
image_path = images_dir + "/COCO_train2014_" + "%012d.jpg" % (element["image_id"])
|
||||
image_path_to_caption[image_path].append(caption)
|
||||
|
||||
lines = []
|
||||
for image_path, captions in image_path_to_caption.items():
|
||||
lines.append(json.dumps({"image_path": image_path, "captions": captions}))
|
||||
|
||||
train_lines = lines[:-8000]
|
||||
valid_line = lines[-8000:]
|
||||
with open("coco_dataset/train_dataset.json", "w") as f:
|
||||
f.write("\n".join(train_lines))
|
||||
|
||||
with open("coco_dataset/valid_dataset.json", "w") as f:
|
||||
f.write("\n".join(valid_line))
|
||||
```
|
||||
|
||||
> Note: The data loading and processing part of this script can still be improved for maximum performance. In particular one should decode the images beforehand and use those instead decoding them each time. If the dataset is small or if you have huge disk space the you could also pre-process all the dataset beforehand and then use it.
|
||||
|
||||
## Train the model
|
||||
Next we can run the example script to train the model:
|
||||
|
||||
```bash
|
||||
python run_hybrid_clip.py \
|
||||
--output_dir ${MODEL_DIR} \
|
||||
--text_model_name_or_path="FacebookAI/roberta-base" \
|
||||
--vision_model_name_or_path="openai/clip-vit-base-patch32" \
|
||||
--tokenizer_name="FacebookAI/roberta-base" \
|
||||
--train_file="coco_dataset/train_dataset.json" \
|
||||
--validation_file="coco_dataset/validation_dataset.json" \
|
||||
--do_train --do_eval \
|
||||
--num_train_epochs="40" --max_seq_length 96 \
|
||||
--per_device_train_batch_size="64" \
|
||||
--per_device_eval_batch_size="64" \
|
||||
--learning_rate="5e-5" --warmup_steps="0" --weight_decay 0.1 \
|
||||
--overwrite_output_dir \
|
||||
--preprocessing_num_workers 32 \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
This should finish in ~1h50 mins with min validation loss 2.43. Training statistics can be accessed on [tfhub.de](https://tensorboard.dev/experiment/RUNPYd1yRgSD5kZSb9hDig/#scalars)
|
||||
@@ -1,112 +0,0 @@
|
||||
import copy
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class HybridCLIPConfig(PretrainedConfig):
|
||||
r"""
|
||||
:class:`HybridCLIPConfig` is the configuration class to store the configuration of a
|
||||
:class:`~HybridCLIPModel`. It is used to instantiate HybridCLIPModel model according to the specified arguments,
|
||||
defining the text model and vision model configs.
|
||||
|
||||
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
|
||||
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
|
||||
|
||||
Args:
|
||||
text_config_dict (:obj:`dict`):
|
||||
Dictionary of configuration options that defines text model config.
|
||||
vision_config_dict (:obj:`dict`):
|
||||
Dictionary of configuration options that defines vison model config.
|
||||
projection_dim (:obj:`int`, `optional`, defaults to 512):
|
||||
Dimentionality of text and vision projection layers.
|
||||
kwargs (`optional`):
|
||||
Dictionary of keyword arguments.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import BertConfig, CLIPConfig, HybridCLIPConfig, FlaxHybridCLIP
|
||||
|
||||
>>> # Initializing a BERT and CLIP configuration
|
||||
>>> config_text = BertConfig()
|
||||
>>> config_vision = CLIPConfig()
|
||||
|
||||
>>> config = HybridCLIPConfig.from_text_vision_configs(config_text, config_vision, projection_dim=512)
|
||||
|
||||
>>> # Initializing a BERT and CLIPVision model
|
||||
>>> model = EncoderDecoderModel(config=config)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> config_text = model.config.text_config
|
||||
>>> config_vision = model.config.vision_config
|
||||
|
||||
>>> # Saving the model, including its configuration
|
||||
>>> model.save_pretrained('my-model')
|
||||
|
||||
>>> # loading model and config from pretrained folder
|
||||
>>> encoder_decoder_config = HybridCLIPConfig.from_pretrained('my-model')
|
||||
>>> model = FlaxHybridCLIP.from_pretrained('my-model', config=encoder_decoder_config)
|
||||
"""
|
||||
|
||||
model_type = "hybrid-clip"
|
||||
is_composition = True
|
||||
|
||||
def __init__(self, projection_dim=512, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if "text_config" not in kwargs:
|
||||
raise ValueError("`text_config` can not be `None`.")
|
||||
|
||||
if "vision_config" not in kwargs:
|
||||
raise ValueError("`vision_config` can not be `None`.")
|
||||
|
||||
text_config = kwargs.pop("text_config")
|
||||
vision_config = kwargs.pop("vision_config")
|
||||
|
||||
text_model_type = text_config.pop("model_type")
|
||||
vision_model_type = vision_config.pop("model_type")
|
||||
|
||||
from transformers import AutoConfig
|
||||
|
||||
self.text_config = AutoConfig.for_model(text_model_type, **text_config)
|
||||
|
||||
if vision_model_type == "clip":
|
||||
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config
|
||||
elif vision_model_type == "clip_vision_model":
|
||||
from transformers import CLIPVisionConfig
|
||||
|
||||
self.vision_config = CLIPVisionConfig(**vision_config)
|
||||
else:
|
||||
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)
|
||||
|
||||
self.projection_dim = projection_dim
|
||||
self.initializer_factor = 1.0
|
||||
|
||||
@classmethod
|
||||
def from_text_vision_configs(cls, text_config: PretrainedConfig, vision_config: PretrainedConfig, **kwargs):
|
||||
r"""
|
||||
Instantiate a :class:`HybridCLIPConfig` (or a derived class) from text model configuration and
|
||||
vision model configuration.
|
||||
|
||||
Returns:
|
||||
:class:`HybridCLIPConfig`: An instance of a configuration object
|
||||
"""
|
||||
|
||||
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default
|
||||
:meth:`~transformers.PretrainedConfig.to_dict`.
|
||||
|
||||
Returns:
|
||||
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
output["text_config"] = self.text_config.to_dict()
|
||||
output["vision_config"] = self.vision_config.to_dict()
|
||||
output["model_type"] = self.__class__.model_type
|
||||
return output
|
||||
@@ -1,420 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from configuration_hybrid_clip import HybridCLIPConfig
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
|
||||
from transformers import FLAX_MODEL_MAPPING, FlaxCLIPVisionModel
|
||||
from transformers.modeling_flax_utils import FlaxPreTrainedModel
|
||||
from transformers.models.clip.modeling_flax_clip import FlaxCLIPOutput
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class FlaxHybridCLIPModule(nn.Module):
|
||||
config: HybridCLIPConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
text_config = self.config.text_config
|
||||
vision_config = self.config.vision_config
|
||||
|
||||
self.projection_dim = self.config.projection_dim
|
||||
self.text_embed_dim = text_config.hidden_size
|
||||
self.vision_embed_dim = vision_config.hidden_size
|
||||
|
||||
text_module = FLAX_MODEL_MAPPING[self.config.text_config.__class__].module_class
|
||||
vision_module = FLAX_MODEL_MAPPING.get(self.config.vision_config.__class__, FlaxCLIPVisionModel).module_class
|
||||
|
||||
self.text_model = text_module(text_config, dtype=self.dtype)
|
||||
self.vision_model = vision_module(vision_config, dtype=self.dtype)
|
||||
|
||||
self.visual_projection = nn.Dense(
|
||||
self.projection_dim,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(0.02),
|
||||
use_bias=False,
|
||||
)
|
||||
self.text_projection = nn.Dense(
|
||||
self.projection_dim,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(0.02),
|
||||
use_bias=False,
|
||||
)
|
||||
self.logit_scale = self.param("logit_scale", jax.nn.initializers.ones, [])
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids=None,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
token_type_ids=None,
|
||||
deterministic: bool = True,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
text_outputs = self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
image_embeds = vision_outputs[1]
|
||||
image_embeds = self.visual_projection(image_embeds)
|
||||
|
||||
text_embeds = text_outputs[1]
|
||||
text_embeds = self.text_projection(text_embeds)
|
||||
|
||||
# normalized features
|
||||
image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=-1, keepdims=True)
|
||||
text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True)
|
||||
|
||||
# cosine similarity as logits
|
||||
logit_scale = jnp.exp(self.logit_scale)
|
||||
logits_per_text = jnp.matmul(text_embeds, image_embeds.T) * logit_scale
|
||||
logits_per_image = logits_per_text.T
|
||||
|
||||
if not return_dict:
|
||||
return (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
||||
|
||||
return FlaxCLIPOutput(
|
||||
logits_per_image=logits_per_image,
|
||||
logits_per_text=logits_per_text,
|
||||
text_embeds=text_embeds,
|
||||
image_embeds=image_embeds,
|
||||
text_model_output=text_outputs,
|
||||
vision_model_output=vision_outputs,
|
||||
)
|
||||
|
||||
|
||||
class FlaxHybridCLIP(FlaxPreTrainedModel):
|
||||
config_class = HybridCLIPConfig
|
||||
module_class = FlaxHybridCLIPModule
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: HybridCLIPConfig,
|
||||
input_shape: Optional[Tuple] = None,
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
**kwargs,
|
||||
):
|
||||
if input_shape is None:
|
||||
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
|
||||
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensor
|
||||
input_ids = jnp.zeros(input_shape[0], dtype="i4")
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
||||
pixel_values = jax.random.normal(rng, input_shape[1])
|
||||
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids)["params"]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
pixel_values,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
token_type_ids=None,
|
||||
params: dict = None,
|
||||
dropout_rng: jax.random.PRNGKey = None,
|
||||
train: bool = False,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = jnp.zeros_like(input_ids)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(pixel_values, dtype=jnp.float32),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
not train,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
def get_text_features(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
token_type_ids=None,
|
||||
params: dict = None,
|
||||
dropout_rng: jax.random.PRNGKey = None,
|
||||
train=False,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||
provide it.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
|
||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
|
||||
for details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
|
||||
Returns:
|
||||
text_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The text embeddings
|
||||
obtained by applying the projection layer to the pooled output of text model.
|
||||
"""
|
||||
if position_ids is None:
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = jnp.zeros_like(input_ids)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
def _get_features(module, input_ids, attention_mask, position_ids, token_type_ids, deterministic):
|
||||
text_outputs = module.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
deterministic=deterministic,
|
||||
)
|
||||
pooled_output = text_outputs[1]
|
||||
text_features = module.text_projection(pooled_output)
|
||||
return text_features
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
not train,
|
||||
method=_get_features,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
def get_image_features(
|
||||
self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train=False
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained
|
||||
using :class:`~transformers.ImageFeatureExtractionMixin`. See
|
||||
:meth:`transformers.ImageFeatureExtractionMixin.__call__` for details.
|
||||
|
||||
Returns:
|
||||
image_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The image embeddings
|
||||
obtained by applying the projection layer to the pooled output of vision model.
|
||||
"""
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
def _get_features(module, pixel_values, deterministic):
|
||||
vision_outputs = module.vision_model(pixel_values=pixel_values, deterministic=deterministic)
|
||||
pooled_output = vision_outputs[1] # pooled_output
|
||||
image_features = module.visual_projection(pooled_output)
|
||||
return image_features
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(pixel_values, dtype=jnp.float32),
|
||||
not train,
|
||||
method=_get_features,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_text_vision_pretrained(
|
||||
cls,
|
||||
text_model_name_or_path: str = None,
|
||||
vision_model_name_or_path: str = None,
|
||||
*model_args,
|
||||
**kwargs,
|
||||
) -> FlaxPreTrainedModel:
|
||||
"""
|
||||
Params:
|
||||
text_model_name_or_path (:obj: `str`, `optional`):
|
||||
Information necessary to initiate the text model. Can be either:
|
||||
|
||||
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
- A path to a `directory` containing model weights saved using
|
||||
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||
- A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In
|
||||
this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
|
||||
as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in
|
||||
a Flax model using the provided conversion scripts and loading the Flax model afterwards.
|
||||
|
||||
vision_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
|
||||
Information necessary to initiate the vision model. Can be either:
|
||||
|
||||
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
- A path to a `directory` containing model weights saved using
|
||||
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||
- A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In
|
||||
this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
|
||||
as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in
|
||||
a Flax model using the provided conversion scripts and loading the Flax model afterwards.
|
||||
|
||||
model_args (remaining positional arguments, `optional`):
|
||||
All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
|
||||
|
||||
kwargs (remaining dictionary of keyword arguments, `optional`):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
:obj:`output_attentions=True`).
|
||||
|
||||
- To update the text configuration, use the prefix `text_` for each configuration parameter.
|
||||
- To update the vision configuration, use the prefix `vision_` for each configuration parameter.
|
||||
- To update the parent model configuration, do not use a prefix for each configuration parameter.
|
||||
|
||||
Behaves differently depending on whether a :obj:`config` is provided or automatically loaded.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import FlaxHybridCLIP
|
||||
>>> # initialize a model from pretrained BERT and CLIP models. Note that the projection layers will be randomly initialized.
|
||||
>>> # If using CLIP's vision model the vision projection layer will be initialized using pre-trained weights
|
||||
>>> model = FlaxHybridCLIP.from_text_vision_pretrained('google-bert/bert-base-uncased', 'openai/clip-vit-base-patch32')
|
||||
>>> # saving model after fine-tuning
|
||||
>>> model.save_pretrained("./bert-clip")
|
||||
>>> # load fine-tuned model
|
||||
>>> model = FlaxHybridCLIP.from_pretrained("./bert-clip")
|
||||
"""
|
||||
|
||||
kwargs_text = {
|
||||
argument[len("text_") :]: value for argument, value in kwargs.items() if argument.startswith("text_")
|
||||
}
|
||||
|
||||
kwargs_vision = {
|
||||
argument[len("vision_") :]: value for argument, value in kwargs.items() if argument.startswith("vision_")
|
||||
}
|
||||
|
||||
# remove text, vision kwargs from kwargs
|
||||
for key in kwargs_text.keys():
|
||||
del kwargs["text_" + key]
|
||||
for key in kwargs_vision.keys():
|
||||
del kwargs["vision_" + key]
|
||||
|
||||
# Load and initialize the text and vision model
|
||||
text_model = kwargs_text.pop("model", None)
|
||||
if text_model is None:
|
||||
assert (
|
||||
text_model_name_or_path is not None
|
||||
), "If `model` is not defined as an argument, a `text_model_name_or_path` has to be defined"
|
||||
from transformers import FlaxAutoModel
|
||||
|
||||
if "config" not in kwargs_text:
|
||||
from transformers import AutoConfig
|
||||
|
||||
text_config = AutoConfig.from_pretrained(text_model_name_or_path)
|
||||
kwargs_text["config"] = text_config
|
||||
|
||||
text_model = FlaxAutoModel.from_pretrained(text_model_name_or_path, *model_args, **kwargs_text)
|
||||
|
||||
vision_model = kwargs_vision.pop("model", None)
|
||||
if vision_model is None:
|
||||
assert (
|
||||
vision_model_name_or_path is not None
|
||||
), "If `model` is not defined as an argument, a `vision_model_name_or_path` has to be defined"
|
||||
from transformers import FlaxAutoModel
|
||||
|
||||
if "config" not in kwargs_vision:
|
||||
from transformers import AutoConfig
|
||||
|
||||
vision_config = AutoConfig.from_pretrained(vision_model_name_or_path)
|
||||
kwargs_vision["config"] = vision_config
|
||||
|
||||
vision_model = FlaxAutoModel.from_pretrained(vision_model_name_or_path, *model_args, **kwargs_vision)
|
||||
|
||||
# instantiate config with corresponding kwargs
|
||||
dtype = kwargs.pop("dtype", jnp.float32)
|
||||
config = HybridCLIPConfig.from_text_vision_configs(text_model.config, vision_model.config, **kwargs)
|
||||
|
||||
# init model
|
||||
model = cls(config, *model_args, dtype=dtype, **kwargs)
|
||||
|
||||
if vision_config.model_type == "clip":
|
||||
model.params["vision_model"]["vision_model"] = vision_model.params["vision_model"]
|
||||
model.params["visual_projection"]["kernel"] = vision_model.params["visual_projection"]["kernel"]
|
||||
else:
|
||||
model.params["vision_model"] = vision_model.params
|
||||
|
||||
model.params["text_model"] = text_model.params
|
||||
|
||||
return model
|
||||
@@ -1,8 +0,0 @@
|
||||
jax>=0.2.8
|
||||
jaxlib>=0.1.59
|
||||
flax>=0.3.5
|
||||
optax>=0.0.8
|
||||
-f https://download.pytorch.org/whl/torch_stable.html
|
||||
torch==2.2.0
|
||||
-f https://download.pytorch.org/whl/torch_stable.html
|
||||
torchvision==0.10.0+cpu
|
||||
@@ -1,576 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Training a CLIP like dual encoder models using text and vision encoders in the library.
|
||||
|
||||
The script can be used to train CLIP like models for languages other than english by using
|
||||
a text encoder pre-trained in the desired language. Currently this script support the following vision
|
||||
and text models:
|
||||
Vision models: ViT(https://huggingface.co/models?filter=vit), CLIP (https://huggingface.co/models?filter=clip)
|
||||
Text models: BERT, ROBERTa (https://huggingface.co/models?filter=fill-mask)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
import torch
|
||||
from flax import jax_utils
|
||||
from flax.jax_utils import unreplicate
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, shard, shard_prng_key
|
||||
from modeling_hybrid_clip import FlaxHybridCLIP
|
||||
from torchvision.datasets import VisionDataset
|
||||
from torchvision.io import ImageReadMode, read_image
|
||||
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from tqdm import tqdm
|
||||
|
||||
import transformers
|
||||
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, is_tensorboard_available, set_seed
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache the result
|
||||
has_tensorboard = is_tensorboard_available()
|
||||
if has_tensorboard:
|
||||
try:
|
||||
from flax.metrics.tensorboard import SummaryWriter
|
||||
except ImportError as ie:
|
||||
has_tensorboard = False
|
||||
print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
|
||||
|
||||
else:
|
||||
print(
|
||||
"Unable to display metrics through TensorBoard because the package is not installed: "
|
||||
"Please run pip install tensorboard to enable."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
||||
"""
|
||||
|
||||
text_model_name_or_path: str = field(
|
||||
metadata={
|
||||
"help": (
|
||||
"The text model checkpoint for weights initialization. "
|
||||
"Don't set if you want to train a model from scratch."
|
||||
)
|
||||
},
|
||||
)
|
||||
vision_model_name_or_path: str = field(
|
||||
metadata={
|
||||
"help": (
|
||||
"The vision model checkpoint for weights initialization. "
|
||||
"Don't set if you want to train a model from scratch."
|
||||
)
|
||||
},
|
||||
)
|
||||
from_pt: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "whether to load the text and vision model using PyTorch checkpoints."},
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
||||
)
|
||||
use_fast_tokenizer: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
||||
)
|
||||
dtype: Optional[str] = field(
|
||||
default="float32",
|
||||
metadata={
|
||||
"help": (
|
||||
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
|
||||
" `[float32, float16, bfloat16]`."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
data_dir: Optional[str] = field(default=None, metadata={"help": "The data directory containing input files."})
|
||||
train_file: Optional[str] = field(
|
||||
default=None, metadata={"help": "The input training data file (a jsonlines file)."}
|
||||
)
|
||||
validation_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "An optional input evaluation data file (a jsonlines file)."},
|
||||
)
|
||||
max_seq_length: Optional[int] = field(
|
||||
default=72,
|
||||
metadata={
|
||||
"help": (
|
||||
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
)
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
)
|
||||
},
|
||||
)
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
)
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.train_file is None and self.validation_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||
else:
|
||||
if self.train_file is not None:
|
||||
extension = self.train_file.split(".")[-1]
|
||||
assert extension == "json", "`train_file` should be a json file."
|
||||
if self.validation_file is not None:
|
||||
extension = self.validation_file.split(".")[-1]
|
||||
assert extension == "json", "`validation_file` should be a json file."
|
||||
|
||||
|
||||
# We use torchvision for faster image pre-processing.
|
||||
# We need to ensure faster processing speed as it can become a bottleneck on TPU
|
||||
class Transform(torch.nn.Module):
|
||||
def __init__(self, image_size):
|
||||
super().__init__()
|
||||
self.transforms = torch.nn.Sequential(
|
||||
Resize([image_size], interpolation=InterpolationMode.BICUBIC),
|
||||
CenterCrop(image_size),
|
||||
ConvertImageDtype(torch.float),
|
||||
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
x = self.transforms(x)
|
||||
return x
|
||||
|
||||
|
||||
class ImageTextDataset(VisionDataset):
|
||||
"""
|
||||
Dtaset for loading image-text data for tasks like CLIP training, Image Captioning.
|
||||
|
||||
Args:
|
||||
root: (string): The root path where the dataset is stored
|
||||
file_path: (string): Path to the file containing the image_paths and associated captions.
|
||||
The expected format is jsonlines where each line is a json object containing to keys.
|
||||
`image_path`: The path to the image.
|
||||
`captions`: An `array` of captions.
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.ToTensor``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
transforms (callable, optional): A function/transform that takes input sample and its target as entry
|
||||
and returns a transformed version.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
file_path: str,
|
||||
captions_per_image=2,
|
||||
transform: Optional[Callable] = None,
|
||||
target_transform: Optional[Callable] = None,
|
||||
transforms: Optional[Callable] = None,
|
||||
):
|
||||
super().__init__(root, transforms, transform, target_transform)
|
||||
|
||||
with open(file_path, "r") as f:
|
||||
examples = [json.loads(line) for line in f.readlines()]
|
||||
|
||||
self.captions = []
|
||||
self.image_paths = []
|
||||
|
||||
for example in examples:
|
||||
captions_subset = example["captions"][:captions_per_image]
|
||||
self.captions.extend(captions_subset)
|
||||
self.image_paths.extend([example["image_path"]] * len(captions_subset))
|
||||
|
||||
def _load_image(self, idx: int):
|
||||
path = self.image_paths[idx]
|
||||
return read_image(path, mode=ImageReadMode.RGB)
|
||||
|
||||
def _load_target(self, idx):
|
||||
return self.captions[idx]
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
image = self._load_image(index)
|
||||
target = self._load_target(index)
|
||||
|
||||
if self.transforms is not None:
|
||||
image, target = self.transforms(image, target)
|
||||
|
||||
return image, target
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.captions)
|
||||
|
||||
|
||||
class TrainState(train_state.TrainState):
|
||||
dropout_rng: jnp.ndarray
|
||||
|
||||
def replicate(self):
|
||||
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
||||
|
||||
|
||||
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
||||
summary_writer.scalar("train_time", train_time, step)
|
||||
|
||||
train_metrics = get_metrics(train_metrics)
|
||||
for key, vals in train_metrics.items():
|
||||
tag = f"train_{key}"
|
||||
for i, val in enumerate(vals):
|
||||
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
||||
|
||||
for metric_name, value in eval_metrics.items():
|
||||
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
||||
|
||||
|
||||
def create_learning_rate_fn(
|
||||
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
||||
) -> Callable[[int], jnp.ndarray]:
|
||||
"""Returns a linear warmup, linear_decay learning rate function."""
|
||||
steps_per_epoch = train_ds_size // train_batch_size
|
||||
num_train_steps = steps_per_epoch * num_train_epochs
|
||||
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
||||
decay_fn = optax.linear_schedule(
|
||||
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
||||
)
|
||||
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
||||
return schedule_fn
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
if (
|
||||
os.path.exists(training_args.output_dir)
|
||||
and os.listdir(training_args.output_dir)
|
||||
and training_args.do_train
|
||||
and not training_args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
||||
"Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
# Setup logging, we only want one process per machine to log things on the screen.
|
||||
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
|
||||
if jax.process_index() == 0:
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
|
||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
if model_args.tokenizer_name:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
||||
)
|
||||
elif model_args.text_model_name_or_path:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.text_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are instantiating a new tokenizer from scratch. This is not supported by this script. "
|
||||
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
||||
)
|
||||
|
||||
model = FlaxHybridCLIP.from_text_vision_pretrained(
|
||||
model_args.text_model_name_or_path,
|
||||
model_args.vision_model_name_or_path,
|
||||
seed=training_args.seed,
|
||||
dtype=getattr(jnp, model_args.dtype),
|
||||
text_from_pt=model_args.from_pt,
|
||||
vision_from_pt=model_args.from_pt,
|
||||
)
|
||||
config = model.config
|
||||
# set seed for torch dataloaders
|
||||
set_seed(training_args.seed)
|
||||
|
||||
# Initialize torchvision transforms and jit them for faster processing
|
||||
preprocess = Transform(config.vision_config.image_size)
|
||||
preprocess = torch.jit.script(preprocess)
|
||||
|
||||
# Initialize the image-text dataset
|
||||
train_dataset = ImageTextDataset(
|
||||
data_args.data_dir,
|
||||
data_args.train_file,
|
||||
captions_per_image=2,
|
||||
transform=preprocess,
|
||||
)
|
||||
|
||||
eval_dataset = ImageTextDataset(
|
||||
data_args.data_dir,
|
||||
data_args.validation_file,
|
||||
captions_per_image=1,
|
||||
transform=preprocess,
|
||||
)
|
||||
|
||||
# Store some constant
|
||||
num_epochs = int(training_args.num_train_epochs)
|
||||
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
||||
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
||||
steps_per_epoch = len(train_dataset) // train_batch_size
|
||||
total_train_steps = steps_per_epoch * num_epochs
|
||||
|
||||
# Use collate function to tokenizer the text and convert the processed images to numpy
|
||||
def collate_fn(examples):
|
||||
pixel_values = torch.stack([example[0] for example in examples]).permute(0, 2, 3, 1).numpy()
|
||||
captions = [example[1] for example in examples]
|
||||
inputs = tokenizer(
|
||||
captions, max_length=data_args.max_seq_length, padding="max_length", truncation=True, return_tensors="np"
|
||||
)
|
||||
|
||||
batch = {
|
||||
"pixel_values": pixel_values,
|
||||
"input_ids": inputs["input_ids"],
|
||||
"attention_mask": inputs["attention_mask"],
|
||||
}
|
||||
|
||||
return batch
|
||||
|
||||
# Create data loaders
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=train_batch_size,
|
||||
shuffle=True,
|
||||
num_workers=data_args.preprocessing_num_workers,
|
||||
persistent_workers=True,
|
||||
drop_last=True,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
eval_loader = torch.utils.data.DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=eval_batch_size,
|
||||
shuffle=False,
|
||||
num_workers=data_args.preprocessing_num_workers,
|
||||
persistent_workers=True,
|
||||
drop_last=True,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
# Enable tensorboard only on the master node
|
||||
if has_tensorboard and jax.process_index() == 0:
|
||||
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix())
|
||||
|
||||
# Initialize our training
|
||||
rng = jax.random.PRNGKey(training_args.seed)
|
||||
rng, dropout_rng = jax.random.split(rng)
|
||||
|
||||
# Create learning rate schedule
|
||||
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
||||
len(train_dataset),
|
||||
train_batch_size,
|
||||
training_args.num_train_epochs,
|
||||
training_args.warmup_steps,
|
||||
training_args.learning_rate,
|
||||
)
|
||||
|
||||
# create adam optimizer
|
||||
adamw = optax.adamw(
|
||||
learning_rate=linear_decay_lr_schedule_fn,
|
||||
b1=training_args.adam_beta1,
|
||||
b2=training_args.adam_beta2,
|
||||
eps=training_args.adam_epsilon,
|
||||
weight_decay=training_args.weight_decay,
|
||||
)
|
||||
|
||||
# Setup train state
|
||||
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
|
||||
|
||||
def cross_entropy(logits, axis):
|
||||
logprobs = jax.nn.log_softmax(logits, axis=axis)
|
||||
nll = jnp.diag(logprobs)
|
||||
ce = -jnp.mean(nll)
|
||||
return ce
|
||||
|
||||
def clip_loss(similarity):
|
||||
loss = (cross_entropy(similarity, axis=0) + cross_entropy(similarity, axis=1)) / 2
|
||||
return loss
|
||||
|
||||
# Define gradient update step fn
|
||||
def train_step(state, batch):
|
||||
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
||||
|
||||
def compute_loss(params):
|
||||
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
||||
loss = clip_loss(logits)
|
||||
return loss
|
||||
|
||||
grad_fn = jax.value_and_grad(compute_loss)
|
||||
loss, grad = grad_fn(state.params)
|
||||
grad = jax.lax.pmean(grad, "batch")
|
||||
|
||||
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
||||
|
||||
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
||||
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
||||
|
||||
return new_state, metrics
|
||||
|
||||
# Define eval fn
|
||||
def eval_step(params, batch):
|
||||
logits = model(**batch, params=params, train=False)[0]
|
||||
loss = clip_loss(logits)
|
||||
|
||||
# summarize metrics
|
||||
metrics = {"loss": loss}
|
||||
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
||||
return metrics
|
||||
|
||||
# Create parallel version of the train and eval step
|
||||
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
||||
p_eval_step = jax.pmap(eval_step, "batch")
|
||||
|
||||
# Replicate the train state on each device
|
||||
state = state.replicate()
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num Epochs = {num_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
|
||||
logger.info(f" Total optimization steps = {total_train_steps}")
|
||||
|
||||
train_time = 0
|
||||
# Create sampling rng
|
||||
rng, input_rng = jax.random.split(rng)
|
||||
|
||||
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||
for epoch in epochs:
|
||||
# ======================== Training ================================
|
||||
train_start = time.time()
|
||||
|
||||
# Create sampling rng
|
||||
rng, input_rng = jax.random.split(rng)
|
||||
train_metrics = []
|
||||
|
||||
steps_per_epoch = len(train_dataset) // train_batch_size
|
||||
train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
|
||||
# train
|
||||
for batch in train_loader:
|
||||
batch = shard(batch)
|
||||
state, train_metric = p_train_step(state, batch)
|
||||
train_metrics.append(train_metric)
|
||||
|
||||
train_step_progress_bar.update(1)
|
||||
|
||||
train_time += time.time() - train_start
|
||||
|
||||
train_metric = unreplicate(train_metric)
|
||||
|
||||
train_step_progress_bar.close()
|
||||
epochs.write(
|
||||
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:"
|
||||
f" {train_metric['learning_rate']})"
|
||||
)
|
||||
|
||||
# ======================== Evaluating ==============================
|
||||
eval_metrics = []
|
||||
eval_steps = len(eval_dataset) // eval_batch_size
|
||||
eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False)
|
||||
for batch in eval_loader:
|
||||
# Model forward
|
||||
batch = shard(batch)
|
||||
metrics = p_eval_step(state.params, batch)
|
||||
eval_metrics.append(metrics)
|
||||
|
||||
eval_step_progress_bar.update(1)
|
||||
|
||||
# normalize eval metrics
|
||||
eval_metrics = get_metrics(eval_metrics)
|
||||
|
||||
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
|
||||
|
||||
# Print metrics and update progress bar
|
||||
eval_step_progress_bar.close()
|
||||
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
||||
epochs.write(desc)
|
||||
epochs.desc = desc
|
||||
|
||||
# Save metrics
|
||||
if has_tensorboard and jax.process_index() == 0:
|
||||
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
||||
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
|
||||
|
||||
# save checkpoint after each epoch and push checkpoint to the hub
|
||||
if jax.process_index() == 0:
|
||||
params = jax.device_get(unreplicate(state.params))
|
||||
model.save_pretrained(
|
||||
training_args.output_dir,
|
||||
params=params,
|
||||
push_to_hub=training_args.push_to_hub,
|
||||
commit_message=f"Saving weights and logs of epoch {epoch+1}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,67 +0,0 @@
|
||||
<!---
|
||||
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
# Model parallel language model training example
|
||||
|
||||
The following example showcases how to train/fine-tune GPTNeo model with model parallelism using
|
||||
the JAX/Flax backend and the [`pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html) transformation.
|
||||
|
||||
> Note: The example is experimental and might have bugs. Also currently it only supports single V3-8.
|
||||
|
||||
The `partition.py` file defines the `PyTree` of `ParitionSpec` for the GPTNeo model which describes how the model will be sharded.
|
||||
The actual sharding is auto-matically handled by `pjit`. The weights are sharded across all local devices.
|
||||
To adapt the script for other models, we need to also change the `ParitionSpec` accordingly.
|
||||
|
||||
TODO: Add more explantion.
|
||||
|
||||
Before training, let's prepare our model first. To be able to shard the model, the sharded dimension needs to be a multiple of devices it'll be sharded on. But GPTNeo's vocab size is 50257, so we need to resize the embeddings accordingly.
|
||||
|
||||
```python
|
||||
from transformers import FlaxGPTNeoForCausalLM, GPTNeoConfig
|
||||
model = FlaxGPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
|
||||
|
||||
emb = jnp.zeros((50264, model.config.hidden_size))
|
||||
# update the first 50257 weights using pre-trained weights
|
||||
emb = emb.at[:50257, :].set(model.params["transformer"]["wte"]["embedding"])
|
||||
params = model.params
|
||||
params["transformer"]["wte"]["embedding"] = emb
|
||||
|
||||
# initialize a random model with the right vocab_size
|
||||
config = GPTNeoConfig.from_pretrained("EleutherAI/gpt-neo-1.3B", vocab_size=50264)
|
||||
model = FlaxGPTNeoForCausalLM(config)
|
||||
|
||||
# assign the pre-trained weights and save the model.
|
||||
model.params = params
|
||||
model.save_pretrained("gpt-neo-1.3B")
|
||||
```
|
||||
|
||||
|
||||
### Train Model
|
||||
|
||||
```bash
|
||||
python run_clm_mp.py \
|
||||
--model_name_or_path gpt-neo-1.3B \
|
||||
--tokenizer_name openai-community/gpt2 \
|
||||
--dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 \
|
||||
--do_train --do_eval \
|
||||
--block_size 1024 \
|
||||
--num_train_epochs 5 \
|
||||
--learning_rate 4e-6 \
|
||||
--per_device_train_batch_size 3 --per_device_eval_batch_size 3 \
|
||||
--overwrite_output_dir --output_dir ~/tmp/flax-clm \
|
||||
--cache_dir ~/datasets_cache/wikitext --dtype bfloat16 \
|
||||
--logging_steps 96 --eval_steps 96
|
||||
```
|
||||
@@ -1,85 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The Google Research Authors and The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities for constructing PyTrees of PartitionSpecs."""
|
||||
|
||||
# utils adapted from https://github.com/google-research/google-research/blob/master/flax_models/t5x/partitions.py
|
||||
|
||||
import re
|
||||
|
||||
from flax.core.frozen_dict import freeze
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax.experimental import PartitionSpec as P
|
||||
|
||||
|
||||
# Sentinels
|
||||
_unmatched = object()
|
||||
|
||||
# For specifying empty leaf dict `{}`
|
||||
empty_dict = object()
|
||||
|
||||
|
||||
def _match(qs, ks):
|
||||
"""Return True if regexes in qs match any window of strings in tuple ks."""
|
||||
# compile regexes and force complete match
|
||||
qts = tuple((re.compile(x + "$") for x in qs))
|
||||
for i in range(len(ks) - len(qs) + 1):
|
||||
matches = [x.match(y) for x, y in zip(qts, ks[i:])]
|
||||
if matches and all(matches):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _replacement_rules(rules):
|
||||
def replace(key, val):
|
||||
for rule, replacement in rules:
|
||||
if _match(rule, key):
|
||||
return replacement
|
||||
return val
|
||||
|
||||
return replace
|
||||
|
||||
|
||||
# PartitionSpec for GPTNeo
|
||||
# replicate the hidden dim and shard feed-forward and head dim
|
||||
def _get_partition_rules():
|
||||
return [
|
||||
# embeddings
|
||||
(("transformer", "wpe", "embedding"), P("mp", None)),
|
||||
(("transformer", "wte", "embedding"), P("mp", None)),
|
||||
# atention
|
||||
(("attention", "(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
|
||||
(("attention", "out_proj", "kernel"), P("mp", None)),
|
||||
(("attention", "out_proj", "bias"), None),
|
||||
# mlp
|
||||
(("mlp", "c_fc", "kernel"), P(None, "mp")),
|
||||
(("mlp", "c_fc", "bias"), P("mp")),
|
||||
(("mlp", "c_proj", "kernel"), P("mp", None)),
|
||||
(("mlp", "c_proj", "bias"), None),
|
||||
# layer norms
|
||||
((r"ln_\d+", "bias"), None),
|
||||
((r"\d+", r"ln_\d+", "scale"), None),
|
||||
(("ln_f", "bias"), None),
|
||||
(("ln_f", "scale"), None),
|
||||
]
|
||||
|
||||
|
||||
def set_partitions(in_dict):
|
||||
rules = _get_partition_rules()
|
||||
replace = _replacement_rules(rules)
|
||||
initd = {k: _unmatched for k in flatten_dict(in_dict)}
|
||||
result = {k: replace(k, v) for k, v in initd.items()}
|
||||
assert _unmatched not in result.values(), "Incomplete partition spec."
|
||||
return freeze(unflatten_dict(result))
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user