Black preview (#17217)
* Black preview * Fixup too! * Fix check copies * Use the same version as the CI * Bump black
This commit is contained in:
@@ -93,7 +93,8 @@ def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=Fa
|
||||
exp = "{val_avg_loss:.4f}-{step_count}"
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"seq2seq callbacks only support rouge2, bleu and loss, got {metric}, You can make your own by adding to this function."
|
||||
f"seq2seq callbacks only support rouge2, bleu and loss, got {metric}, You can make your own by adding to"
|
||||
" this function."
|
||||
)
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
|
||||
@@ -52,9 +52,10 @@ class SummarizationDistiller(SummarizationModule):
|
||||
student.config.length_penalty = hparams.length_penalty
|
||||
hparams.tokenizer_name = hparams.teacher # Use teacher's tokenizer
|
||||
super().__init__(hparams, model=student, config=student.config)
|
||||
assert (
|
||||
student.config.model_type == teacher.config.model_type
|
||||
), f"teacher, student model types should be the same, got {student.config.model_type} != {teacher.config.model_type}"
|
||||
assert student.config.model_type == teacher.config.model_type, (
|
||||
f"teacher, student model types should be the same, got {student.config.model_type} !="
|
||||
f" {teacher.config.model_type}"
|
||||
)
|
||||
|
||||
if student.config.model_type == "t5":
|
||||
student_encoder_layers = len(student.get_encoder().block)
|
||||
|
||||
@@ -303,29 +303,37 @@ class SummarizationModule(BaseTransformer):
|
||||
"--max_source_length",
|
||||
default=1024,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
help=(
|
||||
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_target_length",
|
||||
default=56,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
help=(
|
||||
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val_max_target_length",
|
||||
default=142, # these defaults are optimized for CNNDM. For xsum, see README.md.
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
help=(
|
||||
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_max_target_length",
|
||||
default=142,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
help=(
|
||||
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--freeze_encoder", action="store_true")
|
||||
parser.add_argument("--freeze_embeds", action="store_true")
|
||||
@@ -353,7 +361,10 @@ class SummarizationModule(BaseTransformer):
|
||||
type=int,
|
||||
default=-1,
|
||||
required=False,
|
||||
help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
|
||||
help=(
|
||||
"-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So"
|
||||
" val_check_interval will effect it."
|
||||
),
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@@ -312,8 +312,10 @@ def add_generic_args(parser, root_dir) -> None:
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O2",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
help=(
|
||||
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html"
|
||||
),
|
||||
)
|
||||
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
|
||||
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
|
||||
|
||||
@@ -58,7 +58,8 @@ def pick_layers_to_copy(n_student, n_teacher):
|
||||
except KeyError:
|
||||
if n_student != n_teacher:
|
||||
warnings.warn(
|
||||
f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first {n_student}"
|
||||
f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first"
|
||||
f" {n_student}"
|
||||
)
|
||||
return list(range(n_student))
|
||||
|
||||
@@ -144,7 +145,8 @@ def create_student_by_copying_alternating_layers(
|
||||
if copy_first_teacher_layers: # Our copying is done. We just log and save
|
||||
e_layers_to_copy, d_layers_to_copy = list(range(e)), list(range(d))
|
||||
logger.info(
|
||||
f"Copied encoder layers {e_layers_to_copy} and decoder layers {d_layers_to_copy}. Saving them to {save_path}"
|
||||
f"Copied encoder layers {e_layers_to_copy} and decoder layers {d_layers_to_copy}. Saving them to"
|
||||
f" {save_path}"
|
||||
)
|
||||
student.save_pretrained(save_path)
|
||||
return student, e_layers_to_copy, d_layers_to_copy
|
||||
|
||||
@@ -108,7 +108,10 @@ def run_generate(verbose=True):
|
||||
nargs="?",
|
||||
type=str,
|
||||
const=datetime_now(),
|
||||
help="use in conjunction w/ --dump-args to print with the results whatever other info you'd like, e.g. lang=en-ru. If no value is passed, the current datetime string will be used.",
|
||||
help=(
|
||||
"use in conjunction w/ --dump-args to print with the results whatever other info you'd like, e.g."
|
||||
" lang=en-ru. If no value is passed, the current datetime string will be used."
|
||||
),
|
||||
)
|
||||
# Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate
|
||||
args, rest = parser.parse_known_args()
|
||||
|
||||
Reference in New Issue
Block a user