Black preview (#17217)

* Black preview

* Fixup too!

* Fix check copies

* Use the same version as the CI

* Bump black
This commit is contained in:
Sylvain Gugger
2022-05-12 16:25:55 -04:00
committed by GitHub
parent 9bd67ac7bb
commit afe5d42d8d
578 changed files with 8274 additions and 3296 deletions

View File

@@ -29,7 +29,8 @@ def get_checkpoint_callback(output_dir, metric):
exp = "{val_avg_em:.4f}-{step_count}"
else:
raise NotImplementedError(
f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function."
f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this"
" function."
)
checkpoint_callback = ModelCheckpoint(

View File

@@ -80,7 +80,10 @@ if __name__ == "__main__":
parser.add_argument(
"--config_name_or_path",
type=str,
help="Identifier of the model config to use, if not provided, resolves to a base config for a given ``model_type``",
help=(
"Identifier of the model config to use, if not provided, resolves to a base config for a given"
" ``model_type``"
),
)
args = parser.parse_args()

View File

@@ -146,7 +146,10 @@ def get_args():
"--model_type",
choices=["rag_sequence", "rag_token", "bart"],
type=str,
help="RAG model type: rag_sequence, rag_token or bart, if none specified, the type is inferred from the model_name_or_path",
help=(
"RAG model type: rag_sequence, rag_token or bart, if none specified, the type is inferred from the"
" model_name_or_path"
),
)
parser.add_argument(
"--index_name",
@@ -174,7 +177,10 @@ def get_args():
choices=["e2e", "retrieval"],
default="e2e",
type=str,
help="Evaluation mode, e2e calculates exact match and F1 of the downstream task, retrieval calculates precision@k.",
help=(
"Evaluation mode, e2e calculates exact match and F1 of the downstream task, retrieval calculates"
" precision@k."
),
)
parser.add_argument("--k", default=1, type=int, help="k for the precision@k calculation")
parser.add_argument(
@@ -196,9 +202,11 @@ def get_args():
default="qa",
type=str,
choices=["qa", "ans"],
help="Format of the gold data file"
"qa - a single line in the following format: question [tab] answer_list"
"ans - a single line of the gold file contains the expected answer string",
help=(
"Format of the gold data file"
"qa - a single line in the following format: question [tab] answer_list"
"ans - a single line of the gold file contains the expected answer string"
),
)
parser.add_argument(
"--predictions_path",

View File

@@ -383,29 +383,37 @@ class GenerativeQAModule(BaseTransformer):
"--max_source_length",
default=128,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
help=(
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
),
)
parser.add_argument(
"--max_target_length",
default=25,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
help=(
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
),
)
parser.add_argument(
"--val_max_target_length",
default=25,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
help=(
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
),
)
parser.add_argument(
"--test_max_target_length",
default=25,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
help=(
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
),
)
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
@@ -423,7 +431,10 @@ class GenerativeQAModule(BaseTransformer):
type=int,
default=-1,
required=False,
help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
help=(
"-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So"
" val_check_interval will effect it."
),
)
parser.add_argument(
"--distributed-port", type=int, default=-1, required=False, help="Port number for distributed training."
@@ -432,7 +443,10 @@ class GenerativeQAModule(BaseTransformer):
"--model_type",
choices=["rag_sequence", "rag_token", "bart", "t5"],
type=str,
help="RAG model type: sequence or token, if none specified, the type is inferred from the model_name_or_path",
help=(
"RAG model type: sequence or token, if none specified, the type is inferred from the"
" model_name_or_path"
),
)
return parser
@@ -442,39 +456,53 @@ class GenerativeQAModule(BaseTransformer):
"--index_name",
type=str,
default=None,
help="Name of the index to use: 'hf' for a canonical dataset from the datasets library (default), 'custom' for a local index, or 'legacy' for the orignal one)",
help=(
"Name of the index to use: 'hf' for a canonical dataset from the datasets library (default), 'custom'"
" for a local index, or 'legacy' for the orignal one)"
),
)
parser.add_argument(
"--passages_path",
type=str,
default=None,
help="Path to the dataset of passages for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
help=(
"Path to the dataset of passages for custom index. More info about custom indexes in the RagRetriever"
" documentation as well as in `examples/rag/use_own_knowledge_dataset.py`"
),
)
parser.add_argument(
"--index_path",
type=str,
default=None,
help="Path to the faiss index for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
help=(
"Path to the faiss index for custom index. More info about custom indexes in the RagRetriever"
" documentation as well as in `examples/rag/use_own_knowledge_dataset.py`"
),
)
parser.add_argument(
"--distributed_retriever",
choices=["ray", "pytorch"],
type=str,
default="pytorch",
help="What implementation to use for distributed retriever? If "
"pytorch is selected, the index is loaded on training "
"worker 0, and torch.distributed is used to handle "
"communication between training worker 0, and the other "
"training workers. If ray is selected, the Ray library is "
"used to create load the index on separate processes, "
"and Ray handles the communication between the training "
"workers and the retrieval actors.",
help=(
"What implementation to use for distributed retriever? If "
"pytorch is selected, the index is loaded on training "
"worker 0, and torch.distributed is used to handle "
"communication between training worker 0, and the other "
"training workers. If ray is selected, the Ray library is "
"used to create load the index on separate processes, "
"and Ray handles the communication between the training "
"workers and the retrieval actors."
),
)
parser.add_argument(
"--use_dummy_dataset",
type=bool,
default=False,
help="Whether to use the dummy version of the dataset index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
help=(
"Whether to use the dummy version of the dataset index. More info about custom indexes in the"
" RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`"
),
)
return parser
@@ -485,18 +513,22 @@ class GenerativeQAModule(BaseTransformer):
"--ray-address",
default="auto",
type=str,
help="The address of the Ray cluster to connect to. If not "
"specified, Ray will attempt to automatically detect the "
"cluster. Has no effect if pytorch is used as the distributed "
"retriever.",
help=(
"The address of the Ray cluster to connect to. If not "
"specified, Ray will attempt to automatically detect the "
"cluster. Has no effect if pytorch is used as the distributed "
"retriever."
),
)
parser.add_argument(
"--num_retrieval_workers",
type=int,
default=1,
help="The number of retrieval actors to use when Ray is selected"
"for the distributed retriever. Has no effect when "
"distributed_retriever is set to pytorch.",
help=(
"The number of retrieval actors to use when Ray is selected"
"for the distributed retriever. Has no effect when "
"distributed_retriever is set to pytorch."
),
)
return parser
@@ -514,7 +546,7 @@ def main(args=None, model=None) -> GenerativeQAModule:
named_actors = []
if args.distributed_retriever == "ray" and args.gpus > 1:
if not is_ray_available():
raise RuntimeError("Please install Ray to use the Ray " "distributed retriever.")
raise RuntimeError("Please install Ray to use the Ray distributed retriever.")
# Connect to an existing Ray cluster.
try:
ray.init(address=args.ray_address, namespace="rag")

View File

@@ -321,8 +321,10 @@ def add_generic_args(parser, root_dir) -> None:
"--fp16_opt_level",
type=str,
default="O2",
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html",
help=(
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html"
),
)
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")

View File

@@ -154,7 +154,10 @@ class RagExampleArguments:
dpr_ctx_encoder_model_name: str = field(
default="facebook/dpr-ctx_encoder-multiset-base",
metadata={
"help": "The DPR context encoder model to use. Either 'facebook/dpr-ctx_encoder-single-nq-base' or 'facebook/dpr-ctx_encoder-multiset-base'"
"help": (
"The DPR context encoder model to use. Either 'facebook/dpr-ctx_encoder-single-nq-base' or"
" 'facebook/dpr-ctx_encoder-multiset-base'"
)
},
)
output_dir: Optional[str] = field(
@@ -188,7 +191,9 @@ class IndexHnswArguments:
m: int = field(
default=128,
metadata={
"help": "The number of bi-directional links created for every new element during the HNSW index construction."
"help": (
"The number of bi-directional links created for every new element during the HNSW index construction."
)
},
)