Black preview (#17217)
* Black preview * Fixup too! * Fix check copies * Use the same version as the CI * Bump black
This commit is contained in:
@@ -60,7 +60,7 @@ class GroupedBatchSampler(BatchSampler):
|
||||
def __init__(self, sampler, group_ids, batch_size):
|
||||
if not isinstance(sampler, Sampler):
|
||||
raise ValueError(
|
||||
"sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}".format(sampler)
|
||||
"sampler should be an instance of torch.utils.data.Sampler, but got sampler={}".format(sampler)
|
||||
)
|
||||
self.sampler = sampler
|
||||
self.group_ids = group_ids
|
||||
|
||||
@@ -518,7 +518,10 @@ def main():
|
||||
"--teacher_type",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.",
|
||||
help=(
|
||||
"Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for"
|
||||
" distillation."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--teacher_name_or_path",
|
||||
@@ -590,8 +593,10 @@ def main():
|
||||
"--max_seq_length",
|
||||
default=384,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
|
||||
"longer than this will be truncated, and sequences shorter than this will be padded.",
|
||||
help=(
|
||||
"The maximum total input sequence length after WordPiece tokenization. Sequences "
|
||||
"longer than this will be truncated, and sequences shorter than this will be padded."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--doc_stride",
|
||||
@@ -603,8 +608,10 @@ def main():
|
||||
"--max_query_length",
|
||||
default=64,
|
||||
type=int,
|
||||
help="The maximum number of tokens for the question. Questions longer than this will "
|
||||
"be truncated to this length.",
|
||||
help=(
|
||||
"The maximum number of tokens for the question. Questions longer than this will "
|
||||
"be truncated to this length."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||
@@ -649,14 +656,18 @@ def main():
|
||||
"--max_answer_length",
|
||||
default=30,
|
||||
type=int,
|
||||
help="The maximum length of an answer that can be generated. This is needed because the start "
|
||||
"and end predictions are not conditioned on one another.",
|
||||
help=(
|
||||
"The maximum length of an answer that can be generated. This is needed because the start "
|
||||
"and end predictions are not conditioned on one another."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose_logging",
|
||||
action="store_true",
|
||||
help="If true, all of the warnings related to data processing will be printed. "
|
||||
"A number of warnings are expected for a normal SQuAD evaluation.",
|
||||
help=(
|
||||
"If true, all of the warnings related to data processing will be printed. "
|
||||
"A number of warnings are expected for a normal SQuAD evaluation."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||
@@ -685,8 +696,10 @@ def main():
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
help=(
|
||||
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html"
|
||||
),
|
||||
)
|
||||
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
||||
|
||||
@@ -25,7 +25,10 @@ from transformers import GPT2LMHeadModel, RobertaForMaskedLM
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned Distillation"
|
||||
description=(
|
||||
"Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned"
|
||||
" Distillation"
|
||||
)
|
||||
)
|
||||
parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"])
|
||||
parser.add_argument("--model_name", default="roberta-large", type=str)
|
||||
|
||||
@@ -25,7 +25,10 @@ from transformers import BertForMaskedLM
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation"
|
||||
description=(
|
||||
"Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned"
|
||||
" Distillation"
|
||||
)
|
||||
)
|
||||
parser.add_argument("--model_type", default="bert", choices=["bert"])
|
||||
parser.add_argument("--model_name", default="bert-base-uncased", type=str)
|
||||
|
||||
@@ -207,8 +207,10 @@ def main():
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
help=(
|
||||
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html"
|
||||
),
|
||||
)
|
||||
parser.add_argument("--n_gpu", type=int, default=1, help="Number of GPUs in the node.")
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="Distributed training - Local rank")
|
||||
@@ -226,8 +228,8 @@ def main():
|
||||
if os.path.exists(args.dump_path):
|
||||
if not args.force:
|
||||
raise ValueError(
|
||||
f"Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite it"
|
||||
"Use `--force` if you want to overwrite it"
|
||||
f"Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite"
|
||||
" itUse `--force` if you want to overwrite it"
|
||||
)
|
||||
else:
|
||||
shutil.rmtree(args.dump_path)
|
||||
|
||||
Reference in New Issue
Block a user