fix examples/rag imports, tests (#7712)
This commit is contained in:
@@ -65,26 +65,41 @@ Does He Love You Does He Love You Red Sandy Spika dress of Reba McEntire Greates
|
|||||||
We demonstrate how to evaluate retrieval against DPR evaluation data. You can download respective files from links listed [here](https://github.com/facebookresearch/DPR/blob/master/data/download_data.py#L39-L45).
|
We demonstrate how to evaluate retrieval against DPR evaluation data. You can download respective files from links listed [here](https://github.com/facebookresearch/DPR/blob/master/data/download_data.py#L39-L45).
|
||||||
|
|
||||||
1. Download and unzip the gold data file. We use the `biencoder-nq-dev` from https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz.
|
1. Download and unzip the gold data file. We use the `biencoder-nq-dev` from https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz.
|
||||||
|
```bash
|
||||||
|
wget https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz && gzip -d biencoder-nq-dev.json.gz
|
||||||
|
```
|
||||||
|
|
||||||
2. Parse the unziped file using the `parse_dpr_relevance_data.py`
|
2. Parse the unziped file using the `parse_dpr_relevance_data.py`
|
||||||
```bash
|
```bash
|
||||||
|
mkdir output # or wherever you want to save this
|
||||||
python examples/rag/parse_dpr_relevance_data.py \
|
python examples/rag/parse_dpr_relevance_data.py \
|
||||||
--src_path path/to/unziped/biencoder-nq-dev.json \
|
--src_path biencoder-nq-dev.json \
|
||||||
--evaluation_set path/to/output/biencoder-nq-dev.questions \
|
--evaluation_set output/biencoder-nq-dev.questions \
|
||||||
--gold_data_path path/to/output/biencoder-nq-dev.pages
|
--gold_data_path output/biencoder-nq-dev.pages
|
||||||
```
|
```
|
||||||
3. Run evaluation:
|
3. Run evaluation:
|
||||||
```bash
|
```bash
|
||||||
|
python examples/rag/eval_rag.py \
|
||||||
|
--model_name_or_path facebook/rag-sequence-nq \
|
||||||
|
--model_type rag_sequence \
|
||||||
|
--evaluation_set output/biencoder-nq-dev.questions \
|
||||||
|
--gold_data_path output/biencoder-nq-dev.pages \
|
||||||
|
--predictions_path output/retrieval_preds.tsv \
|
||||||
|
--eval_mode retrieval \
|
||||||
|
--k 1
|
||||||
|
```
|
||||||
|
```bash
|
||||||
|
# EXPLANATION
|
||||||
python examples/rag/eval_rag.py \
|
python examples/rag/eval_rag.py \
|
||||||
--model_name_or_path facebook/rag-sequence-nq \ # model name or path of the model we're evaluating
|
--model_name_or_path facebook/rag-sequence-nq \ # model name or path of the model we're evaluating
|
||||||
--model_type rag_sequence \ # RAG model type (rag_token or rag_sequence)
|
--model_type rag_sequence \ # RAG model type (rag_token or rag_sequence)
|
||||||
--evaluation_set path/to/output/biencoder-nq-dev.questions \ # an input dataset for evaluation
|
--evaluation_set output/biencoder-nq-dev.questions \ # an input dataset for evaluation
|
||||||
--gold_data_path path/to/output/biencoder-nq-dev.pages \ # a dataset containing ground truth answers for samples from the evaluation_set
|
--gold_data_path poutput/biencoder-nq-dev.pages \ # a dataset containing ground truth answers for samples from the evaluation_set
|
||||||
--predictions_path path/to/retrieval_preds.tsv \ # name of file where predictions will be stored
|
--predictions_path output/retrieval_preds.tsv \ # name of file where predictions will be stored
|
||||||
--eval_mode retrieval \ # indicates whether we're performing retrieval evaluation or e2e evaluation
|
--eval_mode retrieval \ # indicates whether we're performing retrieval evaluation or e2e evaluation
|
||||||
--k 1 # parameter k for the precision@k metric
|
--k 1 # parameter k for the precision@k metric
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## End-to-end evaluation
|
## End-to-end evaluation
|
||||||
|
|
||||||
We support two formats of the gold data file (controlled by the `gold_data_mode` parameter):
|
We support two formats of the gold data file (controlled by the `gold_data_mode` parameter):
|
||||||
@@ -97,7 +112,9 @@ who is the owner of reading football club ['Xiu Li Dai', 'Dai Yongge', 'Dai Xiul
|
|||||||
Xiu Li Dai
|
Xiu Li Dai
|
||||||
```
|
```
|
||||||
|
|
||||||
Predictions of the model for the samples from the `evaluation_set` will be saved under the path specified by the `predictions_path` parameter. If this path already exists, the script will use saved predictions to calculate metrics. Add `--recalculate` parameter to force the script to perform inference from scratch.
|
Predictions of the model for the samples from the `evaluation_set` will be saved under the path specified by the `predictions_path` parameter.
|
||||||
|
If this path already exists, the script will use saved predictions to calculate metrics.
|
||||||
|
Add `--recalculate` parameter to force the script to perform inference from scratch.
|
||||||
|
|
||||||
An example e2e evaluation run could look as follows:
|
An example e2e evaluation run could look as follows:
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -0,0 +1,5 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
sys.path.insert(1, os.path.dirname(os.path.realpath(__file__)))
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from transformers import logging as transformers_logging
|
|||||||
|
|
||||||
|
|
||||||
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # isort:skip
|
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # isort:skip
|
||||||
from examples.rag.utils import exact_match_score, f1_score # noqa: E402 # isort:skip
|
from utils import exact_match_score, f1_score # noqa: E402 # isort:skip
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -31,16 +31,13 @@ from transformers import (
|
|||||||
from transformers import logging as transformers_logging
|
from transformers import logging as transformers_logging
|
||||||
|
|
||||||
|
|
||||||
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
|
from callbacks import ( # noqa: E402 # isort:skipq
|
||||||
|
|
||||||
from examples.lightning_base import BaseTransformer, add_generic_args, generic_train # noqa: E402 # isort:skip
|
|
||||||
from examples.rag.callbacks import ( # noqa: E402 # isort:skip
|
|
||||||
get_checkpoint_callback,
|
get_checkpoint_callback,
|
||||||
get_early_stopping_callback,
|
get_early_stopping_callback,
|
||||||
Seq2SeqLoggingCallback,
|
Seq2SeqLoggingCallback,
|
||||||
)
|
)
|
||||||
from examples.rag.distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||||
from examples.rag.utils import ( # noqa: E402 # isort:skip
|
from utils import ( # noqa: E402 # isort:skip
|
||||||
calculate_exact_match,
|
calculate_exact_match,
|
||||||
flatten_list,
|
flatten_list,
|
||||||
get_git_info,
|
get_git_info,
|
||||||
@@ -53,6 +50,11 @@ from examples.rag.utils import ( # noqa: E402 # isort:skip
|
|||||||
Seq2SeqDataset,
|
Seq2SeqDataset,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# need the parent dir module
|
||||||
|
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
|
||||||
|
from lightning_base import BaseTransformer, add_generic_args, generic_train # noqa
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FI
|
|||||||
|
|
||||||
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
|
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
|
||||||
|
|
||||||
from examples.rag.distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||||
|
|
||||||
|
|
||||||
def require_distributed_retrieval(test_case):
|
def require_distributed_retrieval(test_case):
|
||||||
|
|||||||
Reference in New Issue
Block a user