From bd5363cc8330229b14c668bdc3340e8a4902e608 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 7 Oct 2019 15:37:30 +0200 Subject: [PATCH] update CTRL configuration --- transformers/configuration_ctrl.py | 41 +++++++++++++++--------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/transformers/configuration_ctrl.py b/transformers/configuration_ctrl.py index 4525936885..b305bf3ad3 100644 --- a/transformers/configuration_ctrl.py +++ b/transformers/configuration_ctrl.py @@ -95,33 +95,32 @@ class CTRLConfig(PretrainedConfig): """ super(CTRLConfig, self).__init__(**kwargs) + self.vocab_size = vocab_size_or_config_json_file if isinstance(vocab_size_or_config_json_file, int) else -1 + self.n_ctx = n_ctx + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.dff = dff + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + + self.num_labels = num_labels + self.summary_type = summary_type + self.summary_use_proj = summary_use_proj + self.summary_activation = summary_activation + self.summary_first_dropout = summary_first_dropout + self.summary_proj_to_labels = summary_proj_to_labels if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 and isinstance(vocab_size_or_config_json_file, unicode)): with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader: json_config = json.loads(reader.read()) for key, value in json_config.items(): self.__dict__[key] = value - elif isinstance(vocab_size_or_config_json_file, int): - self.vocab_size = vocab_size_or_config_json_file - self.n_ctx = n_ctx - self.n_positions = n_positions - self.n_embd = n_embd - self.n_layer = n_layer - self.n_head = n_head - self.dff = dff - self.resid_pdrop = resid_pdrop - self.embd_pdrop = embd_pdrop - self.attn_pdrop = attn_pdrop - self.layer_norm_epsilon = layer_norm_epsilon - self.initializer_range = initializer_range - - self.num_labels = num_labels - self.summary_type = summary_type - self.summary_use_proj = summary_use_proj - self.summary_activation = summary_activation - self.summary_first_dropout = summary_first_dropout - self.summary_proj_to_labels = summary_proj_to_labels - else: + elif not isinstance(vocab_size_or_config_json_file, int): raise ValueError( "First argument must be either a vocabulary size (int)" "or the path to a pretrained model config file (str)"