Reorganize examples (#9010)
* Reorganize example folder * Continue reorganization * Change requirements for tests * Final cleanup * Finish regroup with tests all passing * Copyright * Requirements and readme * Make a full link for the documentation * Address review comments * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Add symlink * Reorg again * Apply suggestions from code review Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com> * Adapt title * Update to new strucutre * Remove test * Update READMEs Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
161
examples/research_projects/rag/README.md
Normal file
161
examples/research_projects/rag/README.md
Normal file
@@ -0,0 +1,161 @@
|
||||
# Intro
|
||||
|
||||
Authors: @patrickvonplaten and @lhoestq
|
||||
|
||||
Aimed at tackling the knowledge-intensive NLP tasks (think tasks a human wouldn't be expected to solve without access to external knowledge sources), RAG models are seq2seq models with access to a retrieval mechanism providing relevant context documents at training and evaluation time.
|
||||
|
||||
A RAG model encapsulates two core components: a question encoder and a generator.
|
||||
During a forward pass, we encode the input with the question encoder and pass it
|
||||
to the retriever to extract relevant context documents. The documents are then prepended to the input.
|
||||
Such contextualized inputs are passed to the generator.
|
||||
|
||||
Read more about RAG at https://arxiv.org/abs/2005.11401.
|
||||
|
||||
# Finetuning
|
||||
|
||||
Our finetuning logic is based on scripts from [`examples/seq2seq`](https://github.com/huggingface/transformers/tree/master/examples/seq2seq). We accept training data in the same format as specified there - we expect a directory consisting of 6 text files:
|
||||
```bash
|
||||
train.source
|
||||
train.target
|
||||
val.source
|
||||
val.target
|
||||
test.source
|
||||
test.target
|
||||
```
|
||||
|
||||
A sample finetuning command (run ` ./examples/rag/finetune_rag.py --help` to list all available options):
|
||||
|
||||
```bash
|
||||
python examples/rag/finetune_rag.py \
|
||||
--data_dir $DATA_DIR \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--model_name_or_path $MODEL_NAME_OR_PATH \
|
||||
--model_type rag_sequence \
|
||||
--fp16 \
|
||||
--gpus 8
|
||||
```
|
||||
We publish two `base` models which can serve as a starting point for finetuning on downstream tasks (use them as `model_name_or_path`):
|
||||
- [`facebook/rag-sequence-base`](https://huggingface.co/facebook/rag-sequence-base) - a base for finetuning `RagSequenceForGeneration` models,
|
||||
- [`facebook/rag-token-base`](https://huggingface.co/facebook/rag-token-base) - a base for finetuning `RagTokenForGeneration` models.
|
||||
|
||||
The `base` models initialize the question encoder with [`facebook/dpr-question_encoder-single-nq-base`](https://huggingface.co/facebook/dpr-question_encoder-single-nq-base) and the generator with [`facebook/bart-large`](https://huggingface.co/facebook/bart-large).
|
||||
|
||||
If you would like to initialize finetuning with a base model using different question encoder and generator architectures, you can build it with a consolidation script, e.g.:
|
||||
```
|
||||
python examples/rag/consolidate_rag_checkpoint.py \
|
||||
--model_type rag_sequence \
|
||||
--generator_name_or_path facebook/bart-large-cnn \
|
||||
--question_encoder_name_or_path facebook/dpr-question_encoder-single-nq-base \
|
||||
--dest path/to/checkpoint
|
||||
```
|
||||
You will then be able to pass `path/to/checkpoint` as `model_name_or_path` to the `finetune_rag.py` script.
|
||||
|
||||
|
||||
# Evaluation
|
||||
Our evaluation script enables two modes of evaluation (controlled by the `eval_mode` argument): `e2e` - end2end evaluation, returns EM (exact match) and F1 scores calculated for the downstream task and `retrieval` - which returns precision@k of the documents retrieved for provided inputs.
|
||||
|
||||
The evaluation script expects paths to two files:
|
||||
- `evaluation_set` - a path to a file specifying the evaluation dataset, a single input per line.
|
||||
- `gold_data_path` - a path to a file contaning ground truth answers for datapoints from the `evaluation_set`, a single output per line. Check below for expected formats of the gold data files.
|
||||
|
||||
|
||||
## Retrieval evaluation
|
||||
For `retrieval` evaluation, we expect a gold data file where each line will consist of a tab-separated list of document titles constituting positive contexts for respective datapoints from the `evaluation_set`. E.g. given a question `who sings does he love me with reba` in the `evaluation_set`, a respective ground truth line could look as follows:
|
||||
```
|
||||
Does He Love You Does He Love You Red Sandy Spika dress of Reba McEntire Greatest Hits Volume Two (Reba McEntire album) Shoot for the Moon (album)
|
||||
```
|
||||
|
||||
We demonstrate how to evaluate retrieval against DPR evaluation data. You can download respective files from links listed [here](https://github.com/facebookresearch/DPR/blob/master/data/download_data.py#L39-L45).
|
||||
|
||||
1. Download and unzip the gold data file. We use the `biencoder-nq-dev` from https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz.
|
||||
```bash
|
||||
wget https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz && gzip -d biencoder-nq-dev.json.gz
|
||||
```
|
||||
|
||||
2. Parse the unziped file using the `parse_dpr_relevance_data.py`
|
||||
```bash
|
||||
mkdir output # or wherever you want to save this
|
||||
python examples/rag/parse_dpr_relevance_data.py \
|
||||
--src_path biencoder-nq-dev.json \
|
||||
--evaluation_set output/biencoder-nq-dev.questions \
|
||||
--gold_data_path output/biencoder-nq-dev.pages
|
||||
```
|
||||
3. Run evaluation:
|
||||
```bash
|
||||
python examples/rag/eval_rag.py \
|
||||
--model_name_or_path facebook/rag-sequence-nq \
|
||||
--model_type rag_sequence \
|
||||
--evaluation_set output/biencoder-nq-dev.questions \
|
||||
--gold_data_path output/biencoder-nq-dev.pages \
|
||||
--predictions_path output/retrieval_preds.tsv \
|
||||
--eval_mode retrieval \
|
||||
--k 1
|
||||
```
|
||||
```bash
|
||||
# EXPLANATION
|
||||
python examples/rag/eval_rag.py \
|
||||
--model_name_or_path facebook/rag-sequence-nq \ # model name or path of the model we're evaluating
|
||||
--model_type rag_sequence \ # RAG model type (rag_token or rag_sequence)
|
||||
--evaluation_set output/biencoder-nq-dev.questions \ # an input dataset for evaluation
|
||||
--gold_data_path poutput/biencoder-nq-dev.pages \ # a dataset containing ground truth answers for samples from the evaluation_set
|
||||
--predictions_path output/retrieval_preds.tsv \ # name of file where predictions will be stored
|
||||
--eval_mode retrieval \ # indicates whether we're performing retrieval evaluation or e2e evaluation
|
||||
--k 1 # parameter k for the precision@k metric
|
||||
|
||||
```
|
||||
## End-to-end evaluation
|
||||
|
||||
We support two formats of the gold data file (controlled by the `gold_data_mode` parameter):
|
||||
- `qa` - where a single line has the following format: `input [tab] output_list`, e.g.:
|
||||
```
|
||||
who is the owner of reading football club ['Xiu Li Dai', 'Dai Yongge', 'Dai Xiuli', 'Yongge Dai']
|
||||
```
|
||||
- `ans` - where a single line contains a single expected answer, e.g.:
|
||||
```
|
||||
Xiu Li Dai
|
||||
```
|
||||
|
||||
Predictions of the model for the samples from the `evaluation_set` will be saved under the path specified by the `predictions_path` parameter.
|
||||
If this path already exists, the script will use saved predictions to calculate metrics.
|
||||
Add `--recalculate` parameter to force the script to perform inference from scratch.
|
||||
|
||||
An example e2e evaluation run could look as follows:
|
||||
```bash
|
||||
python examples/rag/eval_rag.py \
|
||||
--model_name_or_path facebook/rag-sequence-nq \
|
||||
--model_type rag_sequence \
|
||||
--evaluation_set path/to/test.source \
|
||||
--gold_data_path path/to/gold_data \
|
||||
--predictions_path path/to/e2e_preds.txt \
|
||||
--eval_mode e2e \
|
||||
--gold_data_mode qa \
|
||||
--n_docs 5 \ # You can experiment with retrieving different number of documents at evaluation time
|
||||
--print_predictions \
|
||||
--recalculate \ # adding this parameter will force recalculating predictions even if predictions_path already exists
|
||||
```
|
||||
|
||||
# Use your own knowledge source
|
||||
|
||||
By default, RAG uses the English Wikipedia as a knowledge source, known as the 'wiki_dpr' dataset.
|
||||
With `use_custom_knowledge_dataset.py` you can build your own knowledge source, *e.g.* for RAG.
|
||||
|
||||
For instance, if documents are serialized as tab-separated csv files with the columns "title" and "text", one can use `use_own_knowledge_dataset.py` as follows:
|
||||
```bash
|
||||
python examples/rag/use_own_knowledge_dataset.py \
|
||||
--csv_path path/to/my_csv \
|
||||
--output_dir path/to/my_knowledge_dataset \
|
||||
```
|
||||
|
||||
The created outputs in `path/to/my_knowledge_dataset` can then be used to finetune RAG as follows:
|
||||
```bash
|
||||
python examples/rag/finetune_rag.py \
|
||||
--data_dir $DATA_DIR \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--model_name_or_path $MODEL_NAME_OR_PATH \
|
||||
--model_type rag_sequence \
|
||||
--fp16 \
|
||||
--gpus 8
|
||||
--index_name custom
|
||||
--passages_path path/to/data/my_knowledge_dataset
|
||||
--index_path path/to/my_knowledge_dataset_hnsw_index.faiss
|
||||
```
|
||||
5
examples/research_projects/rag/__init__.py
Normal file
5
examples/research_projects/rag/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.insert(1, os.path.dirname(os.path.realpath(__file__)))
|
||||
96
examples/research_projects/rag/_test_finetune_rag.py
Normal file
96
examples/research_projects/rag/_test_finetune_rag.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import finetune_rag
|
||||
from transformers.file_utils import is_apex_available
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
execute_subprocess_async,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
)
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class RagFinetuneExampleTests(TestCasePlus):
|
||||
def _create_dummy_data(self, data_dir):
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
contents = {"source": "What is love ?", "target": "life"}
|
||||
n_lines = {"train": 12, "val": 2, "test": 2}
|
||||
for split in ["train", "test", "val"]:
|
||||
for field in ["source", "target"]:
|
||||
content = "\n".join([contents[field]] * n_lines[split])
|
||||
with open(os.path.join(data_dir, f"{split}.{field}"), "w") as f:
|
||||
f.write(content)
|
||||
|
||||
def _run_finetune(self, gpus: int):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
output_dir = os.path.join(tmp_dir, "output")
|
||||
data_dir = os.path.join(tmp_dir, "data")
|
||||
self._create_dummy_data(data_dir=data_dir)
|
||||
|
||||
testargs = f"""
|
||||
--data_dir {data_dir} \
|
||||
--output_dir {output_dir} \
|
||||
--model_name_or_path facebook/rag-sequence-base \
|
||||
--model_type rag_sequence \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--n_val -1 \
|
||||
--val_check_interval 1.0 \
|
||||
--train_batch_size 2 \
|
||||
--eval_batch_size 1 \
|
||||
--max_source_length 25 \
|
||||
--max_target_length 25 \
|
||||
--val_max_target_length 25 \
|
||||
--test_max_target_length 25 \
|
||||
--label_smoothing 0.1 \
|
||||
--dropout 0.1 \
|
||||
--attention_dropout 0.1 \
|
||||
--weight_decay 0.001 \
|
||||
--adam_epsilon 1e-08 \
|
||||
--max_grad_norm 0.1 \
|
||||
--lr_scheduler polynomial \
|
||||
--learning_rate 3e-04 \
|
||||
--num_train_epochs 1 \
|
||||
--warmup_steps 4 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--distributed-port 8787 \
|
||||
--use_dummy_dataset 1 \
|
||||
""".split()
|
||||
|
||||
if gpus > 0:
|
||||
testargs.append(f"--gpus={gpus}")
|
||||
if is_apex_available():
|
||||
testargs.append("--fp16")
|
||||
else:
|
||||
testargs.append("--gpus=0")
|
||||
testargs.append("--distributed_backend=ddp_cpu")
|
||||
testargs.append("--num_processes=2")
|
||||
|
||||
cmd = [sys.executable, str(Path(finetune_rag.__file__).resolve())] + testargs
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
|
||||
metrics_save_path = os.path.join(output_dir, "metrics.json")
|
||||
with open(metrics_save_path) as f:
|
||||
result = json.load(f)
|
||||
return result
|
||||
|
||||
@require_torch_gpu
|
||||
def test_finetune_gpu(self):
|
||||
result = self._run_finetune(gpus=1)
|
||||
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_finetune_multigpu(self):
|
||||
result = self._run_finetune(gpus=2)
|
||||
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
|
||||
116
examples/research_projects/rag/callbacks_rag.py
Normal file
116
examples/research_projects/rag/callbacks_rag.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
from utils_rag import save_json
|
||||
|
||||
|
||||
def count_trainable_parameters(model):
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
params = sum([np.prod(p.size()) for p in model_parameters])
|
||||
return params
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_checkpoint_callback(output_dir, metric):
|
||||
"""Saves the best model by validation EM score."""
|
||||
if metric == "rouge2":
|
||||
exp = "{val_avg_rouge2:.4f}-{step_count}"
|
||||
elif metric == "bleu":
|
||||
exp = "{val_avg_bleu:.4f}-{step_count}"
|
||||
elif metric == "em":
|
||||
exp = "{val_avg_em:.4f}-{step_count}"
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function."
|
||||
)
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
filepath=os.path.join(output_dir, exp),
|
||||
monitor=f"val_{metric}",
|
||||
mode="max",
|
||||
save_top_k=3,
|
||||
period=1, # maybe save a checkpoint every time val is run, not just end of epoch.
|
||||
)
|
||||
return checkpoint_callback
|
||||
|
||||
|
||||
def get_early_stopping_callback(metric, patience):
|
||||
return EarlyStopping(
|
||||
monitor=f"val_{metric}", # does this need avg?
|
||||
mode="min" if "loss" in metric else "max",
|
||||
patience=patience,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
||||
class Seq2SeqLoggingCallback(pl.Callback):
|
||||
def on_batch_end(self, trainer, pl_module):
|
||||
lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)}
|
||||
pl_module.logger.log_metrics(lrs)
|
||||
|
||||
@rank_zero_only
|
||||
def _write_logs(
|
||||
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
|
||||
) -> None:
|
||||
logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****")
|
||||
metrics = trainer.callback_metrics
|
||||
trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]})
|
||||
# Log results
|
||||
od = Path(pl_module.hparams.output_dir)
|
||||
if type_path == "test":
|
||||
results_file = od / "test_results.txt"
|
||||
generations_file = od / "test_generations.txt"
|
||||
else:
|
||||
# this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
|
||||
# If people want this it will be easy enough to add back.
|
||||
results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt"
|
||||
generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
|
||||
results_file.parent.mkdir(exist_ok=True)
|
||||
generations_file.parent.mkdir(exist_ok=True)
|
||||
with open(results_file, "a+") as writer:
|
||||
for key in sorted(metrics):
|
||||
if key in ["log", "progress_bar", "preds"]:
|
||||
continue
|
||||
val = metrics[key]
|
||||
if isinstance(val, torch.Tensor):
|
||||
val = val.item()
|
||||
msg = f"{key}: {val:.6f}\n"
|
||||
writer.write(msg)
|
||||
|
||||
if not save_generations:
|
||||
return
|
||||
|
||||
if "preds" in metrics:
|
||||
content = "\n".join(metrics["preds"])
|
||||
generations_file.open("w+").write(content)
|
||||
|
||||
@rank_zero_only
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
try:
|
||||
npars = pl_module.model.model.num_parameters()
|
||||
except AttributeError:
|
||||
npars = pl_module.model.num_parameters()
|
||||
|
||||
n_trainable_pars = count_trainable_parameters(pl_module)
|
||||
# mp stands for million parameters
|
||||
trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6})
|
||||
|
||||
@rank_zero_only
|
||||
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
save_json(pl_module.metrics, pl_module.metrics_save_path)
|
||||
return self._write_logs(trainer, pl_module, "test")
|
||||
|
||||
@rank_zero_only
|
||||
def on_validation_end(self, trainer: pl.Trainer, pl_module):
|
||||
save_json(pl_module.metrics, pl_module.metrics_save_path)
|
||||
# Uncommenting this will save val generations
|
||||
# return self._write_logs(trainer, pl_module, "valid")
|
||||
99
examples/research_projects/rag/consolidate_rag_checkpoint.py
Normal file
99
examples/research_projects/rag/consolidate_rag_checkpoint.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
A script creating a RAG checkpoint from a generator and a question encoder checkpoints.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import AutoConfig, AutoTokenizer, RagConfig, RagSequenceForGeneration, RagTokenForGeneration
|
||||
|
||||
|
||||
def consolidate(
|
||||
model_type,
|
||||
generator_name_or_path: str,
|
||||
question_encoder_name_or_path: str,
|
||||
dest_dir: Path,
|
||||
config_name_or_path: str = None,
|
||||
generator_tokenizer_name_or_path: str = None,
|
||||
question_encoder_tokenizer_name_or_path: str = None,
|
||||
):
|
||||
|
||||
if config_name_or_path is None:
|
||||
config_name_or_path = "facebook/rag-token-base" if model_type == "rag_token" else "facebook/rag-sequence-base"
|
||||
|
||||
if generator_tokenizer_name_or_path is None:
|
||||
generator_tokenizer_name_or_path = generator_name_or_path
|
||||
|
||||
if question_encoder_tokenizer_name_or_path is None:
|
||||
question_encoder_tokenizer_name_or_path = question_encoder_name_or_path
|
||||
|
||||
model_class = RagTokenForGeneration if model_type == "rag_token" else RagSequenceForGeneration
|
||||
|
||||
# Save model.
|
||||
rag_config = RagConfig.from_pretrained(config_name_or_path)
|
||||
gen_config = AutoConfig.from_pretrained(generator_name_or_path)
|
||||
question_encoder_config = AutoConfig.from_pretrained(question_encoder_name_or_path)
|
||||
|
||||
rag_config.generator = gen_config
|
||||
rag_config.question_encoder = question_encoder_config
|
||||
|
||||
rag_model = model_class.from_pretrained_question_encoder_generator(
|
||||
question_encoder_name_or_path, generator_name_or_path, config=rag_config
|
||||
)
|
||||
rag_model.save_pretrained(dest_dir)
|
||||
|
||||
# Sanity check.
|
||||
model_class.from_pretrained(dest_dir)
|
||||
|
||||
# Save tokenizers.
|
||||
gen_tokenizer = AutoTokenizer.from_pretrained(generator_tokenizer_name_or_path)
|
||||
gen_tokenizer.save_pretrained(dest_dir / "generator_tokenizer/")
|
||||
question_encoder_tokenizer = AutoTokenizer.from_pretrained(question_encoder_tokenizer_name_or_path)
|
||||
question_encoder_tokenizer.save_pretrained(dest_dir / "question_encoder_tokenizer/")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
choices=["rag_sequence", "rag_token"],
|
||||
required=True,
|
||||
type=str,
|
||||
help="RAG model type: rag_sequence, rag_token",
|
||||
)
|
||||
parser.add_argument("--dest", type=str, required=True, help="Path to the output checkpoint directory.")
|
||||
parser.add_argument("--generator_name_or_path", type=str, required=True, help="Generator model identifier")
|
||||
parser.add_argument(
|
||||
"--question_encoder_name_or_path", type=str, required=True, help="Question encoder model identifier"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--generator_tokenizer_name_or_path",
|
||||
type=str,
|
||||
help="Generator tokenizer identifier, if not specified, resolves to ``generator_name_or_path``",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--question_encoder_tokenizer_name_or_path",
|
||||
type=str,
|
||||
help="Question encoder tokenizer identifier, if not specified, resolves to ``question_encoder_name_or_path``",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_name_or_path",
|
||||
type=str,
|
||||
help="Identifier of the model config to use, if not provided, resolves to a base config for a given ``model_type``",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
dest_dir = Path(args.dest)
|
||||
dest_dir.mkdir(exist_ok=True)
|
||||
|
||||
consolidate(
|
||||
args.model_type,
|
||||
args.generator_name_or_path,
|
||||
args.question_encoder_name_or_path,
|
||||
dest_dir,
|
||||
args.config_name_or_path,
|
||||
args.generator_tokenizer_name_or_path,
|
||||
args.question_encoder_tokenizer_name_or_path,
|
||||
)
|
||||
139
examples/research_projects/rag/distributed_retriever.py
Normal file
139
examples/research_projects/rag/distributed_retriever.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import psutil
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from transformers import RagRetriever
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RagPyTorchDistributedRetriever(RagRetriever):
|
||||
"""
|
||||
A distributed retriever built on top of the ``torch.distributed`` communication package. During training all workers
|
||||
initialize their own instance of the retriever, however, only the main worker loads the index into memory. The index is stored
|
||||
in cpu memory. The index will also work well in a non-distributed setup.
|
||||
|
||||
Args:
|
||||
config (:class:`~transformers.RagConfig`):
|
||||
The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build.
|
||||
question_encoder_tokenizer (:class:`~transformers.PretrainedTokenizer`):
|
||||
The tokenizer that was used to tokenize the question.
|
||||
It is used to decode the question and then use the generator_tokenizer.
|
||||
generator_tokenizer (:class:`~transformers.PretrainedTokenizer`):
|
||||
The tokenizer used for the generator part of the RagModel.
|
||||
index (:class:`~transformers.models.rag.retrieval_rag.Index`, optional, defaults to the one defined by the configuration):
|
||||
If specified, use this index instead of the one built using the configuration
|
||||
"""
|
||||
|
||||
_init_retrieval = False
|
||||
|
||||
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None):
|
||||
super().__init__(
|
||||
config,
|
||||
question_encoder_tokenizer=question_encoder_tokenizer,
|
||||
generator_tokenizer=generator_tokenizer,
|
||||
index=index,
|
||||
)
|
||||
self.process_group = None
|
||||
|
||||
def init_retrieval(self, distributed_port: int):
|
||||
"""
|
||||
Retriever initialization function, needs to be called from the training process. The function sets some common parameters
|
||||
and environment variables. On top of that, (only) the main process in the process group loads the index into memory.
|
||||
|
||||
Args:
|
||||
distributed_port (:obj:`int`):
|
||||
The port on which the main communication of the training run is carried out. We set the port for retrieval-related
|
||||
communication as ``distributed_port + 1``.
|
||||
"""
|
||||
|
||||
logger.info("initializing retrieval")
|
||||
|
||||
# initializing a separate process group for retrieval as the default
|
||||
# nccl backend doesn't support gather/scatter operations while gloo
|
||||
# is too slow to replace nccl for the core gpu communication
|
||||
if dist.is_initialized():
|
||||
logger.info("dist initialized")
|
||||
# needs to be set manually
|
||||
os.environ["GLOO_SOCKET_IFNAME"] = self._infer_socket_ifname()
|
||||
# avoid clash with the NCCL port
|
||||
os.environ["MASTER_PORT"] = str(distributed_port + 1)
|
||||
self.process_group = dist.new_group(ranks=None, backend="gloo")
|
||||
|
||||
# initialize retriever only on the main worker
|
||||
if not dist.is_initialized() or self._is_main():
|
||||
logger.info("dist not initialized / main")
|
||||
self.index.init_index()
|
||||
|
||||
# all processes wait untill the retriever is initialized by the main process
|
||||
if dist.is_initialized():
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
def _is_main(self):
|
||||
return dist.get_rank(group=self.process_group) == 0
|
||||
|
||||
def _scattered(self, scatter_list, target_shape, target_type=torch.float32):
|
||||
target_tensor = torch.empty(target_shape, dtype=target_type)
|
||||
dist.scatter(target_tensor, src=0, scatter_list=scatter_list, group=self.process_group)
|
||||
return target_tensor
|
||||
|
||||
def _infer_socket_ifname(self):
|
||||
addrs = psutil.net_if_addrs()
|
||||
# a hacky way to deal with varying network interface names
|
||||
ifname = next((addr for addr in addrs if addr.startswith("e")), None)
|
||||
return ifname
|
||||
|
||||
def retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, List[dict]]:
|
||||
"""
|
||||
Retrieves documents for specified ``question_hidden_states``. The main process, which has the access to the index stored in memory, gathers queries
|
||||
from all the processes in the main training process group, performs the retrieval and scatters back the results.
|
||||
|
||||
Args:
|
||||
question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`):
|
||||
A batch of query vectors to retrieve with.
|
||||
n_docs (:obj:`int`):
|
||||
The number of docs retrieved per query.
|
||||
|
||||
Output:
|
||||
retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)`
|
||||
The retrieval embeddings of the retrieved docs per query.
|
||||
doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`)
|
||||
The ids of the documents in the index
|
||||
doc_dicts (:obj:`List[dict]`):
|
||||
The retrieved_doc_embeds examples per query.
|
||||
"""
|
||||
|
||||
# single GPU training
|
||||
if not dist.is_initialized():
|
||||
doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs)
|
||||
return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids)
|
||||
|
||||
# distributed training
|
||||
world_size = dist.get_world_size(group=self.process_group)
|
||||
|
||||
# gather logic
|
||||
gather_list = None
|
||||
if self._is_main():
|
||||
gather_list = [torch.empty(question_hidden_states.shape, dtype=torch.float32) for _ in range(world_size)]
|
||||
dist.gather(torch.tensor(question_hidden_states), dst=0, gather_list=gather_list, group=self.process_group)
|
||||
|
||||
# scatter logic
|
||||
n_queries = question_hidden_states.shape[0]
|
||||
scatter_ids = []
|
||||
scatter_vectors = []
|
||||
if self._is_main():
|
||||
assert len(gather_list) == world_size
|
||||
ids, vectors = self._main_retrieve(torch.cat(gather_list).numpy(), n_docs)
|
||||
ids, vectors = torch.tensor(ids), torch.tensor(vectors)
|
||||
scatter_ids = self._chunk_tensor(ids, n_queries)
|
||||
scatter_vectors = self._chunk_tensor(vectors, n_queries)
|
||||
doc_ids = self._scattered(scatter_ids, [n_queries, n_docs], target_type=torch.int64)
|
||||
retrieved_doc_embeds = self._scattered(scatter_vectors, [n_queries, n_docs, question_hidden_states.shape[1]])
|
||||
|
||||
return retrieved_doc_embeds.numpy(), doc_ids.numpy(), self.index.get_doc_dicts(doc_ids)
|
||||
314
examples/research_projects/rag/eval_rag.py
Normal file
314
examples/research_projects/rag/eval_rag.py
Normal file
@@ -0,0 +1,314 @@
|
||||
""" Evaluation script for RAG models."""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import BartForConditionalGeneration, RagRetriever, RagSequenceForGeneration, RagTokenForGeneration
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # isort:skip
|
||||
from utils_rag import exact_match_score, f1_score # noqa: E402 # isort:skip
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
transformers_logging.set_verbosity_info()
|
||||
|
||||
|
||||
def infer_model_type(model_name_or_path):
|
||||
if "token" in model_name_or_path:
|
||||
return "rag_token"
|
||||
if "sequence" in model_name_or_path:
|
||||
return "rag_sequence"
|
||||
if "bart" in model_name_or_path:
|
||||
return "bart"
|
||||
return None
|
||||
|
||||
|
||||
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
||||
return max(metric_fn(prediction, gt) for gt in ground_truths)
|
||||
|
||||
|
||||
def get_scores(args, preds_path, gold_data_path):
|
||||
hypos = [line.strip() for line in open(preds_path, "r").readlines()]
|
||||
answers = []
|
||||
|
||||
if args.gold_data_mode == "qa":
|
||||
data = pd.read_csv(gold_data_path, sep="\t", header=None)
|
||||
for answer_list in data[1]:
|
||||
ground_truths = ast.literal_eval(answer_list)
|
||||
answers.append(ground_truths)
|
||||
else:
|
||||
references = [line.strip() for line in open(gold_data_path, "r").readlines()]
|
||||
answers = [[reference] for reference in references]
|
||||
|
||||
f1 = em = total = 0
|
||||
for prediction, ground_truths in zip(hypos, answers):
|
||||
total += 1
|
||||
em += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
|
||||
f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths)
|
||||
|
||||
em = 100.0 * em / total
|
||||
f1 = 100.0 * f1 / total
|
||||
|
||||
logger.info(f"F1: {f1:.2f}")
|
||||
logger.info(f"EM: {em:.2f}")
|
||||
|
||||
|
||||
def get_precision_at_k(args, preds_path, gold_data_path):
|
||||
k = args.k
|
||||
hypos = [line.strip() for line in open(preds_path, "r").readlines()]
|
||||
references = [line.strip() for line in open(gold_data_path, "r").readlines()]
|
||||
|
||||
em = total = 0
|
||||
for hypo, reference in zip(hypos, references):
|
||||
hypo_provenance = set(hypo.split("\t")[:k])
|
||||
ref_provenance = set(reference.split("\t"))
|
||||
total += 1
|
||||
em += len(hypo_provenance & ref_provenance) / k
|
||||
|
||||
em = 100.0 * em / total
|
||||
logger.info(f"Precision@{k}: {em: .2f}")
|
||||
|
||||
|
||||
def evaluate_batch_retrieval(args, rag_model, questions):
|
||||
def strip_title(title):
|
||||
if title.startswith('"'):
|
||||
title = title[1:]
|
||||
if title.endswith('"'):
|
||||
title = title[:-1]
|
||||
return title
|
||||
|
||||
retriever_input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
|
||||
questions,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
)["input_ids"].to(args.device)
|
||||
|
||||
question_enc_outputs = rag_model.rag.question_encoder(retriever_input_ids)
|
||||
question_enc_pool_output = question_enc_outputs.pooler_output
|
||||
|
||||
result = rag_model.retriever(
|
||||
retriever_input_ids,
|
||||
question_enc_pool_output.cpu().detach().to(torch.float32).numpy(),
|
||||
prefix=rag_model.rag.generator.config.prefix,
|
||||
n_docs=rag_model.config.n_docs,
|
||||
return_tensors="pt",
|
||||
)
|
||||
all_docs = rag_model.retriever.index.get_doc_dicts(result.doc_ids)
|
||||
provenance_strings = []
|
||||
for docs in all_docs:
|
||||
provenance = [strip_title(title) for title in docs["title"]]
|
||||
provenance_strings.append("\t".join(provenance))
|
||||
return provenance_strings
|
||||
|
||||
|
||||
def evaluate_batch_e2e(args, rag_model, questions):
|
||||
with torch.no_grad():
|
||||
inputs_dict = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
|
||||
questions, return_tensors="pt", padding=True, truncation=True
|
||||
)
|
||||
|
||||
input_ids = inputs_dict.input_ids.to(args.device)
|
||||
attention_mask = inputs_dict.attention_mask.to(args.device)
|
||||
outputs = rag_model.generate( # rag_model overwrites generate
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
num_beams=args.num_beams,
|
||||
min_length=args.min_length,
|
||||
max_length=args.max_length,
|
||||
early_stopping=False,
|
||||
num_return_sequences=1,
|
||||
bad_words_ids=[[0, 0]], # BART likes to repeat BOS tokens, dont allow it to generate more than one
|
||||
clean_up_tokenization=True,
|
||||
print_docs=args.print_docs,
|
||||
)
|
||||
answers = rag_model.retriever.generator_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
if args.print_predictions:
|
||||
for q, a in zip(questions, answers):
|
||||
logger.info("Q: {} - A: {}".format(q, a))
|
||||
|
||||
return answers
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
choices=["rag_sequence", "rag_token", "bart"],
|
||||
type=str,
|
||||
help="RAG model type: rag_sequence, rag_token or bart, if none specified, the type is inferred from the model_name_or_path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index_name",
|
||||
default=None,
|
||||
choices=["exact", "compressed", "legacy"],
|
||||
type=str,
|
||||
help="RAG model retriever type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to the retrieval index",
|
||||
)
|
||||
parser.add_argument("--n_docs", default=5, type=int, help="Number of retrieved docs")
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pretrained checkpoints or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_mode",
|
||||
choices=["e2e", "retrieval"],
|
||||
default="e2e",
|
||||
type=str,
|
||||
help="Evaluation mode, e2e calculates exact match and F1 of the downstream task, retrieval calculates precision@k.",
|
||||
)
|
||||
parser.add_argument("--k", default=1, type=int, help="k for the precision@k calculation")
|
||||
parser.add_argument(
|
||||
"--evaluation_set",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to a file containing evaluation samples",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gold_data_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to a tab-separated file with gold samples",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gold_data_mode",
|
||||
default="qa",
|
||||
type=str,
|
||||
choices=["qa", "ans"],
|
||||
help="Format of the gold data file"
|
||||
"qa - a single line in the following format: question [tab] answer_list"
|
||||
"ans - a single line of the gold file contains the expected answer string",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--predictions_path",
|
||||
type=str,
|
||||
default="predictions.txt",
|
||||
help="Name of the predictions file, to be stored in the checkpoints directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_all_checkpoints",
|
||||
action="store_true",
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_batch_size",
|
||||
default=8,
|
||||
type=int,
|
||||
help="Batch size per GPU/CPU for evaluation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recalculate",
|
||||
help="Recalculate predictions even if the prediction file exists",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_beams",
|
||||
default=4,
|
||||
type=int,
|
||||
help="Number of beams to be used when generating answers",
|
||||
)
|
||||
parser.add_argument("--min_length", default=1, type=int, help="Min length of the generated answers")
|
||||
parser.add_argument("--max_length", default=50, type=int, help="Max length of the generated answers")
|
||||
|
||||
parser.add_argument(
|
||||
"--print_predictions",
|
||||
action="store_true",
|
||||
help="If True, prints predictions while evaluating.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print_docs",
|
||||
action="store_true",
|
||||
help="If True, prints docs retried while generating.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
return args
|
||||
|
||||
|
||||
def main(args):
|
||||
model_kwargs = {}
|
||||
if args.model_type is None:
|
||||
args.model_type = infer_model_type(args.model_name_or_path)
|
||||
assert args.model_type is not None
|
||||
if args.model_type.startswith("rag"):
|
||||
model_class = RagTokenForGeneration if args.model_type == "rag_token" else RagSequenceForGeneration
|
||||
model_kwargs["n_docs"] = args.n_docs
|
||||
if args.index_name is not None:
|
||||
model_kwargs["index_name"] = args.index_name
|
||||
if args.index_path is not None:
|
||||
model_kwargs["index_path"] = args.index_path
|
||||
else:
|
||||
model_class = BartForConditionalGeneration
|
||||
|
||||
checkpoints = (
|
||||
[f.path for f in os.scandir(args.model_name_or_path) if f.is_dir()]
|
||||
if args.eval_all_checkpoints
|
||||
else [args.model_name_or_path]
|
||||
)
|
||||
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
|
||||
score_fn = get_scores if args.eval_mode == "e2e" else get_precision_at_k
|
||||
evaluate_batch_fn = evaluate_batch_e2e if args.eval_mode == "e2e" else evaluate_batch_retrieval
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
if os.path.exists(args.predictions_path) and (not args.recalculate):
|
||||
logger.info("Calculating metrics based on an existing predictions file: {}".format(args.predictions_path))
|
||||
score_fn(args, args.predictions_path, args.gold_data_path)
|
||||
continue
|
||||
|
||||
logger.info("***** Running evaluation for {} *****".format(checkpoint))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
logger.info(" Predictions will be stored under {}".format(args.predictions_path))
|
||||
|
||||
if args.model_type.startswith("rag"):
|
||||
retriever = RagRetriever.from_pretrained(checkpoint, **model_kwargs)
|
||||
model = model_class.from_pretrained(checkpoint, retriever=retriever, **model_kwargs)
|
||||
model.retriever.init_retrieval()
|
||||
else:
|
||||
model = model_class.from_pretrained(checkpoint, **model_kwargs)
|
||||
model.to(args.device)
|
||||
|
||||
with open(args.evaluation_set, "r") as eval_file, open(args.predictions_path, "w") as preds_file:
|
||||
questions = []
|
||||
for line in tqdm(eval_file):
|
||||
questions.append(line.strip())
|
||||
if len(questions) == args.eval_batch_size:
|
||||
answers = evaluate_batch_fn(args, model, questions)
|
||||
preds_file.write("\n".join(answers) + "\n")
|
||||
preds_file.flush()
|
||||
questions = []
|
||||
if len(questions) > 0:
|
||||
answers = evaluate_batch_fn(args, model, questions)
|
||||
preds_file.write("\n".join(answers))
|
||||
preds_file.flush()
|
||||
|
||||
score_fn(args, args.predictions_path, args.gold_data_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
main(args)
|
||||
512
examples/research_projects/rag/finetune_rag.py
Normal file
512
examples/research_projects/rag/finetune_rag.py
Normal file
@@ -0,0 +1,512 @@
|
||||
"""Finetuning script for RAG models. Adapted from examples.seq2seq.finetune.py"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from pytorch_lightning.accelerators.ddp_accelerator import DDPAccelerator
|
||||
from pytorch_lightning.cluster_environments import TorchElasticEnvironment
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
BartForConditionalGeneration,
|
||||
BatchEncoding,
|
||||
RagConfig,
|
||||
RagSequenceForGeneration,
|
||||
RagTokenForGeneration,
|
||||
RagTokenizer,
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
from callbacks_rag import ( # noqa: E402 # isort:skipq
|
||||
get_checkpoint_callback,
|
||||
get_early_stopping_callback,
|
||||
Seq2SeqLoggingCallback,
|
||||
)
|
||||
from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||
from utils_rag import ( # noqa: E402 # isort:skip
|
||||
calculate_exact_match,
|
||||
flatten_list,
|
||||
get_git_info,
|
||||
is_rag_model,
|
||||
lmap,
|
||||
pickle_save,
|
||||
save_git_info,
|
||||
save_json,
|
||||
set_extra_model_params,
|
||||
Seq2SeqDataset,
|
||||
)
|
||||
|
||||
# need the parent dir module
|
||||
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
|
||||
from lightning_base import BaseTransformer, add_generic_args, generic_train # noqa
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
transformers_logging.set_verbosity_info()
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
# In PTL >v1.0, `init_ddp_connection` method in the `LightningModule`
|
||||
# is no longer used, and is moved into DDPAccelerator instead.
|
||||
# We override DDPAccelerator to add our custom logic for initializing the
|
||||
# retriever.
|
||||
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/backends/test_accelerator_connector.py
|
||||
|
||||
|
||||
class CustomAccel(DDPAccelerator):
|
||||
def __init__(self, trainer=None, **kwargs):
|
||||
# Trainer is set later.
|
||||
super().__init__(trainer, **kwargs)
|
||||
|
||||
def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True):
|
||||
logger.info("Custom init_ddp_connection.")
|
||||
module = self.trainer.model
|
||||
if self.cluster_environment is None:
|
||||
self.cluster_environment = TorchElasticEnvironment()
|
||||
self.distributed_port = module.hparams.distributed_port
|
||||
os.environ["MASTER_PORT"] = str(self.distributed_port)
|
||||
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
|
||||
if module.is_rag_model:
|
||||
module.model.rag.retriever.init_retrieval(self.distributed_port)
|
||||
|
||||
|
||||
class GenerativeQAModule(BaseTransformer):
|
||||
mode = "generative_qa"
|
||||
loss_names = ["loss"]
|
||||
metric_names = ["em"]
|
||||
val_metric = "em"
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
# when loading from a pytorch lightning checkpoint, hparams are passed as dict
|
||||
if isinstance(hparams, dict):
|
||||
hparams = AttrDict(hparams)
|
||||
if hparams.model_type == "rag_sequence":
|
||||
self.model_class = RagSequenceForGeneration
|
||||
elif hparams.model_type == "rag_token":
|
||||
self.model_class = RagTokenForGeneration
|
||||
elif hparams.model_type == "bart":
|
||||
self.model_class = BartForConditionalGeneration
|
||||
else:
|
||||
self.model_class = T5ForConditionalGeneration
|
||||
self.is_rag_model = is_rag_model(hparams.model_type)
|
||||
|
||||
config_class = RagConfig if self.is_rag_model else AutoConfig
|
||||
config = config_class.from_pretrained(hparams.model_name_or_path)
|
||||
|
||||
# set retriever parameters
|
||||
config.index_name = hparams.index_name or config.index_name
|
||||
config.passages_path = hparams.passages_path or config.passages_path
|
||||
config.index_path = hparams.index_path or config.index_path
|
||||
config.use_dummy_dataset = hparams.use_dummy_dataset
|
||||
|
||||
# set extra_model_params for generator configs and load_model
|
||||
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "attention_dropout", "dropout")
|
||||
if self.is_rag_model:
|
||||
if hparams.prefix is not None:
|
||||
config.generator.prefix = hparams.prefix
|
||||
config.label_smoothing = hparams.label_smoothing
|
||||
hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator)
|
||||
retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path, config=config)
|
||||
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever)
|
||||
prefix = config.question_encoder.prefix
|
||||
else:
|
||||
if hparams.prefix is not None:
|
||||
config.prefix = hparams.prefix
|
||||
hparams, config = set_extra_model_params(extra_model_params, hparams, config)
|
||||
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config)
|
||||
prefix = config.prefix
|
||||
|
||||
tokenizer = (
|
||||
RagTokenizer.from_pretrained(hparams.model_name_or_path)
|
||||
if self.is_rag_model
|
||||
else AutoTokenizer.from_pretrained(hparams.model_name_or_path)
|
||||
)
|
||||
|
||||
super().__init__(hparams, config=config, tokenizer=tokenizer, model=model)
|
||||
|
||||
save_git_info(self.hparams.output_dir)
|
||||
self.output_dir = Path(self.hparams.output_dir)
|
||||
self.metrics_save_path = Path(self.output_dir) / "metrics.json"
|
||||
self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
|
||||
pickle_save(self.hparams, self.hparams_save_path)
|
||||
self.step_count = 0
|
||||
self.metrics = defaultdict(list)
|
||||
|
||||
self.dataset_kwargs: dict = dict(
|
||||
data_dir=self.hparams.data_dir,
|
||||
max_source_length=self.hparams.max_source_length,
|
||||
prefix=prefix or "",
|
||||
)
|
||||
n_observations_per_split = {
|
||||
"train": self.hparams.n_train,
|
||||
"val": self.hparams.n_val,
|
||||
"test": self.hparams.n_test,
|
||||
}
|
||||
self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
|
||||
|
||||
self.target_lens = {
|
||||
"train": self.hparams.max_target_length,
|
||||
"val": self.hparams.val_max_target_length,
|
||||
"test": self.hparams.test_max_target_length,
|
||||
}
|
||||
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
|
||||
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
|
||||
|
||||
self.hparams.git_sha = get_git_info()["repo_sha"]
|
||||
self.num_workers = hparams.num_workers
|
||||
self.distributed_port = self.hparams.distributed_port
|
||||
|
||||
# For single GPU training, init_ddp_connection is not called.
|
||||
# So we need to initialize the retrievers here.
|
||||
if hparams.gpus <= 1:
|
||||
self.model.retriever.init_retrieval(self.distributed_port)
|
||||
|
||||
def forward(self, input_ids, **kwargs):
|
||||
return self.model(input_ids, **kwargs)
|
||||
|
||||
def ids_to_clean_text(self, generated_ids: List[int]):
|
||||
gen_text = self.tokenizer.batch_decode(
|
||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
return lmap(str.strip, gen_text)
|
||||
|
||||
def _step(self, batch: dict) -> Tuple:
|
||||
source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
||||
|
||||
rag_kwargs = {}
|
||||
if isinstance(self.model, T5ForConditionalGeneration):
|
||||
decoder_input_ids = self.model._shift_right(target_ids)
|
||||
lm_labels = target_ids
|
||||
elif isinstance(self.model, BartForConditionalGeneration):
|
||||
decoder_input_ids = target_ids[:, :-1].contiguous()
|
||||
lm_labels = target_ids[:, 1:].clone()
|
||||
else:
|
||||
assert self.is_rag_model
|
||||
generator = self.model.rag.generator
|
||||
if isinstance(generator, T5ForConditionalGeneration):
|
||||
decoder_start_token_id = generator.config.decoder_start_token_id
|
||||
decoder_input_ids = (
|
||||
torch.cat(
|
||||
[torch.Tensor([[decoder_start_token_id]] * target_ids.shape[0]).to(target_ids), target_ids],
|
||||
dim=1,
|
||||
)
|
||||
if target_ids.shape[0] < self.target_lens["train"]
|
||||
else generator._shift_right(target_ids)
|
||||
)
|
||||
elif isinstance(generator, BartForConditionalGeneration):
|
||||
decoder_input_ids = target_ids
|
||||
lm_labels = decoder_input_ids
|
||||
rag_kwargs["reduce_loss"] = True
|
||||
|
||||
assert decoder_input_ids is not None
|
||||
|
||||
outputs = self(
|
||||
source_ids,
|
||||
attention_mask=source_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
use_cache=False,
|
||||
labels=lm_labels,
|
||||
**rag_kwargs,
|
||||
)
|
||||
|
||||
loss = outputs["loss"]
|
||||
return (loss,)
|
||||
|
||||
@property
|
||||
def pad(self) -> int:
|
||||
raise NotImplementedError("pad not implemented")
|
||||
|
||||
def training_step(self, batch, batch_idx) -> Dict:
|
||||
loss_tensors = self._step(batch)
|
||||
|
||||
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||
# tokens per batch
|
||||
tgt_pad_token_id = (
|
||||
self.tokenizer.generator.pad_token_id
|
||||
if isinstance(self.tokenizer, RagTokenizer)
|
||||
else self.tokenizer.pad_token_id
|
||||
)
|
||||
src_pad_token_id = (
|
||||
self.tokenizer.question_encoder.pad_token_id
|
||||
if isinstance(self.tokenizer, RagTokenizer)
|
||||
else self.tokenizer.pad_token_id
|
||||
)
|
||||
logs["tpb"] = (
|
||||
batch["input_ids"].ne(src_pad_token_id).sum() + batch["decoder_input_ids"].ne(tgt_pad_token_id).sum()
|
||||
)
|
||||
|
||||
return {"loss": loss_tensors[0], "log": logs}
|
||||
|
||||
def validation_step(self, batch, batch_idx) -> Dict:
|
||||
return self._generative_step(batch)
|
||||
|
||||
def validation_epoch_end(self, outputs, prefix="val") -> Dict:
|
||||
self.step_count += 1
|
||||
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
||||
loss = losses["loss"]
|
||||
gen_metrics = {
|
||||
k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
|
||||
}
|
||||
metrics_tensor: torch.FloatTensor = torch.tensor(gen_metrics[self.val_metric]).type_as(loss)
|
||||
gen_metrics.update({k: v.item() for k, v in losses.items()})
|
||||
|
||||
# fix for https://github.com/PyTorchLightning/pytorch-lightning/issues/2424
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(metrics_tensor, op=dist.ReduceOp.SUM)
|
||||
metrics_tensor = metrics_tensor / dist.get_world_size()
|
||||
gen_metrics.update({self.val_metric: metrics_tensor.item()})
|
||||
|
||||
losses.update(gen_metrics)
|
||||
metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
|
||||
metrics["step_count"] = self.step_count
|
||||
self.save_metrics(metrics, prefix) # writes to self.metrics_save_path
|
||||
preds = flatten_list([x["preds"] for x in outputs])
|
||||
return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_{self.val_metric}": metrics_tensor}
|
||||
|
||||
def save_metrics(self, latest_metrics, type_path) -> None:
|
||||
self.metrics[type_path].append(latest_metrics)
|
||||
save_json(self.metrics, self.metrics_save_path)
|
||||
|
||||
def calc_generative_metrics(self, preds, target) -> Dict:
|
||||
return calculate_exact_match(preds, target)
|
||||
|
||||
def _generative_step(self, batch: dict) -> dict:
|
||||
start_time = time.time()
|
||||
batch = BatchEncoding(batch).to(device=self.model.device)
|
||||
generated_ids = self.model.generate(
|
||||
batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
do_deduplication=False, # rag specific parameter
|
||||
use_cache=True,
|
||||
min_length=1,
|
||||
max_length=self.target_lens["val"],
|
||||
)
|
||||
|
||||
gen_time = (time.time() - start_time) / batch["input_ids"].shape[0]
|
||||
preds: List[str] = self.ids_to_clean_text(generated_ids)
|
||||
target: List[str] = self.ids_to_clean_text(batch["decoder_input_ids"])
|
||||
loss_tensors = self._step(batch)
|
||||
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||
gen_metrics: Dict = self.calc_generative_metrics(preds, target)
|
||||
|
||||
summ_len = np.mean(lmap(len, generated_ids))
|
||||
base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **gen_metrics)
|
||||
return base_metrics
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
return self._generative_step(batch)
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
return self.validation_epoch_end(outputs, prefix="test")
|
||||
|
||||
def get_dataset(self, type_path) -> Seq2SeqDataset:
|
||||
n_obs = self.n_obs[type_path]
|
||||
max_target_length = self.target_lens[type_path]
|
||||
dataset = Seq2SeqDataset(
|
||||
self.tokenizer,
|
||||
type_path=type_path,
|
||||
n_obs=n_obs,
|
||||
max_target_length=max_target_length,
|
||||
**self.dataset_kwargs,
|
||||
)
|
||||
return dataset
|
||||
|
||||
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
||||
dataset = self.get_dataset(type_path)
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=dataset.collate_fn,
|
||||
shuffle=shuffle,
|
||||
num_workers=self.num_workers,
|
||||
)
|
||||
return dataloader
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
||||
return dataloader
|
||||
|
||||
def val_dataloader(self) -> DataLoader:
|
||||
return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
|
||||
|
||||
def test_dataloader(self) -> DataLoader:
|
||||
return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
|
||||
|
||||
@pl.utilities.rank_zero_only
|
||||
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
save_path = self.output_dir.joinpath("checkpoint{}".format(self.step_count))
|
||||
self.model.config.save_step = self.step_count
|
||||
self.model.save_pretrained(save_path)
|
||||
self.tokenizer.save_pretrained(save_path)
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
BaseTransformer.add_model_specific_args(parser, root_dir)
|
||||
add_generic_args(parser, root_dir)
|
||||
parser.add_argument(
|
||||
"--max_source_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_target_length",
|
||||
default=25,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val_max_target_length",
|
||||
default=25,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_max_target_length",
|
||||
default=25,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
|
||||
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--n_val", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
|
||||
parser.add_argument(
|
||||
"--prefix",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Prefix added at the beginning of each text, typically used with T5-based models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--early_stopping_patience",
|
||||
type=int,
|
||||
default=-1,
|
||||
required=False,
|
||||
help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distributed-port", type=int, default=-1, required=False, help="Port number for distributed training."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
choices=["rag_sequence", "rag_token", "bart", "t5"],
|
||||
type=str,
|
||||
help="RAG model type: sequence or token, if none specified, the type is inferred from the model_name_or_path",
|
||||
)
|
||||
return parser
|
||||
|
||||
@staticmethod
|
||||
def add_retriever_specific_args(parser):
|
||||
parser.add_argument(
|
||||
"--index_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Name of the index to use: 'hf' for a canonical dataset from the datasets library (default), 'custom' for a local index, or 'legacy' for the orignal one)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--passages_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the dataset of passages for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the faiss index for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_dummy_dataset",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Whether to use the dummy version of the dataset index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main(args=None, model=None) -> GenerativeQAModule:
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
|
||||
parser = GenerativeQAModule.add_retriever_specific_args(parser)
|
||||
|
||||
args = args or parser.parse_args()
|
||||
|
||||
Path(args.output_dir).mkdir(exist_ok=True)
|
||||
if model is None:
|
||||
model: GenerativeQAModule = GenerativeQAModule(args)
|
||||
|
||||
dataset = Path(args.data_dir).name
|
||||
if (
|
||||
args.logger_name == "default"
|
||||
or args.fast_dev_run
|
||||
or str(args.output_dir).startswith("/tmp")
|
||||
or str(args.output_dir).startswith("/var")
|
||||
):
|
||||
logger = True # don't pollute wandb logs unnecessarily
|
||||
elif args.logger_name == "wandb":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
project = os.environ.get("WANDB_PROJECT", dataset)
|
||||
logger = WandbLogger(name=model.output_dir.name, project=project)
|
||||
|
||||
elif args.logger_name == "wandb_shared":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
||||
|
||||
es_callback = (
|
||||
get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
|
||||
if args.early_stopping_patience >= 0
|
||||
else False
|
||||
)
|
||||
|
||||
trainer: pl.Trainer = generic_train(
|
||||
model,
|
||||
args,
|
||||
logging_callback=Seq2SeqLoggingCallback(),
|
||||
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
|
||||
early_stopping_callback=es_callback,
|
||||
logger=logger,
|
||||
accelerator=CustomAccel() if args.gpus > 1 else None,
|
||||
)
|
||||
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
|
||||
|
||||
if not args.do_predict:
|
||||
return model
|
||||
|
||||
# test() without a model tests using the best checkpoint automatically
|
||||
trainer.test()
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
34
examples/research_projects/rag/finetune_rag.sh
Executable file
34
examples/research_projects/rag/finetune_rag.sh
Executable file
@@ -0,0 +1,34 @@
|
||||
# Add parent directory to python path to access lightning_base.py
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path
|
||||
# run ./examples/rag/finetune.sh --help to see all the possible options
|
||||
|
||||
python examples/rag/finetune_rag.py \
|
||||
--data_dir $DATA_DIR \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--model_name_or_path $MODEL_NAME_OR_PATH \
|
||||
--model_type rag_sequence \
|
||||
--fp16 \
|
||||
--gpus 8 \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--n_val -1 \
|
||||
--val_check_interval 0.25 \
|
||||
--train_batch_size 8 \
|
||||
--eval_batch_size 1 \
|
||||
--max_source_length 128 \
|
||||
--max_target_length 25 \
|
||||
--val_max_target_length 25 \
|
||||
--test_max_target_length 25 \
|
||||
--label_smoothing 0.1 \
|
||||
--dropout 0.1 \
|
||||
--attention_dropout 0.1 \
|
||||
--weight_decay 0.001 \
|
||||
--adam_epsilon 1e-08 \
|
||||
--max_grad_norm 0.1 \
|
||||
--lr_scheduler polynomial \
|
||||
--learning_rate 3e-05 \
|
||||
--num_train_epochs 100 \
|
||||
--warmup_steps 500 \
|
||||
--gradient_accumulation_steps 1
|
||||
391
examples/research_projects/rag/lightning_base.py
Normal file
391
examples/research_projects/rag/lightning_base.py
Normal file
@@ -0,0 +1,391 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
|
||||
from transformers import (
|
||||
AdamW,
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForPreTraining,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelWithLMHead,
|
||||
AutoTokenizer,
|
||||
PretrainedConfig,
|
||||
PreTrainedTokenizer,
|
||||
)
|
||||
from transformers.optimization import (
|
||||
Adafactor,
|
||||
get_cosine_schedule_with_warmup,
|
||||
get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||
get_linear_schedule_with_warmup,
|
||||
get_polynomial_decay_schedule_with_warmup,
|
||||
)
|
||||
from transformers.utils.versions import require_version_examples
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
require_version_examples("pytorch_lightning>=1.0.4")
|
||||
|
||||
MODEL_MODES = {
|
||||
"base": AutoModel,
|
||||
"sequence-classification": AutoModelForSequenceClassification,
|
||||
"question-answering": AutoModelForQuestionAnswering,
|
||||
"pretraining": AutoModelForPreTraining,
|
||||
"token-classification": AutoModelForTokenClassification,
|
||||
"language-modeling": AutoModelWithLMHead,
|
||||
"summarization": AutoModelForSeq2SeqLM,
|
||||
"translation": AutoModelForSeq2SeqLM,
|
||||
}
|
||||
|
||||
|
||||
# update this and the import above to support new schedulers from transformers.optimization
|
||||
arg_to_scheduler = {
|
||||
"linear": get_linear_schedule_with_warmup,
|
||||
"cosine": get_cosine_schedule_with_warmup,
|
||||
"cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||
"polynomial": get_polynomial_decay_schedule_with_warmup,
|
||||
# '': get_constant_schedule, # not supported for now
|
||||
# '': get_constant_schedule_with_warmup, # not supported for now
|
||||
}
|
||||
arg_to_scheduler_choices = sorted(arg_to_scheduler.keys())
|
||||
arg_to_scheduler_metavar = "{" + ", ".join(arg_to_scheduler_choices) + "}"
|
||||
|
||||
|
||||
class BaseTransformer(pl.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
hparams: argparse.Namespace,
|
||||
num_labels=None,
|
||||
mode="base",
|
||||
config=None,
|
||||
tokenizer=None,
|
||||
model=None,
|
||||
**config_kwargs
|
||||
):
|
||||
"""Initialize a model, tokenizer and config."""
|
||||
super().__init__()
|
||||
# TODO: move to self.save_hyperparameters()
|
||||
# self.save_hyperparameters()
|
||||
# can also expand arguments into trainer signature for easier reading
|
||||
|
||||
self.save_hyperparameters(hparams)
|
||||
self.step_count = 0
|
||||
self.output_dir = Path(self.hparams.output_dir)
|
||||
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
|
||||
if config is None:
|
||||
self.config = AutoConfig.from_pretrained(
|
||||
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
|
||||
**({"num_labels": num_labels} if num_labels is not None else {}),
|
||||
cache_dir=cache_dir,
|
||||
**config_kwargs,
|
||||
)
|
||||
else:
|
||||
self.config: PretrainedConfig = config
|
||||
|
||||
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
|
||||
for p in extra_model_params:
|
||||
if getattr(self.hparams, p, None):
|
||||
assert hasattr(self.config, p), f"model config doesn't have a `{p}` attribute"
|
||||
setattr(self.config, p, getattr(self.hparams, p))
|
||||
|
||||
if tokenizer is None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
else:
|
||||
self.tokenizer: PreTrainedTokenizer = tokenizer
|
||||
self.model_type = MODEL_MODES[mode]
|
||||
if model is None:
|
||||
self.model = self.model_type.from_pretrained(
|
||||
self.hparams.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
|
||||
config=self.config,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
else:
|
||||
self.model = model
|
||||
|
||||
def load_hf_checkpoint(self, *args, **kwargs):
|
||||
self.model = self.model_type.from_pretrained(*args, **kwargs)
|
||||
|
||||
def get_lr_scheduler(self):
|
||||
get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler]
|
||||
scheduler = get_schedule_func(
|
||||
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps()
|
||||
)
|
||||
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
|
||||
return scheduler
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""Prepare optimizer and schedule (linear warmup and decay)"""
|
||||
model = self.model
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": self.hparams.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
if self.hparams.adafactor:
|
||||
optimizer = Adafactor(
|
||||
optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False
|
||||
)
|
||||
|
||||
else:
|
||||
optimizer = AdamW(
|
||||
optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon
|
||||
)
|
||||
self.opt = optimizer
|
||||
|
||||
scheduler = self.get_lr_scheduler()
|
||||
|
||||
return [optimizer], [scheduler]
|
||||
|
||||
def test_step(self, batch, batch_nb):
|
||||
return self.validation_step(batch, batch_nb)
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
return self.validation_end(outputs)
|
||||
|
||||
def total_steps(self) -> int:
|
||||
"""The number of total training steps that will be run. Used for lr scheduler purposes."""
|
||||
num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores
|
||||
effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
|
||||
return (self.dataset_size / effective_batch_size) * self.hparams.max_epochs
|
||||
|
||||
def setup(self, mode):
|
||||
if mode == "test":
|
||||
self.dataset_size = len(self.test_dataloader().dataset)
|
||||
else:
|
||||
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
|
||||
self.dataset_size = len(self.train_dataloader().dataset)
|
||||
|
||||
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False):
|
||||
raise NotImplementedError("You must implement this for your task")
|
||||
|
||||
def train_dataloader(self):
|
||||
return self.train_loader
|
||||
|
||||
def val_dataloader(self):
|
||||
return self.get_dataloader("dev", self.hparams.eval_batch_size, shuffle=False)
|
||||
|
||||
def test_dataloader(self):
|
||||
return self.get_dataloader("test", self.hparams.eval_batch_size, shuffle=False)
|
||||
|
||||
def _feature_file(self, mode):
|
||||
return os.path.join(
|
||||
self.hparams.data_dir,
|
||||
"cached_{}_{}_{}".format(
|
||||
mode,
|
||||
list(filter(None, self.hparams.model_name_or_path.split("/"))).pop(),
|
||||
str(self.hparams.max_seq_length),
|
||||
),
|
||||
)
|
||||
|
||||
@pl.utilities.rank_zero_only
|
||||
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
save_path = self.output_dir.joinpath("best_tfmr")
|
||||
self.model.config.save_step = self.step_count
|
||||
self.model.save_pretrained(save_path)
|
||||
self.tokenizer.save_pretrained(save_path)
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from huggingface.co",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder_layerdrop",
|
||||
type=float,
|
||||
help="Encoder layer dropout probability (Optional). Goes into model.config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder_layerdrop",
|
||||
type=float,
|
||||
help="Decoder layer dropout probability (Optional). Goes into model.config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dropout",
|
||||
type=float,
|
||||
help="Dropout probability (Optional). Goes into model.config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attention_dropout",
|
||||
type=float,
|
||||
help="Attention dropout probability (Optional). Goes into model.config",
|
||||
)
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
default="linear",
|
||||
choices=arg_to_scheduler_choices,
|
||||
metavar=arg_to_scheduler_metavar,
|
||||
type=str,
|
||||
help="Learning rate scheduler",
|
||||
)
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
|
||||
parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int)
|
||||
parser.add_argument("--train_batch_size", default=32, type=int)
|
||||
parser.add_argument("--eval_batch_size", default=32, type=int)
|
||||
parser.add_argument("--adafactor", action="store_true")
|
||||
|
||||
|
||||
class LoggingCallback(pl.Callback):
|
||||
def on_batch_end(self, trainer, pl_module):
|
||||
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
|
||||
lrs = {f"lr_group_{i}": lr for i, lr in enumerate(lr_scheduler.get_lr())}
|
||||
pl_module.logger.log_metrics(lrs)
|
||||
|
||||
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
rank_zero_info("***** Validation results *****")
|
||||
metrics = trainer.callback_metrics
|
||||
# Log results
|
||||
for key in sorted(metrics):
|
||||
if key not in ["log", "progress_bar"]:
|
||||
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
|
||||
|
||||
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
rank_zero_info("***** Test results *****")
|
||||
metrics = trainer.callback_metrics
|
||||
# Log and save results to file
|
||||
output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
|
||||
with open(output_test_results_file, "w") as writer:
|
||||
for key in sorted(metrics):
|
||||
if key not in ["log", "progress_bar"]:
|
||||
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
|
||||
writer.write("{} = {}\n".format(key, str(metrics[key])))
|
||||
|
||||
|
||||
def add_generic_args(parser, root_dir) -> None:
|
||||
# To allow all pl args uncomment the following line
|
||||
# parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O2",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
|
||||
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.")
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
dest="accumulate_grad_batches",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
|
||||
)
|
||||
|
||||
|
||||
def generic_train(
|
||||
model: BaseTransformer,
|
||||
args: argparse.Namespace,
|
||||
early_stopping_callback=None,
|
||||
logger=True, # can pass WandbLogger() here
|
||||
extra_callbacks=[],
|
||||
checkpoint_callback=None,
|
||||
logging_callback=None,
|
||||
**extra_train_kwargs
|
||||
):
|
||||
pl.seed_everything(args.seed)
|
||||
|
||||
# init model
|
||||
odir = Path(model.hparams.output_dir)
|
||||
odir.mkdir(exist_ok=True)
|
||||
|
||||
# add custom checkpoints
|
||||
if checkpoint_callback is None:
|
||||
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
||||
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
|
||||
)
|
||||
if early_stopping_callback:
|
||||
extra_callbacks.append(early_stopping_callback)
|
||||
if logging_callback is None:
|
||||
logging_callback = LoggingCallback()
|
||||
|
||||
train_params = {}
|
||||
|
||||
# TODO: remove with PyTorch 1.6 since pl uses native amp
|
||||
if args.fp16:
|
||||
train_params["precision"] = 16
|
||||
train_params["amp_level"] = args.fp16_opt_level
|
||||
|
||||
if args.gpus > 1:
|
||||
train_params["distributed_backend"] = "ddp"
|
||||
|
||||
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
|
||||
train_params["accelerator"] = extra_train_kwargs.get("accelerator", None)
|
||||
train_params["profiler"] = extra_train_kwargs.get("profiler", None)
|
||||
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
weights_summary=None,
|
||||
callbacks=[logging_callback] + extra_callbacks,
|
||||
logger=logger,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
**train_params,
|
||||
)
|
||||
|
||||
if args.do_train:
|
||||
trainer.fit(model)
|
||||
|
||||
return trainer
|
||||
47
examples/research_projects/rag/parse_dpr_relevance_data.py
Normal file
47
examples/research_projects/rag/parse_dpr_relevance_data.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
This script reads DPR retriever training data and parses each datapoint. We save a line per datapoint.
|
||||
Each line consists of the query followed by a tab-separated list of Wikipedia page titles constituting
|
||||
positive contexts for a given query.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--src_path",
|
||||
type=str,
|
||||
default="biencoder-nq-dev.json",
|
||||
help="Path to raw DPR training data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--evaluation_set",
|
||||
type=str,
|
||||
help="where to store parsed evaluation_set file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gold_data_path",
|
||||
type=str,
|
||||
help="where to store parsed gold_data_path file",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.src_path, "r") as src_file, open(args.evaluation_set, "w") as eval_file, open(
|
||||
args.gold_data_path, "w"
|
||||
) as gold_file:
|
||||
dpr_records = json.load(src_file)
|
||||
for dpr_record in tqdm(dpr_records):
|
||||
question = dpr_record["question"]
|
||||
contexts = [context["title"] for context in dpr_record["positive_ctxs"]]
|
||||
eval_file.write(question + "\n")
|
||||
gold_file.write("\t".join(contexts) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
6
examples/research_projects/rag/requirements.txt
Normal file
6
examples/research_projects/rag/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
faiss-cpu >= 1.6.3
|
||||
datasets >= 1.0.1
|
||||
psutil >= 5.7.0
|
||||
torch >= 1.4.0
|
||||
transformers
|
||||
pytorch-lightning==1.0.4
|
||||
@@ -0,0 +1,2 @@
|
||||
Aaron Aaron Aaron ( or ; "Ahärôn") is a prophet, high priest, and the brother of Moses in the Abrahamic religions. Knowledge of Aaron, along with his brother Moses, comes exclusively from religious texts, such as the Bible and Quran. The Hebrew Bible relates that, unlike Moses, who grew up in the Egyptian royal court, Aaron and his elder sister Miriam remained with their kinsmen in the eastern border-land of Egypt (Goshen). When Moses first confronted the Egyptian king about the Israelites, Aaron served as his brother's spokesman ("prophet") to the Pharaoh. Part of the Law (Torah) that Moses received from God at Sinai granted Aaron the priesthood for himself and his male descendants, and he became the first High Priest of the Israelites. Aaron died before the Israelites crossed the North Jordan river and he was buried on Mount Hor (Numbers 33:39; Deuteronomy 10:6 says he died and was buried at Moserah). Aaron is also mentioned in the New Testament of the Bible. According to the Book of Exodus, Aaron first functioned as Moses' assistant. Because Moses complained that he could not speak well, God appointed Aaron as Moses' "prophet" (Exodus 4:10-17; 7:1). At the command of Moses, he let his rod turn into a snake. Then he stretched out his rod in order to bring on the first three plagues. After that, Moses tended to act and speak for himself. During the journey in the wilderness, Aaron was not always prominent or active. At the battle with Amalek, he was chosen with Hur to support the hand of Moses that held the "rod of God". When the revelation was given to Moses at biblical Mount Sinai, he headed the elders of Israel who accompanied Moses on the way to the summit.
|
||||
"Pokémon" Pokémon , also known as in Japan, is a media franchise managed by The Pokémon Company, a Japanese consortium between Nintendo, Game Freak, and Creatures. The franchise copyright is shared by all three companies, but Nintendo is the sole owner of the trademark. The franchise was created by Satoshi Tajiri in 1995, and is centered on fictional creatures called "Pokémon", which humans, known as Pokémon Trainers, catch and train to battle each other for sport. The English slogan for the franchise is "Gotta Catch 'Em All". Works within the franchise are set in the Pokémon universe. The franchise began as "Pokémon Red" and "Green" (released outside of Japan as "Pokémon Red" and "Blue"), a pair of video games for the original Game Boy that were developed by Game Freak and published by Nintendo in February 1996. "Pokémon" has since gone on to become the highest-grossing media franchise of all time, with over in revenue up until March 2017. The original video game series is the second best-selling video game franchise (behind Nintendo's "Mario" franchise) with more than 300million copies sold and over 800million mobile downloads. In addition, the "Pokémon" franchise includes the world's top-selling toy brand, the top-selling trading card game with over 25.7billion cards sold, an anime television series that has become the most successful video game adaptation with over 20 seasons and 1,000 episodes in 124 countries, as well as an anime film series, a , books, manga comics, music, and merchandise. The franchise is also represented in other Nintendo media, such as the "Super Smash Bros." series. In November 2005, 4Kids Entertainment, which had managed the non-game related licensing of "Pokémon", announced that it had agreed not to renew the "Pokémon" representation agreement. The Pokémon Company International oversees all "Pokémon" licensing outside Asia.
|
||||
|
Can't render this file because it contains an unexpected character in line 1 and column 35.
|
224
examples/research_projects/rag/test_distributed_retriever.py
Normal file
224
examples/research_projects/rag/test_distributed_retriever.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
from datasets import Dataset
|
||||
|
||||
import faiss
|
||||
from transformers import BartConfig, BartTokenizer, DPRConfig, DPRQuestionEncoderTokenizer, RagConfig
|
||||
from transformers.file_utils import is_datasets_available, is_faiss_available, is_psutil_available, is_torch_available
|
||||
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
||||
from transformers.models.rag.retrieval_rag import CustomHFIndex
|
||||
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
|
||||
from transformers.testing_utils import require_torch_non_multi_gpu_but_fix_me
|
||||
|
||||
|
||||
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
|
||||
|
||||
from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||
|
||||
|
||||
def require_distributed_retrieval(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with
|
||||
:class:`~transformers.RagRetriever`.
|
||||
|
||||
These tests are skipped when respective libraries are not installed.
|
||||
|
||||
"""
|
||||
if not (is_torch_available() and is_datasets_available() and is_faiss_available() and is_psutil_available()):
|
||||
test_case = unittest.skip("test requires PyTorch, Datasets, Faiss, psutil")(test_case)
|
||||
return test_case
|
||||
|
||||
|
||||
@require_distributed_retrieval
|
||||
class RagRetrieverTest(TestCase):
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
self.retrieval_vector_size = 8
|
||||
|
||||
# DPR tok
|
||||
vocab_tokens = [
|
||||
"[UNK]",
|
||||
"[CLS]",
|
||||
"[SEP]",
|
||||
"[PAD]",
|
||||
"[MASK]",
|
||||
"want",
|
||||
"##want",
|
||||
"##ed",
|
||||
"wa",
|
||||
"un",
|
||||
"runn",
|
||||
"##ing",
|
||||
",",
|
||||
"low",
|
||||
"lowest",
|
||||
]
|
||||
dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer")
|
||||
os.makedirs(dpr_tokenizer_path, exist_ok=True)
|
||||
self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
# BART tok
|
||||
vocab = [
|
||||
"l",
|
||||
"o",
|
||||
"w",
|
||||
"e",
|
||||
"r",
|
||||
"s",
|
||||
"t",
|
||||
"i",
|
||||
"d",
|
||||
"n",
|
||||
"\u0120",
|
||||
"\u0120l",
|
||||
"\u0120n",
|
||||
"\u0120lo",
|
||||
"\u0120low",
|
||||
"er",
|
||||
"\u0120lowest",
|
||||
"\u0120newer",
|
||||
"\u0120wider",
|
||||
"<unk>",
|
||||
]
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
|
||||
self.special_tokens_map = {"unk_token": "<unk>"}
|
||||
|
||||
bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer")
|
||||
os.makedirs(bart_tokenizer_path, exist_ok=True)
|
||||
self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"])
|
||||
self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
||||
fp.write(json.dumps(vocab_tokens) + "\n")
|
||||
with open(self.merges_file, "w", encoding="utf-8") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
|
||||
return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
|
||||
|
||||
def get_bart_tokenizer(self) -> BartTokenizer:
|
||||
return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer"))
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def get_dummy_dataset(self):
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"id": ["0", "1"],
|
||||
"text": ["foo", "bar"],
|
||||
"title": ["Foo", "Bar"],
|
||||
"embeddings": [np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size)],
|
||||
}
|
||||
)
|
||||
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
|
||||
return dataset
|
||||
|
||||
def get_dummy_pytorch_distributed_retriever(
|
||||
self, init_retrieval: bool, port=12345
|
||||
) -> RagPyTorchDistributedRetriever:
|
||||
dataset = self.get_dummy_dataset()
|
||||
config = RagConfig(
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
question_encoder=DPRConfig().to_dict(),
|
||||
generator=BartConfig().to_dict(),
|
||||
)
|
||||
with patch("transformers.models.rag.retrieval_rag.load_dataset") as mock_load_dataset:
|
||||
mock_load_dataset.return_value = dataset
|
||||
retriever = RagPyTorchDistributedRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||
generator_tokenizer=self.get_bart_tokenizer(),
|
||||
)
|
||||
if init_retrieval:
|
||||
retriever.init_retrieval(port)
|
||||
return retriever
|
||||
|
||||
def get_dummy_custom_hf_index_retriever(self, init_retrieval: bool, from_disk: bool, port=12345):
|
||||
dataset = self.get_dummy_dataset()
|
||||
config = RagConfig(
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
question_encoder=DPRConfig().to_dict(),
|
||||
generator=BartConfig().to_dict(),
|
||||
index_name="custom",
|
||||
)
|
||||
if from_disk:
|
||||
config.passages_path = os.path.join(self.tmpdirname, "dataset")
|
||||
config.index_path = os.path.join(self.tmpdirname, "index.faiss")
|
||||
dataset.get_index("embeddings").save(os.path.join(self.tmpdirname, "index.faiss"))
|
||||
dataset.drop_index("embeddings")
|
||||
dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset"))
|
||||
del dataset
|
||||
retriever = RagPyTorchDistributedRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||
generator_tokenizer=self.get_bart_tokenizer(),
|
||||
)
|
||||
else:
|
||||
retriever = RagPyTorchDistributedRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||
generator_tokenizer=self.get_bart_tokenizer(),
|
||||
index=CustomHFIndex(config.retrieval_vector_size, dataset),
|
||||
)
|
||||
if init_retrieval:
|
||||
retriever.init_retrieval(port)
|
||||
return retriever
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_pytorch_distributed_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_pytorch_distributed_retriever(init_retrieval=True)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
|
||||
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_custom_hf_index_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_custom_hf_index_retriever(init_retrieval=True, from_disk=False)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
|
||||
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_custom_pytorch_distributed_retriever_retrieve_from_disk(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_custom_hf_index_retriever(init_retrieval=True, from_disk=True)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
|
||||
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
204
examples/research_projects/rag/use_own_knowledge_dataset.py
Normal file
204
examples/research_projects/rag/use_own_knowledge_dataset.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from datasets import Features, Sequence, Value, load_dataset
|
||||
|
||||
import faiss
|
||||
from transformers import (
|
||||
DPRContextEncoder,
|
||||
DPRContextEncoderTokenizerFast,
|
||||
HfArgumentParser,
|
||||
RagRetriever,
|
||||
RagSequenceForGeneration,
|
||||
RagTokenizer,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
torch.set_grad_enabled(False)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def split_text(text: str, n=100, character=" ") -> List[str]:
|
||||
"""Split the text every ``n``-th occurrence of ``character``"""
|
||||
text = text.split(character)
|
||||
return [character.join(text[i : i + n]).strip() for i in range(0, len(text), n)]
|
||||
|
||||
|
||||
def split_documents(documents: dict) -> dict:
|
||||
"""Split documents into passages"""
|
||||
titles, texts = [], []
|
||||
for title, text in zip(documents["title"], documents["text"]):
|
||||
if text is not None:
|
||||
for passage in split_text(text):
|
||||
titles.append(title if title is not None else "")
|
||||
texts.append(passage)
|
||||
return {"title": titles, "text": texts}
|
||||
|
||||
|
||||
def embed(documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizerFast) -> dict:
|
||||
"""Compute the DPR embeddings of document passages"""
|
||||
input_ids = ctx_tokenizer(
|
||||
documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt"
|
||||
)["input_ids"]
|
||||
embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output
|
||||
return {"embeddings": embeddings.detach().cpu().numpy()}
|
||||
|
||||
|
||||
def main(
|
||||
rag_example_args: "RagExampleArguments",
|
||||
processing_args: "ProcessingArguments",
|
||||
index_hnsw_args: "IndexHnswArguments",
|
||||
):
|
||||
|
||||
######################################
|
||||
logger.info("Step 1 - Create the dataset")
|
||||
######################################
|
||||
|
||||
# The dataset needed for RAG must have three columns:
|
||||
# - title (string): title of the document
|
||||
# - text (string): text of a passage of the document
|
||||
# - embeddings (array of dimension d): DPR representation of the passage
|
||||
|
||||
# Let's say you have documents in tab-separated csv files with columns "title" and "text"
|
||||
assert os.path.isfile(rag_example_args.csv_path), "Please provide a valid path to a csv file"
|
||||
|
||||
# You can load a Dataset object this way
|
||||
dataset = load_dataset(
|
||||
"csv", data_files=[rag_example_args.csv_path], split="train", delimiter="\t", column_names=["title", "text"]
|
||||
)
|
||||
|
||||
# More info about loading csv files in the documentation: https://huggingface.co/docs/datasets/loading_datasets.html?highlight=csv#csv-files
|
||||
|
||||
# Then split the documents into passages of 100 words
|
||||
dataset = dataset.map(split_documents, batched=True, num_proc=processing_args.num_proc)
|
||||
|
||||
# And compute the embeddings
|
||||
ctx_encoder = DPRContextEncoder.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name).to(device=device)
|
||||
ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name)
|
||||
new_features = Features(
|
||||
{"text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32"))}
|
||||
) # optional, save as float32 instead of float64 to save space
|
||||
dataset = dataset.map(
|
||||
partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
|
||||
batched=True,
|
||||
batch_size=processing_args.batch_size,
|
||||
features=new_features,
|
||||
)
|
||||
|
||||
# And finally save your dataset
|
||||
passages_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset")
|
||||
dataset.save_to_disk(passages_path)
|
||||
# from datasets import load_from_disk
|
||||
# dataset = load_from_disk(passages_path) # to reload the dataset
|
||||
|
||||
######################################
|
||||
logger.info("Step 2 - Index the dataset")
|
||||
######################################
|
||||
|
||||
# Let's use the Faiss implementation of HNSW for fast approximate nearest neighbor search
|
||||
index = faiss.IndexHNSWFlat(index_hnsw_args.d, index_hnsw_args.m, faiss.METRIC_INNER_PRODUCT)
|
||||
dataset.add_faiss_index("embeddings", custom_index=index)
|
||||
|
||||
# And save the index
|
||||
index_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset_hnsw_index.faiss")
|
||||
dataset.get_index("embeddings").save(index_path)
|
||||
# dataset.load_faiss_index("embeddings", index_path) # to reload the index
|
||||
|
||||
######################################
|
||||
logger.info("Step 3 - Load RAG")
|
||||
######################################
|
||||
|
||||
# Easy way to load the model
|
||||
retriever = RagRetriever.from_pretrained(
|
||||
rag_example_args.rag_model_name, index_name="custom", indexed_dataset=dataset
|
||||
)
|
||||
model = RagSequenceForGeneration.from_pretrained(rag_example_args.rag_model_name, retriever=retriever)
|
||||
tokenizer = RagTokenizer.from_pretrained(rag_example_args.rag_model_name)
|
||||
|
||||
# For distributed fine-tuning you'll need to provide the paths instead, as the dataset and the index are loaded separately.
|
||||
# retriever = RagRetriever.from_pretrained(rag_model_name, index_name="custom", passages_path=passages_path, index_path=index_path)
|
||||
|
||||
######################################
|
||||
logger.info("Step 4 - Have fun")
|
||||
######################################
|
||||
|
||||
question = rag_example_args.question or "What does Moses' rod turn into ?"
|
||||
input_ids = tokenizer.question_encoder(question, return_tensors="pt")["input_ids"]
|
||||
generated = model.generate(input_ids)
|
||||
generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
|
||||
logger.info("Q: " + question)
|
||||
logger.info("A: " + generated_string)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RagExampleArguments:
|
||||
csv_path: str = field(
|
||||
default=str(Path(__file__).parent / "test_data" / "my_knowledge_dataset.csv"),
|
||||
metadata={"help": "Path to a tab-separated csv file with columns 'title' and 'text'"},
|
||||
)
|
||||
question: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Question that is passed as input to RAG. Default is 'What does Moses' rod turn into ?'."},
|
||||
)
|
||||
rag_model_name: str = field(
|
||||
default="facebook/rag-sequence-nq",
|
||||
metadata={"help": "The RAG model to use. Either 'facebook/rag-sequence-nq' or 'facebook/rag-token-nq'"},
|
||||
)
|
||||
dpr_ctx_encoder_model_name: str = field(
|
||||
default="facebook/dpr-ctx_encoder-multiset-base",
|
||||
metadata={
|
||||
"help": "The DPR context encoder model to use. Either 'facebook/dpr-ctx_encoder-single-nq-base' or 'facebook/dpr-ctx_encoder-multiset-base'"
|
||||
},
|
||||
)
|
||||
output_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to a directory where the dataset passages and the index will be saved"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessingArguments:
|
||||
num_proc: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The number of processes to use to split the documents into passages. Default is single process."
|
||||
},
|
||||
)
|
||||
batch_size: int = field(
|
||||
default=16,
|
||||
metadata={
|
||||
"help": "The batch size to use when computing the passages embeddings using the DPR context encoder."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexHnswArguments:
|
||||
d: int = field(
|
||||
default=768,
|
||||
metadata={"help": "The dimension of the embeddings to pass to the HNSW Faiss index."},
|
||||
)
|
||||
m: int = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The number of bi-directional links created for every new element during the HNSW index construction."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
parser = HfArgumentParser((RagExampleArguments, ProcessingArguments, IndexHnswArguments))
|
||||
rag_example_args, processing_args, index_hnsw_args = parser.parse_args_into_dataclasses()
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
rag_example_args.output_dir = rag_example_args.output_dir or tmp_dir
|
||||
main(rag_example_args, processing_args, index_hnsw_args)
|
||||
244
examples/research_projects/rag/utils_rag.py
Normal file
244
examples/research_projects/rag/utils_rag.py
Normal file
@@ -0,0 +1,244 @@
|
||||
import itertools
|
||||
import json
|
||||
import linecache
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import socket
|
||||
import string
|
||||
from collections import Counter
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Iterable, List
|
||||
|
||||
import git
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from transformers import BartTokenizer, RagTokenizer, T5Tokenizer
|
||||
|
||||
|
||||
def encode_line(tokenizer, line, max_length, padding_side, pad_to_max_length=True, return_tensors="pt"):
|
||||
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) and not line.startswith(" ") else {}
|
||||
tokenizer.padding_side = padding_side
|
||||
return tokenizer(
|
||||
[line],
|
||||
max_length=max_length,
|
||||
padding="max_length" if pad_to_max_length else None,
|
||||
truncation=True,
|
||||
return_tensors=return_tensors,
|
||||
add_special_tokens=True,
|
||||
**extra_kw,
|
||||
)
|
||||
|
||||
|
||||
def trim_batch(
|
||||
input_ids,
|
||||
pad_token_id,
|
||||
attention_mask=None,
|
||||
):
|
||||
"""Remove columns that are populated exclusively by pad_token_id"""
|
||||
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
|
||||
if attention_mask is None:
|
||||
return input_ids[:, keep_column_mask]
|
||||
else:
|
||||
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
|
||||
|
||||
|
||||
class Seq2SeqDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
data_dir,
|
||||
max_source_length,
|
||||
max_target_length,
|
||||
type_path="train",
|
||||
n_obs=None,
|
||||
src_lang=None,
|
||||
tgt_lang=None,
|
||||
prefix="",
|
||||
):
|
||||
super().__init__()
|
||||
self.src_file = Path(data_dir).joinpath(type_path + ".source")
|
||||
self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
|
||||
self.src_lens = self.get_char_lens(self.src_file)
|
||||
self.max_source_length = max_source_length
|
||||
self.max_target_length = max_target_length
|
||||
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
|
||||
self.tokenizer = tokenizer
|
||||
self.prefix = prefix
|
||||
if n_obs is not None:
|
||||
self.src_lens = self.src_lens[:n_obs]
|
||||
self.src_lang = src_lang
|
||||
self.tgt_lang = tgt_lang
|
||||
|
||||
def __len__(self):
|
||||
return len(self.src_lens)
|
||||
|
||||
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
|
||||
index = index + 1 # linecache starts at 1
|
||||
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
|
||||
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
||||
assert source_line, f"empty source line for index {index}"
|
||||
assert tgt_line, f"empty tgt line for index {index}"
|
||||
|
||||
# Need to add eos token manually for T5
|
||||
if isinstance(self.tokenizer, T5Tokenizer):
|
||||
source_line += self.tokenizer.eos_token
|
||||
tgt_line += self.tokenizer.eos_token
|
||||
|
||||
# Pad source and target to the right
|
||||
source_tokenizer = (
|
||||
self.tokenizer.question_encoder if isinstance(self.tokenizer, RagTokenizer) else self.tokenizer
|
||||
)
|
||||
target_tokenizer = self.tokenizer.generator if isinstance(self.tokenizer, RagTokenizer) else self.tokenizer
|
||||
|
||||
source_inputs = encode_line(source_tokenizer, source_line, self.max_source_length, "right")
|
||||
target_inputs = encode_line(target_tokenizer, tgt_line, self.max_target_length, "right")
|
||||
|
||||
source_ids = source_inputs["input_ids"].squeeze()
|
||||
target_ids = target_inputs["input_ids"].squeeze()
|
||||
src_mask = source_inputs["attention_mask"].squeeze()
|
||||
return {
|
||||
"input_ids": source_ids,
|
||||
"attention_mask": src_mask,
|
||||
"decoder_input_ids": target_ids,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_char_lens(data_file):
|
||||
return [len(x) for x in Path(data_file).open().readlines()]
|
||||
|
||||
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
||||
input_ids = torch.stack([x["input_ids"] for x in batch])
|
||||
masks = torch.stack([x["attention_mask"] for x in batch])
|
||||
target_ids = torch.stack([x["decoder_input_ids"] for x in batch])
|
||||
tgt_pad_token_id = (
|
||||
self.tokenizer.generator.pad_token_id
|
||||
if isinstance(self.tokenizer, RagTokenizer)
|
||||
else self.tokenizer.pad_token_id
|
||||
)
|
||||
src_pad_token_id = (
|
||||
self.tokenizer.question_encoder.pad_token_id
|
||||
if isinstance(self.tokenizer, RagTokenizer)
|
||||
else self.tokenizer.pad_token_id
|
||||
)
|
||||
y = trim_batch(target_ids, tgt_pad_token_id)
|
||||
source_ids, source_mask = trim_batch(input_ids, src_pad_token_id, attention_mask=masks)
|
||||
batch = {
|
||||
"input_ids": source_ids,
|
||||
"attention_mask": source_mask,
|
||||
"decoder_input_ids": y,
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def flatten_list(summary_ids: List[List]):
|
||||
return [x for x in itertools.chain.from_iterable(summary_ids)]
|
||||
|
||||
|
||||
def save_git_info(folder_path: str) -> None:
|
||||
"""Save git information to output_dir/git_log.json"""
|
||||
repo_infos = get_git_info()
|
||||
save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
|
||||
|
||||
|
||||
def save_json(content, path, indent=4, **json_dump_kwargs):
|
||||
with open(path, "w") as f:
|
||||
json.dump(content, f, indent=indent, **json_dump_kwargs)
|
||||
|
||||
|
||||
def load_json(path):
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def get_git_info():
|
||||
repo = git.Repo(search_parent_directories=True)
|
||||
repo_infos = {
|
||||
"repo_id": str(repo),
|
||||
"repo_sha": str(repo.head.object.hexsha),
|
||||
"repo_branch": str(repo.active_branch),
|
||||
"hostname": str(socket.gethostname()),
|
||||
}
|
||||
return repo_infos
|
||||
|
||||
|
||||
def lmap(f: Callable, x: Iterable) -> List:
|
||||
"""list(map(f, x))"""
|
||||
return list(map(f, x))
|
||||
|
||||
|
||||
def pickle_save(obj, path):
|
||||
"""pickle.dump(obj, path)"""
|
||||
with open(path, "wb") as f:
|
||||
return pickle.dump(obj, f)
|
||||
|
||||
|
||||
def normalize_answer(s):
|
||||
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||
|
||||
def remove_articles(text):
|
||||
return re.sub(r"\b(a|an|the)\b", " ", text)
|
||||
|
||||
def white_space_fix(text):
|
||||
return " ".join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return "".join(ch for ch in text if ch not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
|
||||
def f1_score(prediction, ground_truth):
|
||||
prediction_tokens = normalize_answer(prediction).split()
|
||||
ground_truth_tokens = normalize_answer(ground_truth).split()
|
||||
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
||||
num_same = sum(common.values())
|
||||
if num_same == 0:
|
||||
return 0
|
||||
precision = 1.0 * num_same / len(prediction_tokens)
|
||||
recall = 1.0 * num_same / len(ground_truth_tokens)
|
||||
f1 = (2 * precision * recall) / (precision + recall)
|
||||
return f1
|
||||
|
||||
|
||||
def exact_match_score(prediction, ground_truth):
|
||||
return normalize_answer(prediction) == normalize_answer(ground_truth)
|
||||
|
||||
|
||||
def calculate_exact_match(output_lns: List[str], reference_lns: List[str]) -> Dict:
|
||||
assert len(output_lns) == len(reference_lns)
|
||||
em = 0
|
||||
for hypo, pred in zip(output_lns, reference_lns):
|
||||
em += exact_match_score(hypo, pred)
|
||||
if len(output_lns) > 0:
|
||||
em /= len(output_lns)
|
||||
return {"em": em}
|
||||
|
||||
|
||||
def is_rag_model(model_prefix):
|
||||
return model_prefix.startswith("rag")
|
||||
|
||||
|
||||
def set_extra_model_params(extra_params, hparams, config):
|
||||
equivalent_param = {p: p for p in extra_params}
|
||||
# T5 models don't have `dropout` param, they have `dropout_rate` instead
|
||||
equivalent_param["dropout"] = "dropout_rate"
|
||||
for p in extra_params:
|
||||
if getattr(hparams, p, None):
|
||||
if not hasattr(config, p) and not hasattr(config, equivalent_param[p]):
|
||||
logger.info("config doesn't have a `{}` attribute".format(p))
|
||||
delattr(hparams, p)
|
||||
continue
|
||||
set_p = p if hasattr(config, p) else equivalent_param[p]
|
||||
setattr(config, set_p, getattr(hparams, p))
|
||||
delattr(hparams, p)
|
||||
return hparams, config
|
||||
Reference in New Issue
Block a user