Black preview (#17217)
* Black preview * Fixup too! * Fix check copies * Use the same version as the CI * Bump black
This commit is contained in:
@@ -57,9 +57,10 @@ class Seq2SeqTrainer(Trainer):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if config is None:
|
||||
assert isinstance(
|
||||
self.model, PreTrainedModel
|
||||
), f"If no `config` is passed the model to be trained has to be of type `PreTrainedModel`, but is {self.model.__class__}"
|
||||
assert isinstance(self.model, PreTrainedModel), (
|
||||
"If no `config` is passed the model to be trained has to be of type `PreTrainedModel`, but is"
|
||||
f" {self.model.__class__}"
|
||||
)
|
||||
self.config = self.model.config
|
||||
else:
|
||||
self.config = config
|
||||
@@ -68,13 +69,15 @@ class Seq2SeqTrainer(Trainer):
|
||||
self.vocab_size = self.config.tgt_vocab_size if isinstance(self.config, FSMTConfig) else self.config.vocab_size
|
||||
|
||||
if self.args.label_smoothing != 0 or (self.data_args is not None and self.data_args.ignore_pad_token_for_loss):
|
||||
assert (
|
||||
self.config.pad_token_id is not None
|
||||
), "Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss calculation or doing label smoothing."
|
||||
assert self.config.pad_token_id is not None, (
|
||||
"Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss"
|
||||
" calculation or doing label smoothing."
|
||||
)
|
||||
|
||||
if self.config.pad_token_id is None and self.config.eos_token_id is not None:
|
||||
logger.warning(
|
||||
f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for padding.."
|
||||
f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for"
|
||||
" padding.."
|
||||
)
|
||||
|
||||
if self.args.label_smoothing == 0:
|
||||
@@ -248,7 +251,8 @@ class Seq2SeqTrainer(Trainer):
|
||||
|
||||
if pad_token_id is None:
|
||||
raise ValueError(
|
||||
f"Make sure that either `config.pad_token_id` or `config.eos_token_id` is defined if tensor has to be padded to `max_length`={max_length}"
|
||||
"Make sure that either `config.pad_token_id` or `config.eos_token_id` is defined if tensor has to be"
|
||||
f" padded to `max_length`={max_length}"
|
||||
)
|
||||
|
||||
padded_tensor = pad_token_id * torch.ones(
|
||||
|
||||
Reference in New Issue
Block a user