add to_json_file method to configuration classes
This commit is contained in:
@@ -16,6 +16,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import json
|
||||
import random
|
||||
@@ -176,6 +177,14 @@ class GPT2ModelTest(unittest.TestCase):
|
||||
self.assertEqual(obj["vocab_size"], 99)
|
||||
self.assertEqual(obj["n_embd"], 37)
|
||||
|
||||
def test_config_to_json_file(self):
|
||||
config_first = GPT2Config(vocab_size_or_config_json_file=99, n_embd=37)
|
||||
json_file_path = "/tmp/config.json"
|
||||
config_first.to_json_file(json_file_path)
|
||||
config_second = GPT2Config.from_json_file(json_file_path)
|
||||
os.remove(json_file_path)
|
||||
self.assertEqual(config_second.to_dict(), config_first.to_dict())
|
||||
|
||||
def run_tester(self, tester):
|
||||
config_and_inputs = tester.prepare_config_and_inputs()
|
||||
output_result = tester.create_gpt2_model(*config_and_inputs)
|
||||
|
||||
Reference in New Issue
Block a user