From 9761aa48452712711d6b2ff04902b8a37ff294b3 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 15 Apr 2019 14:12:08 +0200 Subject: [PATCH] add to_json_file method to configuration classes --- pytorch_pretrained_bert/modeling.py | 5 +++++ pytorch_pretrained_bert/modeling_gpt2.py | 5 +++++ pytorch_pretrained_bert/modeling_openai.py | 5 +++++ pytorch_pretrained_bert/modeling_transfo_xl.py | 5 +++++ tests/modeling_gpt2_test.py | 9 +++++++++ tests/modeling_openai_test.py | 9 +++++++++ tests/modeling_test.py | 9 +++++++++ tests/modeling_transfo_xl_test.py | 9 +++++++++ 8 files changed, 56 insertions(+) diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index 2736e34d7f..6a71cbeea6 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -220,6 +220,11 @@ class BertConfig(object): """Serializes this instance to a JSON string.""" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + def to_json_file(self, json_file_path): + """ Save this instance to a json file.""" + with open(json_file_path, "w", encoding='utf-8') as writer: + writer.write(self.to_json_string()) + try: from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm except ImportError: diff --git a/pytorch_pretrained_bert/modeling_gpt2.py b/pytorch_pretrained_bert/modeling_gpt2.py index 7b00ce7730..fce564e9ea 100644 --- a/pytorch_pretrained_bert/modeling_gpt2.py +++ b/pytorch_pretrained_bert/modeling_gpt2.py @@ -180,6 +180,11 @@ class GPT2Config(object): """Serializes this instance to a JSON string.""" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + def to_json_file(self, json_file_path): + """ Save this instance to a json file.""" + with open(json_file_path, "w", encoding='utf-8') as writer: + writer.write(self.to_json_string()) + class Conv1D(nn.Module): def __init__(self, nf, nx): diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index b6252d097f..33bb4472a5 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -225,6 +225,11 @@ class OpenAIGPTConfig(object): """Serializes this instance to a JSON string.""" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + def to_json_file(self, json_file_path): + """ Save this instance to a json file.""" + with open(json_file_path, "w", encoding='utf-8') as writer: + writer.write(self.to_json_string()) + class Conv1D(nn.Module): def __init__(self, nf, rf, nx): diff --git a/pytorch_pretrained_bert/modeling_transfo_xl.py b/pytorch_pretrained_bert/modeling_transfo_xl.py index ac895a03a7..0ba986f5b4 100644 --- a/pytorch_pretrained_bert/modeling_transfo_xl.py +++ b/pytorch_pretrained_bert/modeling_transfo_xl.py @@ -316,6 +316,11 @@ class TransfoXLConfig(object): """Serializes this instance to a JSON string.""" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + def to_json_file(self, json_file_path): + """ Save this instance to a json file.""" + with open(json_file_path, "w", encoding='utf-8') as writer: + writer.write(self.to_json_string()) + class PositionalEmbedding(nn.Module): def __init__(self, demb): diff --git a/tests/modeling_gpt2_test.py b/tests/modeling_gpt2_test.py index 12a539c44b..d542422060 100644 --- a/tests/modeling_gpt2_test.py +++ b/tests/modeling_gpt2_test.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import unittest import json import random @@ -176,6 +177,14 @@ class GPT2ModelTest(unittest.TestCase): self.assertEqual(obj["vocab_size"], 99) self.assertEqual(obj["n_embd"], 37) + def test_config_to_json_file(self): + config_first = GPT2Config(vocab_size_or_config_json_file=99, n_embd=37) + json_file_path = "/tmp/config.json" + config_first.to_json_file(json_file_path) + config_second = GPT2Config.from_json_file(json_file_path) + os.remove(json_file_path) + self.assertEqual(config_second.to_dict(), config_first.to_dict()) + def run_tester(self, tester): config_and_inputs = tester.prepare_config_and_inputs() output_result = tester.create_gpt2_model(*config_and_inputs) diff --git a/tests/modeling_openai_test.py b/tests/modeling_openai_test.py index 1cc8b7d5dc..db03bf792e 100644 --- a/tests/modeling_openai_test.py +++ b/tests/modeling_openai_test.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import unittest import json import random @@ -188,6 +189,14 @@ class OpenAIGPTModelTest(unittest.TestCase): self.assertEqual(obj["vocab_size"], 99) self.assertEqual(obj["n_embd"], 37) + def test_config_to_json_file(self): + config_first = OpenAIGPTConfig(vocab_size_or_config_json_file=99, n_embd=37) + json_file_path = "/tmp/config.json" + config_first.to_json_file(json_file_path) + config_second = OpenAIGPTConfig.from_json_file(json_file_path) + os.remove(json_file_path) + self.assertEqual(config_second.to_dict(), config_first.to_dict()) + def run_tester(self, tester): config_and_inputs = tester.prepare_config_and_inputs() output_result = tester.create_openai_model(*config_and_inputs) diff --git a/tests/modeling_test.py b/tests/modeling_test.py index c7a031cfb0..02d7a13fda 100644 --- a/tests/modeling_test.py +++ b/tests/modeling_test.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import unittest import json import random @@ -251,6 +252,14 @@ class BertModelTest(unittest.TestCase): self.assertEqual(obj["vocab_size"], 99) self.assertEqual(obj["hidden_size"], 37) + def test_config_to_json_file(self): + config_first = BertConfig(vocab_size_or_config_json_file=99, hidden_size=37) + json_file_path = "/tmp/config.json" + config_first.to_json_file(json_file_path) + config_second = BertConfig.from_json_file(json_file_path) + os.remove(json_file_path) + self.assertEqual(config_second.to_dict(), config_first.to_dict()) + def run_tester(self, tester): config_and_inputs = tester.prepare_config_and_inputs() output_result = tester.create_bert_model(*config_and_inputs) diff --git a/tests/modeling_transfo_xl_test.py b/tests/modeling_transfo_xl_test.py index 291d5d9d2a..a59d90b205 100644 --- a/tests/modeling_transfo_xl_test.py +++ b/tests/modeling_transfo_xl_test.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import unittest import json import random @@ -186,6 +187,14 @@ class TransfoXLModelTest(unittest.TestCase): self.assertEqual(obj["n_token"], 96) self.assertEqual(obj["d_embed"], 37) + def test_config_to_json_file(self): + config_first = TransfoXLConfig(vocab_size_or_config_json_file=96, d_embed=37) + json_file_path = "/tmp/config.json" + config_first.to_json_file(json_file_path) + config_second = TransfoXLConfig.from_json_file(json_file_path) + os.remove(json_file_path) + self.assertEqual(config_second.to_dict(), config_first.to_dict()) + def run_tester(self, tester): config_and_inputs = tester.prepare_config_and_inputs()