Fix typos in README and bugs in RAG example code for end-to-end evaluation and finetuning (#9355)
* fix a bug in eval_batch_retrieval * should return parser as well as other staticmethod * remove duplicate argument * these kwargs are no longer accepted (cause TypeError in self.generator.generate of modeling_rag.py) * fixed file paths in README * moved an arg to add_ray_specific_args
This commit is contained in:
committed by
GitHub
parent
c4fd609afb
commit
d944966b19
@@ -23,10 +23,10 @@ test.source
|
|||||||
test.target
|
test.target
|
||||||
```
|
```
|
||||||
|
|
||||||
A sample finetuning command (run ` ./examples/rag/finetune_rag.py --help` to list all available options):
|
A sample finetuning command (run ` ./examples/research_projects/rag/finetune_rag.py --help` to list all available options):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python examples/rag/finetune_rag.py \
|
python examples/research_projects/rag/finetune_rag.py \
|
||||||
--data_dir $DATA_DIR \
|
--data_dir $DATA_DIR \
|
||||||
--output_dir $OUTPUT_DIR \
|
--output_dir $OUTPUT_DIR \
|
||||||
--model_name_or_path $MODEL_NAME_OR_PATH \
|
--model_name_or_path $MODEL_NAME_OR_PATH \
|
||||||
@@ -42,7 +42,7 @@ The `base` models initialize the question encoder with [`facebook/dpr-question_e
|
|||||||
|
|
||||||
If you would like to initialize finetuning with a base model using different question encoder and generator architectures, you can build it with a consolidation script, e.g.:
|
If you would like to initialize finetuning with a base model using different question encoder and generator architectures, you can build it with a consolidation script, e.g.:
|
||||||
```
|
```
|
||||||
python examples/rag/consolidate_rag_checkpoint.py \
|
python examples/research_projects/rag/consolidate_rag_checkpoint.py \
|
||||||
--model_type rag_sequence \
|
--model_type rag_sequence \
|
||||||
--generator_name_or_path facebook/bart-large-cnn \
|
--generator_name_or_path facebook/bart-large-cnn \
|
||||||
--question_encoder_name_or_path facebook/dpr-question_encoder-single-nq-base \
|
--question_encoder_name_or_path facebook/dpr-question_encoder-single-nq-base \
|
||||||
@@ -71,7 +71,7 @@ Also make sure to start the Ray cluster before running fine-tuning.
|
|||||||
# Start a single-node Ray cluster.
|
# Start a single-node Ray cluster.
|
||||||
ray start --head
|
ray start --head
|
||||||
|
|
||||||
python examples/rag/finetune_rag.py \
|
python examples/research_projects/rag/finetune_rag.py \
|
||||||
--data_dir $DATA_DIR \
|
--data_dir $DATA_DIR \
|
||||||
--output_dir $OUTPUT_DIR \
|
--output_dir $OUTPUT_DIR \
|
||||||
--model_name_or_path $MODEL_NAME_OR_PATH \
|
--model_name_or_path $MODEL_NAME_OR_PATH \
|
||||||
@@ -113,14 +113,14 @@ We demonstrate how to evaluate retrieval against DPR evaluation data. You can do
|
|||||||
2. Parse the unziped file using the `parse_dpr_relevance_data.py`
|
2. Parse the unziped file using the `parse_dpr_relevance_data.py`
|
||||||
```bash
|
```bash
|
||||||
mkdir output # or wherever you want to save this
|
mkdir output # or wherever you want to save this
|
||||||
python examples/rag/parse_dpr_relevance_data.py \
|
python examples/research_projects/rag/parse_dpr_relevance_data.py \
|
||||||
--src_path biencoder-nq-dev.json \
|
--src_path biencoder-nq-dev.json \
|
||||||
--evaluation_set output/biencoder-nq-dev.questions \
|
--evaluation_set output/biencoder-nq-dev.questions \
|
||||||
--gold_data_path output/biencoder-nq-dev.pages
|
--gold_data_path output/biencoder-nq-dev.pages
|
||||||
```
|
```
|
||||||
3. Run evaluation:
|
3. Run evaluation:
|
||||||
```bash
|
```bash
|
||||||
python examples/rag/eval_rag.py \
|
python examples/research_projects/rag/eval_rag.py \
|
||||||
--model_name_or_path facebook/rag-sequence-nq \
|
--model_name_or_path facebook/rag-sequence-nq \
|
||||||
--model_type rag_sequence \
|
--model_type rag_sequence \
|
||||||
--evaluation_set output/biencoder-nq-dev.questions \
|
--evaluation_set output/biencoder-nq-dev.questions \
|
||||||
@@ -131,7 +131,7 @@ We demonstrate how to evaluate retrieval against DPR evaluation data. You can do
|
|||||||
```
|
```
|
||||||
```bash
|
```bash
|
||||||
# EXPLANATION
|
# EXPLANATION
|
||||||
python examples/rag/eval_rag.py \
|
python examples/research_projects/rag/eval_rag.py \
|
||||||
--model_name_or_path facebook/rag-sequence-nq \ # model name or path of the model we're evaluating
|
--model_name_or_path facebook/rag-sequence-nq \ # model name or path of the model we're evaluating
|
||||||
--model_type rag_sequence \ # RAG model type (rag_token or rag_sequence)
|
--model_type rag_sequence \ # RAG model type (rag_token or rag_sequence)
|
||||||
--evaluation_set output/biencoder-nq-dev.questions \ # an input dataset for evaluation
|
--evaluation_set output/biencoder-nq-dev.questions \ # an input dataset for evaluation
|
||||||
@@ -159,7 +159,7 @@ Add `--recalculate` parameter to force the script to perform inference from scra
|
|||||||
|
|
||||||
An example e2e evaluation run could look as follows:
|
An example e2e evaluation run could look as follows:
|
||||||
```bash
|
```bash
|
||||||
python examples/rag/eval_rag.py \
|
python examples/research_projects/rag/eval_rag.py \
|
||||||
--model_name_or_path facebook/rag-sequence-nq \
|
--model_name_or_path facebook/rag-sequence-nq \
|
||||||
--model_type rag_sequence \
|
--model_type rag_sequence \
|
||||||
--evaluation_set path/to/test.source \
|
--evaluation_set path/to/test.source \
|
||||||
@@ -179,14 +179,14 @@ With `use_custom_knowledge_dataset.py` you can build your own knowledge source,
|
|||||||
|
|
||||||
For instance, if documents are serialized as tab-separated csv files with the columns "title" and "text", one can use `use_own_knowledge_dataset.py` as follows:
|
For instance, if documents are serialized as tab-separated csv files with the columns "title" and "text", one can use `use_own_knowledge_dataset.py` as follows:
|
||||||
```bash
|
```bash
|
||||||
python examples/rag/use_own_knowledge_dataset.py \
|
python examples/research_projects/rag/use_own_knowledge_dataset.py \
|
||||||
--csv_path path/to/my_csv \
|
--csv_path path/to/my_csv \
|
||||||
--output_dir path/to/my_knowledge_dataset \
|
--output_dir path/to/my_knowledge_dataset \
|
||||||
```
|
```
|
||||||
|
|
||||||
The created outputs in `path/to/my_knowledge_dataset` can then be used to finetune RAG as follows:
|
The created outputs in `path/to/my_knowledge_dataset` can then be used to finetune RAG as follows:
|
||||||
```bash
|
```bash
|
||||||
python examples/rag/finetune_rag.py \
|
python examples/research_projects/rag/finetune_rag.py \
|
||||||
--data_dir $DATA_DIR \
|
--data_dir $DATA_DIR \
|
||||||
--output_dir $OUTPUT_DIR \
|
--output_dir $OUTPUT_DIR \
|
||||||
--model_name_or_path $MODEL_NAME_OR_PATH \
|
--model_name_or_path $MODEL_NAME_OR_PATH \
|
||||||
|
|||||||
@@ -130,8 +130,6 @@ def evaluate_batch_e2e(args, rag_model, questions):
|
|||||||
early_stopping=False,
|
early_stopping=False,
|
||||||
num_return_sequences=1,
|
num_return_sequences=1,
|
||||||
bad_words_ids=[[0, 0]], # BART likes to repeat BOS tokens, dont allow it to generate more than one
|
bad_words_ids=[[0, 0]], # BART likes to repeat BOS tokens, dont allow it to generate more than one
|
||||||
clean_up_tokenization=True,
|
|
||||||
print_docs=args.print_docs,
|
|
||||||
)
|
)
|
||||||
answers = rag_model.retriever.generator_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
answers = rag_model.retriever.generator_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
|
||||||
|
|||||||
@@ -443,7 +443,6 @@ class GenerativeQAModule(BaseTransformer):
|
|||||||
type=str,
|
type=str,
|
||||||
help="RAG model type: sequence or token, if none specified, the type is inferred from the model_name_or_path",
|
help="RAG model type: sequence or token, if none specified, the type is inferred from the model_name_or_path",
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -486,27 +485,10 @@ class GenerativeQAModule(BaseTransformer):
|
|||||||
default=False,
|
default=False,
|
||||||
help="Whether to use the dummy version of the dataset index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
help="Whether to use the dummy version of the dataset index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
||||||
)
|
)
|
||||||
|
return parser
|
||||||
parser.add_argument(
|
|
||||||
"--num_retrieval_workers",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="The number of retrieval actors to use when Ray is selected"
|
|
||||||
"for the distributed retriever. Has no effect when "
|
|
||||||
"distributed_retriever is set to pytorch.",
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_ray_specific_args(parser):
|
def add_ray_specific_args(parser):
|
||||||
parser.add_argument(
|
|
||||||
"--num_retrieval_workers",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="The number of retrieval actors to use when Ray is selected"
|
|
||||||
"for the distributed retriever. Has no effect when "
|
|
||||||
"distributed_retriever is set to pytorch.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ray cluster address.
|
# Ray cluster address.
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ray-address",
|
"--ray-address",
|
||||||
@@ -517,12 +499,18 @@ class GenerativeQAModule(BaseTransformer):
|
|||||||
"cluster. Has no effect if pytorch is used as the distributed "
|
"cluster. Has no effect if pytorch is used as the distributed "
|
||||||
"retriever.",
|
"retriever.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_retrieval_workers",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="The number of retrieval actors to use when Ray is selected"
|
||||||
|
"for the distributed retriever. Has no effect when "
|
||||||
|
"distributed_retriever is set to pytorch.",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def main(args=None, model=None) -> GenerativeQAModule:
|
def main(args=None, model=None) -> GenerativeQAModule:
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
|
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
|
||||||
|
|||||||
Reference in New Issue
Block a user