Black preview (#17217)
* Black preview * Fixup too! * Fix check copies * Use the same version as the CI * Bump black
This commit is contained in:
@@ -392,13 +392,14 @@ class BeamSearchScorerTS(torch.nn.Module):
|
||||
|
||||
if not isinstance(num_beams, int) or num_beams <= 1:
|
||||
raise ValueError(
|
||||
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead."
|
||||
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
|
||||
" one should make use of `greedy_search` instead."
|
||||
)
|
||||
|
||||
if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
|
||||
raise ValueError(
|
||||
f"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` "
|
||||
f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
|
||||
"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
|
||||
f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
|
||||
)
|
||||
|
||||
def hypo_len(self, hypo_idx: int):
|
||||
@@ -508,7 +509,8 @@ class BeamSearchScorerTS(torch.nn.Module):
|
||||
|
||||
if beam_idx < self.group_size:
|
||||
raise ValueError(
|
||||
f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
|
||||
f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
|
||||
f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
|
||||
)
|
||||
|
||||
# Check if we are done so that we can save a pad step if all(done)
|
||||
|
||||
@@ -53,14 +53,16 @@ def parse_args():
|
||||
"--max_length",
|
||||
type=int,
|
||||
default=5,
|
||||
help=("The maximum total input sequence length after tokenization."),
|
||||
help="The maximum total input sequence length after tokenization.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_beams",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of beams to use for evaluation. This argument will be "
|
||||
"passed to ``model.generate``, which is used during ``evaluate`` and ``predict``.",
|
||||
help=(
|
||||
"Number of beams to use for evaluation. This argument will be "
|
||||
"passed to ``model.generate``, which is used during ``evaluate`` and ``predict``."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
|
||||
Reference in New Issue
Block a user