update CTRL configuration
This commit is contained in:
@@ -95,33 +95,32 @@ class CTRLConfig(PretrainedConfig):
|
|||||||
"""
|
"""
|
||||||
super(CTRLConfig, self).__init__(**kwargs)
|
super(CTRLConfig, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
self.vocab_size = vocab_size_or_config_json_file if isinstance(vocab_size_or_config_json_file, int) else -1
|
||||||
|
self.n_ctx = n_ctx
|
||||||
|
self.n_positions = n_positions
|
||||||
|
self.n_embd = n_embd
|
||||||
|
self.n_layer = n_layer
|
||||||
|
self.n_head = n_head
|
||||||
|
self.dff = dff
|
||||||
|
self.resid_pdrop = resid_pdrop
|
||||||
|
self.embd_pdrop = embd_pdrop
|
||||||
|
self.attn_pdrop = attn_pdrop
|
||||||
|
self.layer_norm_epsilon = layer_norm_epsilon
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
|
||||||
|
self.num_labels = num_labels
|
||||||
|
self.summary_type = summary_type
|
||||||
|
self.summary_use_proj = summary_use_proj
|
||||||
|
self.summary_activation = summary_activation
|
||||||
|
self.summary_first_dropout = summary_first_dropout
|
||||||
|
self.summary_proj_to_labels = summary_proj_to_labels
|
||||||
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
|
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
|
||||||
and isinstance(vocab_size_or_config_json_file, unicode)):
|
and isinstance(vocab_size_or_config_json_file, unicode)):
|
||||||
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
|
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
|
||||||
json_config = json.loads(reader.read())
|
json_config = json.loads(reader.read())
|
||||||
for key, value in json_config.items():
|
for key, value in json_config.items():
|
||||||
self.__dict__[key] = value
|
self.__dict__[key] = value
|
||||||
elif isinstance(vocab_size_or_config_json_file, int):
|
elif not isinstance(vocab_size_or_config_json_file, int):
|
||||||
self.vocab_size = vocab_size_or_config_json_file
|
|
||||||
self.n_ctx = n_ctx
|
|
||||||
self.n_positions = n_positions
|
|
||||||
self.n_embd = n_embd
|
|
||||||
self.n_layer = n_layer
|
|
||||||
self.n_head = n_head
|
|
||||||
self.dff = dff
|
|
||||||
self.resid_pdrop = resid_pdrop
|
|
||||||
self.embd_pdrop = embd_pdrop
|
|
||||||
self.attn_pdrop = attn_pdrop
|
|
||||||
self.layer_norm_epsilon = layer_norm_epsilon
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
|
|
||||||
self.num_labels = num_labels
|
|
||||||
self.summary_type = summary_type
|
|
||||||
self.summary_use_proj = summary_use_proj
|
|
||||||
self.summary_activation = summary_activation
|
|
||||||
self.summary_first_dropout = summary_first_dropout
|
|
||||||
self.summary_proj_to_labels = summary_proj_to_labels
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"First argument must be either a vocabulary size (int)"
|
"First argument must be either a vocabulary size (int)"
|
||||||
"or the path to a pretrained model config file (str)"
|
"or the path to a pretrained model config file (str)"
|
||||||
|
|||||||
Reference in New Issue
Block a user