[BART] generation_mode as a kwarg not a class attribute (#3278)
This commit is contained in:
@@ -846,7 +846,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
attention_mask = attention_mask.contiguous().view(
|
||||
effective_batch_size * num_beams, input_ids_len
|
||||
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id"
|
||||
# encoder decoder need to start with empty input_ids and copy the input_ids to encoder_inputs
|
||||
@@ -859,9 +858,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
)
|
||||
cur_len = 1
|
||||
|
||||
# put model in generation mode if it has one
|
||||
if hasattr(self.model, "decoder") and hasattr(self.model.decoder, "generation_mode"):
|
||||
self.model.decoder.generation_mode = True
|
||||
else:
|
||||
encoder_inputs = None
|
||||
cur_len = input_ids.shape[-1]
|
||||
|
||||
Reference in New Issue
Block a user