Black preview (#17217)
* Black preview * Fixup too! * Fix check copies * Use the same version as the CI * Bump black
This commit is contained in:
@@ -29,7 +29,8 @@ def get_checkpoint_callback(output_dir, metric):
|
||||
exp = "{val_avg_em:.4f}-{step_count}"
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function."
|
||||
f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this"
|
||||
" function."
|
||||
)
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
|
||||
@@ -80,7 +80,10 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--config_name_or_path",
|
||||
type=str,
|
||||
help="Identifier of the model config to use, if not provided, resolves to a base config for a given ``model_type``",
|
||||
help=(
|
||||
"Identifier of the model config to use, if not provided, resolves to a base config for a given"
|
||||
" ``model_type``"
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -146,7 +146,10 @@ def get_args():
|
||||
"--model_type",
|
||||
choices=["rag_sequence", "rag_token", "bart"],
|
||||
type=str,
|
||||
help="RAG model type: rag_sequence, rag_token or bart, if none specified, the type is inferred from the model_name_or_path",
|
||||
help=(
|
||||
"RAG model type: rag_sequence, rag_token or bart, if none specified, the type is inferred from the"
|
||||
" model_name_or_path"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index_name",
|
||||
@@ -174,7 +177,10 @@ def get_args():
|
||||
choices=["e2e", "retrieval"],
|
||||
default="e2e",
|
||||
type=str,
|
||||
help="Evaluation mode, e2e calculates exact match and F1 of the downstream task, retrieval calculates precision@k.",
|
||||
help=(
|
||||
"Evaluation mode, e2e calculates exact match and F1 of the downstream task, retrieval calculates"
|
||||
" precision@k."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--k", default=1, type=int, help="k for the precision@k calculation")
|
||||
parser.add_argument(
|
||||
@@ -196,9 +202,11 @@ def get_args():
|
||||
default="qa",
|
||||
type=str,
|
||||
choices=["qa", "ans"],
|
||||
help="Format of the gold data file"
|
||||
"qa - a single line in the following format: question [tab] answer_list"
|
||||
"ans - a single line of the gold file contains the expected answer string",
|
||||
help=(
|
||||
"Format of the gold data file"
|
||||
"qa - a single line in the following format: question [tab] answer_list"
|
||||
"ans - a single line of the gold file contains the expected answer string"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--predictions_path",
|
||||
|
||||
@@ -383,29 +383,37 @@ class GenerativeQAModule(BaseTransformer):
|
||||
"--max_source_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
help=(
|
||||
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_target_length",
|
||||
default=25,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
help=(
|
||||
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val_max_target_length",
|
||||
default=25,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
help=(
|
||||
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_max_target_length",
|
||||
default=25,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
help=(
|
||||
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
|
||||
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
@@ -423,7 +431,10 @@ class GenerativeQAModule(BaseTransformer):
|
||||
type=int,
|
||||
default=-1,
|
||||
required=False,
|
||||
help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
|
||||
help=(
|
||||
"-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So"
|
||||
" val_check_interval will effect it."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distributed-port", type=int, default=-1, required=False, help="Port number for distributed training."
|
||||
@@ -432,7 +443,10 @@ class GenerativeQAModule(BaseTransformer):
|
||||
"--model_type",
|
||||
choices=["rag_sequence", "rag_token", "bart", "t5"],
|
||||
type=str,
|
||||
help="RAG model type: sequence or token, if none specified, the type is inferred from the model_name_or_path",
|
||||
help=(
|
||||
"RAG model type: sequence or token, if none specified, the type is inferred from the"
|
||||
" model_name_or_path"
|
||||
),
|
||||
)
|
||||
return parser
|
||||
|
||||
@@ -442,39 +456,53 @@ class GenerativeQAModule(BaseTransformer):
|
||||
"--index_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Name of the index to use: 'hf' for a canonical dataset from the datasets library (default), 'custom' for a local index, or 'legacy' for the orignal one)",
|
||||
help=(
|
||||
"Name of the index to use: 'hf' for a canonical dataset from the datasets library (default), 'custom'"
|
||||
" for a local index, or 'legacy' for the orignal one)"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--passages_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the dataset of passages for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
||||
help=(
|
||||
"Path to the dataset of passages for custom index. More info about custom indexes in the RagRetriever"
|
||||
" documentation as well as in `examples/rag/use_own_knowledge_dataset.py`"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the faiss index for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
||||
help=(
|
||||
"Path to the faiss index for custom index. More info about custom indexes in the RagRetriever"
|
||||
" documentation as well as in `examples/rag/use_own_knowledge_dataset.py`"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distributed_retriever",
|
||||
choices=["ray", "pytorch"],
|
||||
type=str,
|
||||
default="pytorch",
|
||||
help="What implementation to use for distributed retriever? If "
|
||||
"pytorch is selected, the index is loaded on training "
|
||||
"worker 0, and torch.distributed is used to handle "
|
||||
"communication between training worker 0, and the other "
|
||||
"training workers. If ray is selected, the Ray library is "
|
||||
"used to create load the index on separate processes, "
|
||||
"and Ray handles the communication between the training "
|
||||
"workers and the retrieval actors.",
|
||||
help=(
|
||||
"What implementation to use for distributed retriever? If "
|
||||
"pytorch is selected, the index is loaded on training "
|
||||
"worker 0, and torch.distributed is used to handle "
|
||||
"communication between training worker 0, and the other "
|
||||
"training workers. If ray is selected, the Ray library is "
|
||||
"used to create load the index on separate processes, "
|
||||
"and Ray handles the communication between the training "
|
||||
"workers and the retrieval actors."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_dummy_dataset",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Whether to use the dummy version of the dataset index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
||||
help=(
|
||||
"Whether to use the dummy version of the dataset index. More info about custom indexes in the"
|
||||
" RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`"
|
||||
),
|
||||
)
|
||||
return parser
|
||||
|
||||
@@ -485,18 +513,22 @@ class GenerativeQAModule(BaseTransformer):
|
||||
"--ray-address",
|
||||
default="auto",
|
||||
type=str,
|
||||
help="The address of the Ray cluster to connect to. If not "
|
||||
"specified, Ray will attempt to automatically detect the "
|
||||
"cluster. Has no effect if pytorch is used as the distributed "
|
||||
"retriever.",
|
||||
help=(
|
||||
"The address of the Ray cluster to connect to. If not "
|
||||
"specified, Ray will attempt to automatically detect the "
|
||||
"cluster. Has no effect if pytorch is used as the distributed "
|
||||
"retriever."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_retrieval_workers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of retrieval actors to use when Ray is selected"
|
||||
"for the distributed retriever. Has no effect when "
|
||||
"distributed_retriever is set to pytorch.",
|
||||
help=(
|
||||
"The number of retrieval actors to use when Ray is selected"
|
||||
"for the distributed retriever. Has no effect when "
|
||||
"distributed_retriever is set to pytorch."
|
||||
),
|
||||
)
|
||||
return parser
|
||||
|
||||
@@ -514,7 +546,7 @@ def main(args=None, model=None) -> GenerativeQAModule:
|
||||
named_actors = []
|
||||
if args.distributed_retriever == "ray" and args.gpus > 1:
|
||||
if not is_ray_available():
|
||||
raise RuntimeError("Please install Ray to use the Ray " "distributed retriever.")
|
||||
raise RuntimeError("Please install Ray to use the Ray distributed retriever.")
|
||||
# Connect to an existing Ray cluster.
|
||||
try:
|
||||
ray.init(address=args.ray_address, namespace="rag")
|
||||
|
||||
@@ -321,8 +321,10 @@ def add_generic_args(parser, root_dir) -> None:
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O2",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
help=(
|
||||
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html"
|
||||
),
|
||||
)
|
||||
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
|
||||
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
|
||||
|
||||
@@ -154,7 +154,10 @@ class RagExampleArguments:
|
||||
dpr_ctx_encoder_model_name: str = field(
|
||||
default="facebook/dpr-ctx_encoder-multiset-base",
|
||||
metadata={
|
||||
"help": "The DPR context encoder model to use. Either 'facebook/dpr-ctx_encoder-single-nq-base' or 'facebook/dpr-ctx_encoder-multiset-base'"
|
||||
"help": (
|
||||
"The DPR context encoder model to use. Either 'facebook/dpr-ctx_encoder-single-nq-base' or"
|
||||
" 'facebook/dpr-ctx_encoder-multiset-base'"
|
||||
)
|
||||
},
|
||||
)
|
||||
output_dir: Optional[str] = field(
|
||||
@@ -188,7 +191,9 @@ class IndexHnswArguments:
|
||||
m: int = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The number of bi-directional links created for every new element during the HNSW index construction."
|
||||
"help": (
|
||||
"The number of bi-directional links created for every new element during the HNSW index construction."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user