From 1b6530104d5b0a6c14c7db07252a74c03ac8e163 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Tue, 25 May 2021 10:40:49 -0700 Subject: [PATCH] [Examples] create model with custom config on the fly (#11798) * create custom model on the flight * better wording * add update_from_string * cleanup * cleanup * Update src/transformers/configuration_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * more bool options * style * fix logger * add test * add the doc * assert on conflict of options Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- examples/pytorch/language-modeling/README.md | 18 +++++++++ examples/pytorch/language-modeling/run_clm.py | 19 ++++++++- src/transformers/configuration_utils.py | 40 ++++++++++++++++++- tests/test_configuration_common.py | 20 +++++++++- 4 files changed, 94 insertions(+), 3 deletions(-) diff --git a/examples/pytorch/language-modeling/README.md b/examples/pytorch/language-modeling/README.md index a479fd6716..7340986c0e 100644 --- a/examples/pytorch/language-modeling/README.md +++ b/examples/pytorch/language-modeling/README.md @@ -161,3 +161,21 @@ concatenates all texts and then splits them in blocks of the same length). **Note:** On TPU, you should use the flag `--pad_to_max_length` in conjunction with the `--line_by_line` flag to make sure all your batches have the same length. + + +## Creating a model on the fly + +When training a model from scratch, configuration values may be overridden with the help of `--config_overrides`: + + +```bash +python run_clm.py --model_type gpt2 --tokenizer_name gpt2 \ --config_overrides="n_embd=1024,n_head=16,n_layer=48,n_positions=102" \ +[...] +``` + +At the moment this is only available in `run_clm.py` but eventually should be copied to all other LM examples. + +This feature can also be used to activate gradient checkpointing by passing: +``` +--config_overrides "gradient_checkpointing=true,use_cache=False" +``` diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index c3bf39ffce..0c95e7c423 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -75,6 +75,13 @@ class ModelArguments: default=None, metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, ) + config_overrides: Optional[str] = field( + default=None, + metadata={ + "help": "Override some existing default config settings when a model is trained from scratch. Example: " + "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" + }, + ) config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) @@ -101,6 +108,12 @@ class ModelArguments: }, ) + def __post_init__(self): + if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): + raise ValueError( + "--config_overrides can't be used in combination with --config_name or --model_name_or_path" + ) + @dataclass class DataTrainingArguments: @@ -279,6 +292,9 @@ def main(): else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning("You are instantiating a new config instance from scratch.") + if model_args.config_overrides is not None: + logger.info(f"Overriding config: {model_args.config_overrides}") + config.update_from_string(model_args.config_overrides) tokenizer_kwargs = { "cache_dir": model_args.cache_dir, @@ -306,8 +322,9 @@ def main(): use_auth_token=True if model_args.use_auth_token else None, ) else: - logger.info("Training new model from scratch") model = AutoModelForCausalLM.from_config(config) + n_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values()) + logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params") model.resize_token_embeddings(len(tokenizer)) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 6553d3f42e..4f88eb4e2c 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -667,7 +667,45 @@ class PretrainedConfig(PushToHubMixin): Updates attributes of this class with attributes from ``config_dict``. Args: - config_dict (:obj:`Dict[str, Any]`): Dictionary of attributes that shall be updated for this class. + config_dict (:obj:`Dict[str, Any]`): Dictionary of attributes that should be updated for this class. """ for key, value in config_dict.items(): setattr(self, key, value) + + def update_from_string(self, update_str: str): + """ + Updates attributes of this class with attributes from ``update_str``. + + The expected format is ints, floats and strings as is, and for booleans use ``true`` or ``false``. For example: + "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" + + The keys to change have to already exist in the config object. + + Args: + update_str (:obj:`str`): String with attributes that should be updated for this class. + + """ + + d = dict(x.split("=") for x in update_str.split(",")) + for k, v in d.items(): + if not hasattr(self, k): + raise ValueError(f"key {k} isn't in the original config dict") + + old_v = getattr(self, k) + if isinstance(old_v, bool): + if v.lower() in ["true", "1", "y", "yes"]: + v = True + elif v.lower() in ["false", "0", "n", "no"]: + v = False + else: + raise ValueError(f"can't derive true or false from {v} (key {k})") + elif isinstance(old_v, int): + v = int(v) + elif isinstance(old_v, float): + v = float(v) + elif not isinstance(old_v, str): + raise ValueError( + f"You can only update int, float, bool or string values in the config, got {v} for key {k}" + ) + + setattr(self, k, v) diff --git a/tests/test_configuration_common.py b/tests/test_configuration_common.py index 596c73e989..84c86d1161 100644 --- a/tests/test_configuration_common.py +++ b/tests/test_configuration_common.py @@ -21,7 +21,7 @@ import unittest from huggingface_hub import HfApi from requests.exceptions import HTTPError -from transformers import BertConfig +from transformers import BertConfig, GPT2Config from transformers.testing_utils import ENDPOINT_STAGING, PASS, USER, is_staging_test @@ -138,3 +138,21 @@ class ConfigPushToHubTester(unittest.TestCase): for k, v in config.__dict__.items(): if k != "transformers_version": self.assertEqual(v, getattr(new_config, k)) + + +class ConfigTestUtils(unittest.TestCase): + def test_config_from_string(self): + c = GPT2Config() + + # attempt to modify each of int/float/bool/str config records and verify they were updated + n_embd = c.n_embd + 1 # int + resid_pdrop = c.resid_pdrop + 1.0 # float + scale_attn_weights = not c.scale_attn_weights # bool + summary_type = c.summary_type + "foo" # str + c.update_from_string( + f"n_embd={n_embd},resid_pdrop={resid_pdrop},scale_attn_weights={scale_attn_weights},summary_type={summary_type}" + ) + self.assertEqual(n_embd, c.n_embd, "mismatch for key: n_embd") + self.assertEqual(resid_pdrop, c.resid_pdrop, "mismatch for key: resid_pdrop") + self.assertEqual(scale_attn_weights, c.scale_attn_weights, "mismatch for key: scale_attn_weights") + self.assertEqual(summary_type, c.summary_type, "mismatch for key: summary_type")