add to_json_file method to configuration classes
This commit is contained in:
@@ -220,6 +220,11 @@ class BertConfig(object):
|
|||||||
"""Serializes this instance to a JSON string."""
|
"""Serializes this instance to a JSON string."""
|
||||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||||
|
|
||||||
|
def to_json_file(self, json_file_path):
|
||||||
|
""" Save this instance to a json file."""
|
||||||
|
with open(json_file_path, "w", encoding='utf-8') as writer:
|
||||||
|
writer.write(self.to_json_string())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
|
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|||||||
@@ -180,6 +180,11 @@ class GPT2Config(object):
|
|||||||
"""Serializes this instance to a JSON string."""
|
"""Serializes this instance to a JSON string."""
|
||||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||||
|
|
||||||
|
def to_json_file(self, json_file_path):
|
||||||
|
""" Save this instance to a json file."""
|
||||||
|
with open(json_file_path, "w", encoding='utf-8') as writer:
|
||||||
|
writer.write(self.to_json_string())
|
||||||
|
|
||||||
|
|
||||||
class Conv1D(nn.Module):
|
class Conv1D(nn.Module):
|
||||||
def __init__(self, nf, nx):
|
def __init__(self, nf, nx):
|
||||||
|
|||||||
@@ -225,6 +225,11 @@ class OpenAIGPTConfig(object):
|
|||||||
"""Serializes this instance to a JSON string."""
|
"""Serializes this instance to a JSON string."""
|
||||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||||
|
|
||||||
|
def to_json_file(self, json_file_path):
|
||||||
|
""" Save this instance to a json file."""
|
||||||
|
with open(json_file_path, "w", encoding='utf-8') as writer:
|
||||||
|
writer.write(self.to_json_string())
|
||||||
|
|
||||||
|
|
||||||
class Conv1D(nn.Module):
|
class Conv1D(nn.Module):
|
||||||
def __init__(self, nf, rf, nx):
|
def __init__(self, nf, rf, nx):
|
||||||
|
|||||||
@@ -316,6 +316,11 @@ class TransfoXLConfig(object):
|
|||||||
"""Serializes this instance to a JSON string."""
|
"""Serializes this instance to a JSON string."""
|
||||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||||
|
|
||||||
|
def to_json_file(self, json_file_path):
|
||||||
|
""" Save this instance to a json file."""
|
||||||
|
with open(json_file_path, "w", encoding='utf-8') as writer:
|
||||||
|
writer.write(self.to_json_string())
|
||||||
|
|
||||||
|
|
||||||
class PositionalEmbedding(nn.Module):
|
class PositionalEmbedding(nn.Module):
|
||||||
def __init__(self, demb):
|
def __init__(self, demb):
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
@@ -176,6 +177,14 @@ class GPT2ModelTest(unittest.TestCase):
|
|||||||
self.assertEqual(obj["vocab_size"], 99)
|
self.assertEqual(obj["vocab_size"], 99)
|
||||||
self.assertEqual(obj["n_embd"], 37)
|
self.assertEqual(obj["n_embd"], 37)
|
||||||
|
|
||||||
|
def test_config_to_json_file(self):
|
||||||
|
config_first = GPT2Config(vocab_size_or_config_json_file=99, n_embd=37)
|
||||||
|
json_file_path = "/tmp/config.json"
|
||||||
|
config_first.to_json_file(json_file_path)
|
||||||
|
config_second = GPT2Config.from_json_file(json_file_path)
|
||||||
|
os.remove(json_file_path)
|
||||||
|
self.assertEqual(config_second.to_dict(), config_first.to_dict())
|
||||||
|
|
||||||
def run_tester(self, tester):
|
def run_tester(self, tester):
|
||||||
config_and_inputs = tester.prepare_config_and_inputs()
|
config_and_inputs = tester.prepare_config_and_inputs()
|
||||||
output_result = tester.create_gpt2_model(*config_and_inputs)
|
output_result = tester.create_gpt2_model(*config_and_inputs)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
@@ -188,6 +189,14 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
|||||||
self.assertEqual(obj["vocab_size"], 99)
|
self.assertEqual(obj["vocab_size"], 99)
|
||||||
self.assertEqual(obj["n_embd"], 37)
|
self.assertEqual(obj["n_embd"], 37)
|
||||||
|
|
||||||
|
def test_config_to_json_file(self):
|
||||||
|
config_first = OpenAIGPTConfig(vocab_size_or_config_json_file=99, n_embd=37)
|
||||||
|
json_file_path = "/tmp/config.json"
|
||||||
|
config_first.to_json_file(json_file_path)
|
||||||
|
config_second = OpenAIGPTConfig.from_json_file(json_file_path)
|
||||||
|
os.remove(json_file_path)
|
||||||
|
self.assertEqual(config_second.to_dict(), config_first.to_dict())
|
||||||
|
|
||||||
def run_tester(self, tester):
|
def run_tester(self, tester):
|
||||||
config_and_inputs = tester.prepare_config_and_inputs()
|
config_and_inputs = tester.prepare_config_and_inputs()
|
||||||
output_result = tester.create_openai_model(*config_and_inputs)
|
output_result = tester.create_openai_model(*config_and_inputs)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
@@ -251,6 +252,14 @@ class BertModelTest(unittest.TestCase):
|
|||||||
self.assertEqual(obj["vocab_size"], 99)
|
self.assertEqual(obj["vocab_size"], 99)
|
||||||
self.assertEqual(obj["hidden_size"], 37)
|
self.assertEqual(obj["hidden_size"], 37)
|
||||||
|
|
||||||
|
def test_config_to_json_file(self):
|
||||||
|
config_first = BertConfig(vocab_size_or_config_json_file=99, hidden_size=37)
|
||||||
|
json_file_path = "/tmp/config.json"
|
||||||
|
config_first.to_json_file(json_file_path)
|
||||||
|
config_second = BertConfig.from_json_file(json_file_path)
|
||||||
|
os.remove(json_file_path)
|
||||||
|
self.assertEqual(config_second.to_dict(), config_first.to_dict())
|
||||||
|
|
||||||
def run_tester(self, tester):
|
def run_tester(self, tester):
|
||||||
config_and_inputs = tester.prepare_config_and_inputs()
|
config_and_inputs = tester.prepare_config_and_inputs()
|
||||||
output_result = tester.create_bert_model(*config_and_inputs)
|
output_result = tester.create_bert_model(*config_and_inputs)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
@@ -186,6 +187,14 @@ class TransfoXLModelTest(unittest.TestCase):
|
|||||||
self.assertEqual(obj["n_token"], 96)
|
self.assertEqual(obj["n_token"], 96)
|
||||||
self.assertEqual(obj["d_embed"], 37)
|
self.assertEqual(obj["d_embed"], 37)
|
||||||
|
|
||||||
|
def test_config_to_json_file(self):
|
||||||
|
config_first = TransfoXLConfig(vocab_size_or_config_json_file=96, d_embed=37)
|
||||||
|
json_file_path = "/tmp/config.json"
|
||||||
|
config_first.to_json_file(json_file_path)
|
||||||
|
config_second = TransfoXLConfig.from_json_file(json_file_path)
|
||||||
|
os.remove(json_file_path)
|
||||||
|
self.assertEqual(config_second.to_dict(), config_first.to_dict())
|
||||||
|
|
||||||
def run_tester(self, tester):
|
def run_tester(self, tester):
|
||||||
config_and_inputs = tester.prepare_config_and_inputs()
|
config_and_inputs = tester.prepare_config_and_inputs()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user