diff --git a/examples/summarization/configuration_bertabs.py b/examples/summarization/configuration_bertabs.py index 054763ea93..b862d58d2b 100644 --- a/examples/summarization/configuration_bertabs.py +++ b/examples/summarization/configuration_bertabs.py @@ -33,6 +33,8 @@ class BertAbsConfig(PretrainedConfig): r""" Class to store the configuration of the BertAbs model. Arguments: + vocab_size: int + Number of tokens in the vocabulary. max_pos: int The maximum sequence length that this model will be used with. enc_layer: int @@ -81,39 +83,17 @@ class BertAbsConfig(PretrainedConfig): ): super(BertAbsConfig, self).__init__(**kwargs) - if self._input_is_path_to_json(vocab_size): - path_to_json = vocab_size - with open(path_to_json, "r", encoding="utf-8") as reader: - json_config = json.loads(reader.read()) - for key, value in json_config.items(): - self.__dict__[key] = value - elif isinstance(vocab_size, int): - self.vocab_size = vocab_size - self.max_pos = max_pos + self.vocab_size = vocab_size + self.max_pos = max_pos - self.enc_layers = enc_layers - self.enc_hidden_size = enc_hidden_size - self.enc_heads = enc_heads - self.enc_ff_size = enc_ff_size - self.enc_dropout = enc_dropout + self.enc_layers = enc_layers + self.enc_hidden_size = enc_hidden_size + self.enc_heads = enc_heads + self.enc_ff_size = enc_ff_size + self.enc_dropout = enc_dropout - self.dec_layers = dec_layers - self.dec_hidden_size = dec_hidden_size - self.dec_heads = dec_heads - self.dec_ff_size = dec_ff_size - self.dec_dropout = dec_dropout - else: - raise ValueError( - "First argument must be either a vocabulary size (int)" - "or the path to a pretrained model config file (str)" - ) - - def _input_is_path_to_json(self, first_argument): - """ Checks whether the first argument passed to config - is the path to a JSON file that contains the config. - """ - is_python_2 = sys.version_info[0] == 2 - if is_python_2: - return isinstance(first_argument, unicode) - else: - return isinstance(first_argument, str) + self.dec_layers = dec_layers + self.dec_hidden_size = dec_hidden_size + self.dec_heads = dec_heads + self.dec_ff_size = dec_ff_size + self.dec_dropout = dec_dropout