Black preview (#17217)

* Black preview

* Fixup too!

* Fix check copies

* Use the same version as the CI

* Bump black
This commit is contained in:
Sylvain Gugger
2022-05-12 16:25:55 -04:00
committed by GitHub
parent 9bd67ac7bb
commit afe5d42d8d
578 changed files with 8274 additions and 3296 deletions

View File

@@ -60,7 +60,7 @@ class GroupedBatchSampler(BatchSampler):
def __init__(self, sampler, group_ids, batch_size):
if not isinstance(sampler, Sampler):
raise ValueError(
"sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}".format(sampler)
"sampler should be an instance of torch.utils.data.Sampler, but got sampler={}".format(sampler)
)
self.sampler = sampler
self.group_ids = group_ids

View File

@@ -518,7 +518,10 @@ def main():
"--teacher_type",
default=None,
type=str,
help="Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.",
help=(
"Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for"
" distillation."
),
)
parser.add_argument(
"--teacher_name_or_path",
@@ -590,8 +593,10 @@ def main():
"--max_seq_length",
default=384,
type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
"longer than this will be truncated, and sequences shorter than this will be padded.",
help=(
"The maximum total input sequence length after WordPiece tokenization. Sequences "
"longer than this will be truncated, and sequences shorter than this will be padded."
),
)
parser.add_argument(
"--doc_stride",
@@ -603,8 +608,10 @@ def main():
"--max_query_length",
default=64,
type=int,
help="The maximum number of tokens for the question. Questions longer than this will "
"be truncated to this length.",
help=(
"The maximum number of tokens for the question. Questions longer than this will "
"be truncated to this length."
),
)
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
@@ -649,14 +656,18 @@ def main():
"--max_answer_length",
default=30,
type=int,
help="The maximum length of an answer that can be generated. This is needed because the start "
"and end predictions are not conditioned on one another.",
help=(
"The maximum length of an answer that can be generated. This is needed because the start "
"and end predictions are not conditioned on one another."
),
)
parser.add_argument(
"--verbose_logging",
action="store_true",
help="If true, all of the warnings related to data processing will be printed. "
"A number of warnings are expected for a normal SQuAD evaluation.",
help=(
"If true, all of the warnings related to data processing will be printed. "
"A number of warnings are expected for a normal SQuAD evaluation."
),
)
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
@@ -685,8 +696,10 @@ def main():
"--fp16_opt_level",
type=str,
default="O1",
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html",
help=(
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html"
),
)
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")

View File

@@ -25,7 +25,10 @@ from transformers import GPT2LMHeadModel, RobertaForMaskedLM
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned Distillation"
description=(
"Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned"
" Distillation"
)
)
parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"])
parser.add_argument("--model_name", default="roberta-large", type=str)

View File

@@ -25,7 +25,10 @@ from transformers import BertForMaskedLM
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation"
description=(
"Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned"
" Distillation"
)
)
parser.add_argument("--model_type", default="bert", choices=["bert"])
parser.add_argument("--model_name", default="bert-base-uncased", type=str)

View File

@@ -207,8 +207,10 @@ def main():
"--fp16_opt_level",
type=str,
default="O1",
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html",
help=(
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html"
),
)
parser.add_argument("--n_gpu", type=int, default=1, help="Number of GPUs in the node.")
parser.add_argument("--local_rank", type=int, default=-1, help="Distributed training - Local rank")
@@ -226,8 +228,8 @@ def main():
if os.path.exists(args.dump_path):
if not args.force:
raise ValueError(
f"Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite it"
"Use `--force` if you want to overwrite it"
f"Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite"
" itUse `--force` if you want to overwrite it"
)
else:
shutil.rmtree(args.dump_path)