[Examples] create model with custom config on the fly (#11798)
* create custom model on the flight * better wording * add update_from_string * cleanup * cleanup * Update src/transformers/configuration_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * more bool options * style * fix logger * add test * add the doc * assert on conflict of options Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -161,3 +161,21 @@ concatenates all texts and then splits them in blocks of the same length).
|
||||
|
||||
**Note:** On TPU, you should use the flag `--pad_to_max_length` in conjunction with the `--line_by_line` flag to make
|
||||
sure all your batches have the same length.
|
||||
|
||||
|
||||
## Creating a model on the fly
|
||||
|
||||
When training a model from scratch, configuration values may be overridden with the help of `--config_overrides`:
|
||||
|
||||
|
||||
```bash
|
||||
python run_clm.py --model_type gpt2 --tokenizer_name gpt2 \ --config_overrides="n_embd=1024,n_head=16,n_layer=48,n_positions=102" \
|
||||
[...]
|
||||
```
|
||||
|
||||
At the moment this is only available in `run_clm.py` but eventually should be copied to all other LM examples.
|
||||
|
||||
This feature can also be used to activate gradient checkpointing by passing:
|
||||
```
|
||||
--config_overrides "gradient_checkpointing=true,use_cache=False"
|
||||
```
|
||||
|
||||
@@ -75,6 +75,13 @@ class ModelArguments:
|
||||
default=None,
|
||||
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
||||
)
|
||||
config_overrides: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Override some existing default config settings when a model is trained from scratch. Example: "
|
||||
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
||||
},
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
@@ -101,6 +108,12 @@ class ModelArguments:
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
|
||||
raise ValueError(
|
||||
"--config_overrides can't be used in combination with --config_name or --model_name_or_path"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
@@ -279,6 +292,9 @@ def main():
|
||||
else:
|
||||
config = CONFIG_MAPPING[model_args.model_type]()
|
||||
logger.warning("You are instantiating a new config instance from scratch.")
|
||||
if model_args.config_overrides is not None:
|
||||
logger.info(f"Overriding config: {model_args.config_overrides}")
|
||||
config.update_from_string(model_args.config_overrides)
|
||||
|
||||
tokenizer_kwargs = {
|
||||
"cache_dir": model_args.cache_dir,
|
||||
@@ -306,8 +322,9 @@ def main():
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
else:
|
||||
logger.info("Training new model from scratch")
|
||||
model = AutoModelForCausalLM.from_config(config)
|
||||
n_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
|
||||
logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
|
||||
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
|
||||
@@ -667,7 +667,45 @@ class PretrainedConfig(PushToHubMixin):
|
||||
Updates attributes of this class with attributes from ``config_dict``.
|
||||
|
||||
Args:
|
||||
config_dict (:obj:`Dict[str, Any]`): Dictionary of attributes that shall be updated for this class.
|
||||
config_dict (:obj:`Dict[str, Any]`): Dictionary of attributes that should be updated for this class.
|
||||
"""
|
||||
for key, value in config_dict.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def update_from_string(self, update_str: str):
|
||||
"""
|
||||
Updates attributes of this class with attributes from ``update_str``.
|
||||
|
||||
The expected format is ints, floats and strings as is, and for booleans use ``true`` or ``false``. For example:
|
||||
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
||||
|
||||
The keys to change have to already exist in the config object.
|
||||
|
||||
Args:
|
||||
update_str (:obj:`str`): String with attributes that should be updated for this class.
|
||||
|
||||
"""
|
||||
|
||||
d = dict(x.split("=") for x in update_str.split(","))
|
||||
for k, v in d.items():
|
||||
if not hasattr(self, k):
|
||||
raise ValueError(f"key {k} isn't in the original config dict")
|
||||
|
||||
old_v = getattr(self, k)
|
||||
if isinstance(old_v, bool):
|
||||
if v.lower() in ["true", "1", "y", "yes"]:
|
||||
v = True
|
||||
elif v.lower() in ["false", "0", "n", "no"]:
|
||||
v = False
|
||||
else:
|
||||
raise ValueError(f"can't derive true or false from {v} (key {k})")
|
||||
elif isinstance(old_v, int):
|
||||
v = int(v)
|
||||
elif isinstance(old_v, float):
|
||||
v = float(v)
|
||||
elif not isinstance(old_v, str):
|
||||
raise ValueError(
|
||||
f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
|
||||
)
|
||||
|
||||
setattr(self, k, v)
|
||||
|
||||
@@ -21,7 +21,7 @@ import unittest
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import BertConfig
|
||||
from transformers import BertConfig, GPT2Config
|
||||
from transformers.testing_utils import ENDPOINT_STAGING, PASS, USER, is_staging_test
|
||||
|
||||
|
||||
@@ -138,3 +138,21 @@ class ConfigPushToHubTester(unittest.TestCase):
|
||||
for k, v in config.__dict__.items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
|
||||
class ConfigTestUtils(unittest.TestCase):
|
||||
def test_config_from_string(self):
|
||||
c = GPT2Config()
|
||||
|
||||
# attempt to modify each of int/float/bool/str config records and verify they were updated
|
||||
n_embd = c.n_embd + 1 # int
|
||||
resid_pdrop = c.resid_pdrop + 1.0 # float
|
||||
scale_attn_weights = not c.scale_attn_weights # bool
|
||||
summary_type = c.summary_type + "foo" # str
|
||||
c.update_from_string(
|
||||
f"n_embd={n_embd},resid_pdrop={resid_pdrop},scale_attn_weights={scale_attn_weights},summary_type={summary_type}"
|
||||
)
|
||||
self.assertEqual(n_embd, c.n_embd, "mismatch for key: n_embd")
|
||||
self.assertEqual(resid_pdrop, c.resid_pdrop, "mismatch for key: resid_pdrop")
|
||||
self.assertEqual(scale_attn_weights, c.scale_attn_weights, "mismatch for key: scale_attn_weights")
|
||||
self.assertEqual(summary_type, c.summary_type, "mismatch for key: summary_type")
|
||||
|
||||
Reference in New Issue
Block a user