From ffa17fe322ecf57a10be80d4c328828ebd7c81f0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 25 Mar 2020 21:32:04 +0100 Subject: [PATCH] Extend config with task specific configs. (#3433) * add new default configs * change prefix default to None --- src/transformers/configuration_utils.py | 13 ++++++++++--- src/transformers/modeling_tf_utils.py | 13 ++++++++----- src/transformers/modeling_utils.py | 14 +++++++++----- 3 files changed, 27 insertions(+), 13 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 83185bc10f..d8341a17e8 100644 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -78,9 +78,6 @@ class PretrainedConfig(object): self.top_k = kwargs.pop("top_k", 50) self.top_p = kwargs.pop("top_p", 1.0) self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0) - self.bos_token_id = kwargs.pop("bos_token_id", None) - self.pad_token_id = kwargs.pop("pad_token_id", None) - self.eos_token_id = kwargs.pop("eos_token_id", None) self.length_penalty = kwargs.pop("length_penalty", 1.0) self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) self.num_return_sequences = kwargs.pop("num_return_sequences", 1) @@ -94,6 +91,16 @@ class PretrainedConfig(object): self.label2id = kwargs.pop("label2id", dict(zip(self.id2label.values(), self.id2label.keys()))) self.label2id = dict((key, int(value)) for key, value in self.label2id.items()) + # Tokenizer arguments TODO: eventually tokenizer and models should share the same config + self.prefix = kwargs.pop("prefix", None) + self.bos_token_id = kwargs.pop("bos_token_id", None) + self.pad_token_id = kwargs.pop("pad_token_id", None) + self.eos_token_id = kwargs.pop("eos_token_id", None) + self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None) + + # task specific arguments + self.task_specific_params = kwargs.pop("task_specific_params", None) + # Additional attributes without default values for key, value in kwargs.items(): try: diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 07d722f945..6a441c9fe1 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -610,7 +610,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): num_return_sequences = ( num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences ) - decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id + decoder_start_token_id = ( + decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id + ) if input_ids is not None: batch_size = shape_list(input_ids)[0] # overriden by the input batch_size @@ -635,9 +637,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): assert (eos_token_id is None) or ( isinstance(eos_token_id, int) and (eos_token_id >= 0) ), "`eos_token_id` should be a positive integer." - assert ( - decoder_start_token_id is not None or self.config.is_encoder_decoder is False - ), "`decoder_start_token_id` has to be defined if model is encoder-decoder model" assert length_penalty > 0, "`length_penalty` should be strictely positive." assert ( isinstance(num_return_sequences, int) and num_return_sequences > 0 @@ -708,8 +707,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) if self.config.is_encoder_decoder: + if decoder_start_token_id is None: + decoder_start_token_id = bos_token_id - assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id" + assert ( + decoder_start_token_id is not None + ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self) assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 808c160094..61f00c6eb6 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -809,7 +809,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): num_return_sequences = ( num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences ) - decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id + decoder_start_token_id = ( + decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id + ) if input_ids is not None: batch_size = input_ids.shape[0] # overriden by the input batch_size @@ -831,9 +833,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): assert pad_token_id is None or ( isinstance(pad_token_id, int) and (pad_token_id >= 0) ), "`pad_token_id` should be a positive integer." - assert ( - decoder_start_token_id is not None or self.config.is_encoder_decoder is False - ), "`decoder_start_token_id` has to be defined if model is encoder-decoder model" assert (eos_token_id is None) or ( isinstance(eos_token_id, int) and (eos_token_id >= 0) ), "`eos_token_id` should be a positive integer." @@ -912,7 +911,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) if self.config.is_encoder_decoder: - assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id" + if decoder_start_token_id is None: + decoder_start_token_id = bos_token_id + + assert ( + decoder_start_token_id is not None + ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self) assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)