cleaning up configuration classes

This commit is contained in:
thomwolf
2019-12-13 14:33:24 +01:00
parent 7296f1010b
commit 47f0e3cfb7
43 changed files with 224 additions and 329 deletions

View File

@@ -39,7 +39,7 @@ class XxxConfig(PretrainedConfig):
Arguments:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `XxxModel`.
vocab_size: Vocabulary size of `inputs_ids` in `XxxModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
@@ -64,7 +64,7 @@ class XxxConfig(PretrainedConfig):
pretrained_config_archive_map = XXX_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file=50257,
vocab_size=50257,
n_positions=1024,
n_ctx=1024,
n_embd=768,
@@ -84,7 +84,7 @@ class XxxConfig(PretrainedConfig):
summary_first_dropout=0.1,
**kwargs):
super(XxxConfig, self).__init__(**kwargs)
self.vocab_size = vocab_size_or_config_json_file if isinstance(vocab_size_or_config_json_file, six.string_types) else -1
self.vocab_size = vocab_size if isinstance(vocab_size, six.string_types) else -1
self.n_ctx = n_ctx
self.n_positions = n_positions
self.n_embd = n_embd
@@ -102,12 +102,12 @@ class XxxConfig(PretrainedConfig):
self.summary_activation = summary_activation
self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels
if isinstance(vocab_size_or_config_json_file, six.string_types):
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
if isinstance(vocab_size, six.string_types):
with open(vocab_size, "r", encoding="utf-8") as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif not isinstance(vocab_size_or_config_json_file, int):
elif not isinstance(vocab_size, int):
raise ValueError(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"

View File

@@ -111,7 +111,7 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = XxxConfig(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,

View File

@@ -109,7 +109,7 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = XxxConfig(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,