Extend config with task specific configs. (#3433)
* add new default configs * change prefix default to None
This commit is contained in:
committed by
GitHub
parent
83272a3853
commit
ffa17fe322
@@ -610,7 +610,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
num_return_sequences = (
|
||||
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
||||
)
|
||||
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
|
||||
decoder_start_token_id = (
|
||||
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
|
||||
)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size = shape_list(input_ids)[0] # overriden by the input batch_size
|
||||
@@ -635,9 +637,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
assert (eos_token_id is None) or (
|
||||
isinstance(eos_token_id, int) and (eos_token_id >= 0)
|
||||
), "`eos_token_id` should be a positive integer."
|
||||
assert (
|
||||
decoder_start_token_id is not None or self.config.is_encoder_decoder is False
|
||||
), "`decoder_start_token_id` has to be defined if model is encoder-decoder model"
|
||||
assert length_penalty > 0, "`length_penalty` should be strictely positive."
|
||||
assert (
|
||||
isinstance(num_return_sequences, int) and num_return_sequences > 0
|
||||
@@ -708,8 +707,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
if decoder_start_token_id is None:
|
||||
decoder_start_token_id = bos_token_id
|
||||
|
||||
assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id"
|
||||
assert (
|
||||
decoder_start_token_id is not None
|
||||
), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
|
||||
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
|
||||
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user