From 8feb0cc96768a9d69fd1574b2727c6413131ff61 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Wed, 14 Oct 2020 11:35:00 -0400 Subject: [PATCH] fix examples/rag imports, tests (#7712) --- examples/rag/README.md | 37 ++++++++++++++++------ examples/rag/__init__.py | 5 +++ examples/rag/eval_rag.py | 2 +- examples/rag/finetune.py | 14 ++++---- examples/rag/test_distributed_retriever.py | 2 +- 5 files changed, 42 insertions(+), 18 deletions(-) diff --git a/examples/rag/README.md b/examples/rag/README.md index c35fe63005..65b126666e 100644 --- a/examples/rag/README.md +++ b/examples/rag/README.md @@ -65,26 +65,41 @@ Does He Love You Does He Love You Red Sandy Spika dress of Reba McEntire Greates We demonstrate how to evaluate retrieval against DPR evaluation data. You can download respective files from links listed [here](https://github.com/facebookresearch/DPR/blob/master/data/download_data.py#L39-L45). 1. Download and unzip the gold data file. We use the `biencoder-nq-dev` from https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz. + ```bash + wget https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz && gzip -d biencoder-nq-dev.json.gz + ``` + 2. Parse the unziped file using the `parse_dpr_relevance_data.py` ```bash + mkdir output # or wherever you want to save this python examples/rag/parse_dpr_relevance_data.py \ - --src_path path/to/unziped/biencoder-nq-dev.json \ - --evaluation_set path/to/output/biencoder-nq-dev.questions \ - --gold_data_path path/to/output/biencoder-nq-dev.pages + --src_path biencoder-nq-dev.json \ + --evaluation_set output/biencoder-nq-dev.questions \ + --gold_data_path output/biencoder-nq-dev.pages ``` 3. Run evaluation: - ```bash + ```bash + python examples/rag/eval_rag.py \ + --model_name_or_path facebook/rag-sequence-nq \ + --model_type rag_sequence \ + --evaluation_set output/biencoder-nq-dev.questions \ + --gold_data_path output/biencoder-nq-dev.pages \ + --predictions_path output/retrieval_preds.tsv \ + --eval_mode retrieval \ + --k 1 + ``` + ```bash + # EXPLANATION python examples/rag/eval_rag.py \ --model_name_or_path facebook/rag-sequence-nq \ # model name or path of the model we're evaluating --model_type rag_sequence \ # RAG model type (rag_token or rag_sequence) - --evaluation_set path/to/output/biencoder-nq-dev.questions \ # an input dataset for evaluation - --gold_data_path path/to/output/biencoder-nq-dev.pages \ # a dataset containing ground truth answers for samples from the evaluation_set - --predictions_path path/to/retrieval_preds.tsv \ # name of file where predictions will be stored + --evaluation_set output/biencoder-nq-dev.questions \ # an input dataset for evaluation + --gold_data_path poutput/biencoder-nq-dev.pages \ # a dataset containing ground truth answers for samples from the evaluation_set + --predictions_path output/retrieval_preds.tsv \ # name of file where predictions will be stored --eval_mode retrieval \ # indicates whether we're performing retrieval evaluation or e2e evaluation --k 1 # parameter k for the precision@k metric + ``` - - ## End-to-end evaluation We support two formats of the gold data file (controlled by the `gold_data_mode` parameter): @@ -97,7 +112,9 @@ who is the owner of reading football club ['Xiu Li Dai', 'Dai Yongge', 'Dai Xiul Xiu Li Dai ``` -Predictions of the model for the samples from the `evaluation_set` will be saved under the path specified by the `predictions_path` parameter. If this path already exists, the script will use saved predictions to calculate metrics. Add `--recalculate` parameter to force the script to perform inference from scratch. +Predictions of the model for the samples from the `evaluation_set` will be saved under the path specified by the `predictions_path` parameter. +If this path already exists, the script will use saved predictions to calculate metrics. +Add `--recalculate` parameter to force the script to perform inference from scratch. An example e2e evaluation run could look as follows: ```bash diff --git a/examples/rag/__init__.py b/examples/rag/__init__.py index e69de29bb2..3cee09bb7f 100644 --- a/examples/rag/__init__.py +++ b/examples/rag/__init__.py @@ -0,0 +1,5 @@ +import os +import sys + + +sys.path.insert(1, os.path.dirname(os.path.realpath(__file__))) diff --git a/examples/rag/eval_rag.py b/examples/rag/eval_rag.py index 452baf7cb6..baa956ecab 100644 --- a/examples/rag/eval_rag.py +++ b/examples/rag/eval_rag.py @@ -15,7 +15,7 @@ from transformers import logging as transformers_logging sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # isort:skip -from examples.rag.utils import exact_match_score, f1_score # noqa: E402 # isort:skip +from utils import exact_match_score, f1_score # noqa: E402 # isort:skip logger = logging.getLogger(__name__) diff --git a/examples/rag/finetune.py b/examples/rag/finetune.py index c76045fc3d..a56fe4eac9 100644 --- a/examples/rag/finetune.py +++ b/examples/rag/finetune.py @@ -31,16 +31,13 @@ from transformers import ( from transformers import logging as transformers_logging -sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip - -from examples.lightning_base import BaseTransformer, add_generic_args, generic_train # noqa: E402 # isort:skip -from examples.rag.callbacks import ( # noqa: E402 # isort:skip +from callbacks import ( # noqa: E402 # isort:skipq get_checkpoint_callback, get_early_stopping_callback, Seq2SeqLoggingCallback, ) -from examples.rag.distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip -from examples.rag.utils import ( # noqa: E402 # isort:skip +from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip +from utils import ( # noqa: E402 # isort:skip calculate_exact_match, flatten_list, get_git_info, @@ -53,6 +50,11 @@ from examples.rag.utils import ( # noqa: E402 # isort:skip Seq2SeqDataset, ) +# need the parent dir module +sys.path.insert(2, str(Path(__file__).resolve().parents[1])) +from lightning_base import BaseTransformer, add_generic_args, generic_train # noqa + + logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) diff --git a/examples/rag/test_distributed_retriever.py b/examples/rag/test_distributed_retriever.py index 387c91abbe..be0ec99ba8 100644 --- a/examples/rag/test_distributed_retriever.py +++ b/examples/rag/test_distributed_retriever.py @@ -23,7 +23,7 @@ from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FI sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip -from examples.rag.distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip +from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip def require_distributed_retrieval(test_case):