better naming

This commit is contained in:
Patrick von Platen
2020-03-05 15:48:00 +01:00
parent ff648221bd
commit 7cba11fb9b
3 changed files with 6 additions and 7 deletions

View File

@@ -411,7 +411,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
else:
raise EnvironmentError(
"Error no file named {} found in directory {} or `from_tf` set to False".format(
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index",],
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
pretrained_model_name_or_path,
)
)
@@ -816,7 +816,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
effective_batch_size * num_beams, input_ids_len
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
# TODO (PVP): check eos_token_id
# TODO (PVP): probably not the best way to check whether model is encoder decoder
is_encoder_decoder = (
hasattr(self, "model") and hasattr(self.model, "decoder") and hasattr(self.model, "encoder")
@@ -829,7 +828,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
encoder_inputs = input_ids
input_ids = torch.full(
(effective_batch_size * num_beams, 1),
# eos_token_id, # Why eos_token_id here? bos_token_id makes more sense no?
# eos_token_id, # Why eos_token_id here? bos_token_id seems to work as well ... to see if it works as well with hard summarization case
bos_token_id,
dtype=torch.long,
device=next(self.parameters()).device,