[Examples] create model with custom config on the fly (#11798)

* create custom model on the flight

* better wording

* add update_from_string

* cleanup

* cleanup

* Update src/transformers/configuration_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* more bool options

* style

* fix logger

* add test

* add the doc

* assert on conflict of options

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Stas Bekman
2021-05-25 10:40:49 -07:00
committed by GitHub
parent 6287c929c1
commit 1b6530104d
4 changed files with 94 additions and 3 deletions

View File

@@ -21,7 +21,7 @@ import unittest
from huggingface_hub import HfApi
from requests.exceptions import HTTPError
from transformers import BertConfig
from transformers import BertConfig, GPT2Config
from transformers.testing_utils import ENDPOINT_STAGING, PASS, USER, is_staging_test
@@ -138,3 +138,21 @@ class ConfigPushToHubTester(unittest.TestCase):
for k, v in config.__dict__.items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
class ConfigTestUtils(unittest.TestCase):
def test_config_from_string(self):
c = GPT2Config()
# attempt to modify each of int/float/bool/str config records and verify they were updated
n_embd = c.n_embd + 1 # int
resid_pdrop = c.resid_pdrop + 1.0 # float
scale_attn_weights = not c.scale_attn_weights # bool
summary_type = c.summary_type + "foo" # str
c.update_from_string(
f"n_embd={n_embd},resid_pdrop={resid_pdrop},scale_attn_weights={scale_attn_weights},summary_type={summary_type}"
)
self.assertEqual(n_embd, c.n_embd, "mismatch for key: n_embd")
self.assertEqual(resid_pdrop, c.resid_pdrop, "mismatch for key: resid_pdrop")
self.assertEqual(scale_attn_weights, c.scale_attn_weights, "mismatch for key: scale_attn_weights")
self.assertEqual(summary_type, c.summary_type, "mismatch for key: summary_type")