Black preview (#17217)

* Black preview

* Fixup too!

* Fix check copies

* Use the same version as the CI

* Bump black
This commit is contained in:
Sylvain Gugger
2022-05-12 16:25:55 -04:00
committed by GitHub
parent 9bd67ac7bb
commit afe5d42d8d
578 changed files with 8274 additions and 3296 deletions

View File

@@ -93,7 +93,8 @@ def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=Fa
exp = "{val_avg_loss:.4f}-{step_count}"
else:
raise NotImplementedError(
f"seq2seq callbacks only support rouge2, bleu and loss, got {metric}, You can make your own by adding to this function."
f"seq2seq callbacks only support rouge2, bleu and loss, got {metric}, You can make your own by adding to"
" this function."
)
checkpoint_callback = ModelCheckpoint(

View File

@@ -52,9 +52,10 @@ class SummarizationDistiller(SummarizationModule):
student.config.length_penalty = hparams.length_penalty
hparams.tokenizer_name = hparams.teacher # Use teacher's tokenizer
super().__init__(hparams, model=student, config=student.config)
assert (
student.config.model_type == teacher.config.model_type
), f"teacher, student model types should be the same, got {student.config.model_type} != {teacher.config.model_type}"
assert student.config.model_type == teacher.config.model_type, (
f"teacher, student model types should be the same, got {student.config.model_type} !="
f" {teacher.config.model_type}"
)
if student.config.model_type == "t5":
student_encoder_layers = len(student.get_encoder().block)

View File

@@ -303,29 +303,37 @@ class SummarizationModule(BaseTransformer):
"--max_source_length",
default=1024,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
help=(
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
),
)
parser.add_argument(
"--max_target_length",
default=56,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
help=(
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
),
)
parser.add_argument(
"--val_max_target_length",
default=142, # these defaults are optimized for CNNDM. For xsum, see README.md.
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
help=(
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
),
)
parser.add_argument(
"--test_max_target_length",
default=142,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
help=(
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
),
)
parser.add_argument("--freeze_encoder", action="store_true")
parser.add_argument("--freeze_embeds", action="store_true")
@@ -353,7 +361,10 @@ class SummarizationModule(BaseTransformer):
type=int,
default=-1,
required=False,
help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
help=(
"-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So"
" val_check_interval will effect it."
),
)
return parser

View File

@@ -312,8 +312,10 @@ def add_generic_args(parser, root_dir) -> None:
"--fp16_opt_level",
type=str,
default="O2",
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html",
help=(
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html"
),
)
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")

View File

@@ -58,7 +58,8 @@ def pick_layers_to_copy(n_student, n_teacher):
except KeyError:
if n_student != n_teacher:
warnings.warn(
f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first {n_student}"
f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first"
f" {n_student}"
)
return list(range(n_student))
@@ -144,7 +145,8 @@ def create_student_by_copying_alternating_layers(
if copy_first_teacher_layers: # Our copying is done. We just log and save
e_layers_to_copy, d_layers_to_copy = list(range(e)), list(range(d))
logger.info(
f"Copied encoder layers {e_layers_to_copy} and decoder layers {d_layers_to_copy}. Saving them to {save_path}"
f"Copied encoder layers {e_layers_to_copy} and decoder layers {d_layers_to_copy}. Saving them to"
f" {save_path}"
)
student.save_pretrained(save_path)
return student, e_layers_to_copy, d_layers_to_copy

View File

@@ -108,7 +108,10 @@ def run_generate(verbose=True):
nargs="?",
type=str,
const=datetime_now(),
help="use in conjunction w/ --dump-args to print with the results whatever other info you'd like, e.g. lang=en-ru. If no value is passed, the current datetime string will be used.",
help=(
"use in conjunction w/ --dump-args to print with the results whatever other info you'd like, e.g."
" lang=en-ru. If no value is passed, the current datetime string will be used."
),
)
# Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate
args, rest = parser.parse_known_args()