update bertabs
This commit is contained in:
@@ -33,6 +33,8 @@ class BertAbsConfig(PretrainedConfig):
|
|||||||
r""" Class to store the configuration of the BertAbs model.
|
r""" Class to store the configuration of the BertAbs model.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
vocab_size: int
|
||||||
|
Number of tokens in the vocabulary.
|
||||||
max_pos: int
|
max_pos: int
|
||||||
The maximum sequence length that this model will be used with.
|
The maximum sequence length that this model will be used with.
|
||||||
enc_layer: int
|
enc_layer: int
|
||||||
@@ -81,13 +83,6 @@ class BertAbsConfig(PretrainedConfig):
|
|||||||
):
|
):
|
||||||
super(BertAbsConfig, self).__init__(**kwargs)
|
super(BertAbsConfig, self).__init__(**kwargs)
|
||||||
|
|
||||||
if self._input_is_path_to_json(vocab_size):
|
|
||||||
path_to_json = vocab_size
|
|
||||||
with open(path_to_json, "r", encoding="utf-8") as reader:
|
|
||||||
json_config = json.loads(reader.read())
|
|
||||||
for key, value in json_config.items():
|
|
||||||
self.__dict__[key] = value
|
|
||||||
elif isinstance(vocab_size, int):
|
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.max_pos = max_pos
|
self.max_pos = max_pos
|
||||||
|
|
||||||
@@ -102,18 +97,3 @@ class BertAbsConfig(PretrainedConfig):
|
|||||||
self.dec_heads = dec_heads
|
self.dec_heads = dec_heads
|
||||||
self.dec_ff_size = dec_ff_size
|
self.dec_ff_size = dec_ff_size
|
||||||
self.dec_dropout = dec_dropout
|
self.dec_dropout = dec_dropout
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"First argument must be either a vocabulary size (int)"
|
|
||||||
"or the path to a pretrained model config file (str)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _input_is_path_to_json(self, first_argument):
|
|
||||||
""" Checks whether the first argument passed to config
|
|
||||||
is the path to a JSON file that contains the config.
|
|
||||||
"""
|
|
||||||
is_python_2 = sys.version_info[0] == 2
|
|
||||||
if is_python_2:
|
|
||||||
return isinstance(first_argument, unicode)
|
|
||||||
else:
|
|
||||||
return isinstance(first_argument, str)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user