Black preview (#17217)

* Black preview

* Fixup too!

* Fix check copies

* Use the same version as the CI

* Bump black
This commit is contained in:
Sylvain Gugger
2022-05-12 16:25:55 -04:00
committed by GitHub
parent 9bd67ac7bb
commit afe5d42d8d
578 changed files with 8274 additions and 3296 deletions

View File

@@ -392,13 +392,14 @@ class BeamSearchScorerTS(torch.nn.Module):
if not isinstance(num_beams, int) or num_beams <= 1:
raise ValueError(
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead."
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
" one should make use of `greedy_search` instead."
)
if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
raise ValueError(
f"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` "
f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
)
def hypo_len(self, hypo_idx: int):
@@ -508,7 +509,8 @@ class BeamSearchScorerTS(torch.nn.Module):
if beam_idx < self.group_size:
raise ValueError(
f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
)
# Check if we are done so that we can save a pad step if all(done)

View File

@@ -53,14 +53,16 @@ def parse_args():
"--max_length",
type=int,
default=5,
help=("The maximum total input sequence length after tokenization."),
help="The maximum total input sequence length after tokenization.",
)
parser.add_argument(
"--num_beams",
type=int,
default=None,
help="Number of beams to use for evaluation. This argument will be "
"passed to ``model.generate``, which is used during ``evaluate`` and ``predict``.",
help=(
"Number of beams to use for evaluation. This argument will be "
"passed to ``model.generate``, which is used during ``evaluate`` and ``predict``."
),
)
parser.add_argument(
"--model_name_or_path",