Black preview (#17217)

* Black preview

* Fixup too!

* Fix check copies

* Use the same version as the CI

* Bump black
This commit is contained in:
Sylvain Gugger
2022-05-12 16:25:55 -04:00
committed by GitHub
parent 9bd67ac7bb
commit afe5d42d8d
578 changed files with 8274 additions and 3296 deletions

View File

@@ -90,31 +90,39 @@ class DataTrainingArguments:
max_source_length: Optional[int] = field(
default=1024,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
"help": (
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
"help": (
"The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)
val_max_target_length: Optional[int] = field(
default=142,
metadata={
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded. "
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
"during ``evaluate`` and ``predict``."
"help": (
"The maximum total sequence length for validation target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded. "
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
"during ``evaluate`` and ``predict``."
)
},
)
test_max_target_length: Optional[int] = field(
default=142,
metadata={
"help": "The maximum total sequence length for test target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
"help": (
"The maximum total sequence length for test target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)
n_train: Optional[int] = field(default=-1, metadata={"help": "# training examples. -1 means use all."})

View File

@@ -22,15 +22,30 @@ from utils import calculate_rouge
PRED = [
'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of the final seconds on board Flight 9525. The Germanwings co-pilot says he had a "previous episode of severe depression" German airline confirms it knew of Andreas Lubitz\'s depression years before he took control.',
"The Palestinian Authority officially becomes the 123rd member of the International Criminal Court. The formal accession was marked with a ceremony at The Hague, in the Netherlands. The Palestinians signed the ICC's founding Rome Statute in January. Israel and the United States opposed the Palestinians' efforts to join the body.",
"Amnesty International releases its annual report on the death penalty. The report catalogs the use of state-sanctioned killing as a punitive measure across the globe. At least 607 people were executed around the world in 2014, compared to 778 in 2013. The U.S. remains one of the worst offenders for imposing capital punishment.",
'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of the'
' final seconds on board Flight 9525. The Germanwings co-pilot says he had a "previous episode of severe'
" depression\" German airline confirms it knew of Andreas Lubitz's depression years before he took control.",
"The Palestinian Authority officially becomes the 123rd member of the International Criminal Court. The formal"
" accession was marked with a ceremony at The Hague, in the Netherlands. The Palestinians signed the ICC's"
" founding Rome Statute in January. Israel and the United States opposed the Palestinians' efforts to join the"
" body.",
"Amnesty International releases its annual report on the death penalty. The report catalogs the use of"
" state-sanctioned killing as a punitive measure across the globe. At least 607 people were executed around the"
" world in 2014, compared to 778 in 2013. The U.S. remains one of the worst offenders for imposing capital"
" punishment.",
]
TGT = [
'Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports . Journalists at Bild and Paris Match are "very confident" the video clip is real, an editor says . Andreas Lubitz had informed his Lufthansa training school of an episode of severe depression, airline says .',
"Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June . Israel and the United States opposed the move, which could open the door to war crimes investigations against Israelis .",
"Amnesty's annual death penalty report catalogs encouraging signs, but setbacks in numbers of those sentenced to death . Organization claims that governments around the world are using the threat of terrorism to advance executions . The number of executions worldwide has gone down by almost 22% compared with 2013, but death sentences up by 28% .",
'Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports .'
' Journalists at Bild and Paris Match are "very confident" the video clip is real, an editor says . Andreas Lubitz'
" had informed his Lufthansa training school of an episode of severe depression, airline says .",
"Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June ."
" Israel and the United States opposed the move, which could open the door to war crimes investigations against"
" Israelis .",
"Amnesty's annual death penalty report catalogs encouraging signs, but setbacks in numbers of those sentenced to"
" death . Organization claims that governments around the world are using the threat of terrorism to advance"
" executions . The number of executions worldwide has gone down by almost 22% compared with 2013, but death"
" sentences up by 28% .",
]
@@ -65,7 +80,8 @@ def test_single_sent_scores_dont_depend_on_newline_sep():
]
tgt = [
"Margot Frank, died in 1945, a month earlier than previously thought.",
'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of the final seconds on board Flight 9525.',
'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of'
" the final seconds on board Flight 9525.",
]
assert calculate_rouge(pred, tgt, newline_sep=True) == calculate_rouge(pred, tgt, newline_sep=False)

View File

@@ -121,7 +121,10 @@ def run_generate(verbose=True):
nargs="?",
type=str,
const=datetime_now(),
help="use in conjunction w/ --dump-args to print with the results whatever other info you'd like, e.g. lang=en-ru. If no value is passed, the current datetime string will be used.",
help=(
"use in conjunction w/ --dump-args to print with the results whatever other info you'd like, e.g."
" lang=en-ru. If no value is passed, the current datetime string will be used."
),
)
# Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate
args, rest = parser.parse_known_args()

View File

@@ -35,7 +35,7 @@ def parse_search_arg(search):
groups = search.split()
entries = {k: vs for k, vs in (g.split("=") for g in groups)}
entry_names = list(entries.keys())
sets = [list((f"--{k} {v}") for v in vs.split(":")) for k, vs in entries.items()]
sets = [list(f"--{k} {v}" for v in vs.split(":")) for k, vs in entries.items()]
matrix = [list(x) for x in itertools.product(*sets)]
return matrix, entry_names
@@ -66,7 +66,10 @@ def run_search():
prog = sys.argv[0]
parser = argparse.ArgumentParser(
usage="\n\nImportant: this script accepts all arguments `run_eval.py` accepts and then a few extra, therefore refer to `run_eval.py -h` for the complete list."
usage=(
"\n\nImportant: this script accepts all arguments `run_eval.py` accepts and then a few extra, therefore"
" refer to `run_eval.py -h` for the complete list."
)
)
parser.add_argument(
"--search",
@@ -83,7 +86,10 @@ def run_search():
nargs="?",
type=str,
const=datetime_now(),
help="add custom notes to be printed before the results table. If no value is passed, the current datetime string will be used.",
help=(
"add custom notes to be printed before the results table. If no value is passed, the current datetime"
" string will be used."
),
)
args, args_main = parser.parse_known_args()
# we share some of the args

View File

@@ -57,9 +57,10 @@ class Seq2SeqTrainer(Trainer):
super().__init__(*args, **kwargs)
if config is None:
assert isinstance(
self.model, PreTrainedModel
), f"If no `config` is passed the model to be trained has to be of type `PreTrainedModel`, but is {self.model.__class__}"
assert isinstance(self.model, PreTrainedModel), (
"If no `config` is passed the model to be trained has to be of type `PreTrainedModel`, but is"
f" {self.model.__class__}"
)
self.config = self.model.config
else:
self.config = config
@@ -68,13 +69,15 @@ class Seq2SeqTrainer(Trainer):
self.vocab_size = self.config.tgt_vocab_size if isinstance(self.config, FSMTConfig) else self.config.vocab_size
if self.args.label_smoothing != 0 or (self.data_args is not None and self.data_args.ignore_pad_token_for_loss):
assert (
self.config.pad_token_id is not None
), "Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss calculation or doing label smoothing."
assert self.config.pad_token_id is not None, (
"Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss"
" calculation or doing label smoothing."
)
if self.config.pad_token_id is None and self.config.eos_token_id is not None:
logger.warning(
f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for padding.."
f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for"
" padding.."
)
if self.args.label_smoothing == 0:
@@ -248,7 +251,8 @@ class Seq2SeqTrainer(Trainer):
if pad_token_id is None:
raise ValueError(
f"Make sure that either `config.pad_token_id` or `config.eos_token_id` is defined if tensor has to be padded to `max_length`={max_length}"
"Make sure that either `config.pad_token_id` or `config.eos_token_id` is defined if tensor has to be"
f" padded to `max_length`={max_length}"
)
padded_tensor = pad_token_id * torch.ones(

View File

@@ -39,9 +39,7 @@ def parse_args():
"""
parser = ArgumentParser(
description=(
"PyTorch TPU distributed training launch "
"helper utility that will spawn up "
"multiple distributed processes"
"PyTorch TPU distributed training launch helper utility that will spawn up multiple distributed processes"
)
)