[BART] generation_mode as a kwarg not a class attribute (#3278)

This commit is contained in:
Sam Shleifer
2020-03-16 12:47:53 -04:00
committed by GitHub
parent de697935a2
commit 11573231c6
2 changed files with 9 additions and 8 deletions

View File

@@ -846,7 +846,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
attention_mask = attention_mask.contiguous().view(
effective_batch_size * num_beams, input_ids_len
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
if self.config.is_encoder_decoder:
assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id"
# encoder decoder need to start with empty input_ids and copy the input_ids to encoder_inputs
@@ -859,9 +858,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
)
cur_len = 1
# put model in generation mode if it has one
if hasattr(self.model, "decoder") and hasattr(self.model.decoder, "generation_mode"):
self.model.decoder.generation_mode = True
else:
encoder_inputs = None
cur_len = input_ids.shape[-1]