[Examples] create model with custom config on the fly (#11798)
* create custom model on the flight * better wording * add update_from_string * cleanup * cleanup * Update src/transformers/configuration_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * more bool options * style * fix logger * add test * add the doc * assert on conflict of options Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -161,3 +161,21 @@ concatenates all texts and then splits them in blocks of the same length).
|
|||||||
|
|
||||||
**Note:** On TPU, you should use the flag `--pad_to_max_length` in conjunction with the `--line_by_line` flag to make
|
**Note:** On TPU, you should use the flag `--pad_to_max_length` in conjunction with the `--line_by_line` flag to make
|
||||||
sure all your batches have the same length.
|
sure all your batches have the same length.
|
||||||
|
|
||||||
|
|
||||||
|
## Creating a model on the fly
|
||||||
|
|
||||||
|
When training a model from scratch, configuration values may be overridden with the help of `--config_overrides`:
|
||||||
|
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python run_clm.py --model_type gpt2 --tokenizer_name gpt2 \ --config_overrides="n_embd=1024,n_head=16,n_layer=48,n_positions=102" \
|
||||||
|
[...]
|
||||||
|
```
|
||||||
|
|
||||||
|
At the moment this is only available in `run_clm.py` but eventually should be copied to all other LM examples.
|
||||||
|
|
||||||
|
This feature can also be used to activate gradient checkpointing by passing:
|
||||||
|
```
|
||||||
|
--config_overrides "gradient_checkpointing=true,use_cache=False"
|
||||||
|
```
|
||||||
|
|||||||
@@ -75,6 +75,13 @@ class ModelArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
||||||
)
|
)
|
||||||
|
config_overrides: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Override some existing default config settings when a model is trained from scratch. Example: "
|
||||||
|
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
||||||
|
},
|
||||||
|
)
|
||||||
config_name: Optional[str] = field(
|
config_name: Optional[str] = field(
|
||||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||||
)
|
)
|
||||||
@@ -101,6 +108,12 @@ class ModelArguments:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
|
||||||
|
raise ValueError(
|
||||||
|
"--config_overrides can't be used in combination with --config_name or --model_name_or_path"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataTrainingArguments:
|
class DataTrainingArguments:
|
||||||
@@ -279,6 +292,9 @@ def main():
|
|||||||
else:
|
else:
|
||||||
config = CONFIG_MAPPING[model_args.model_type]()
|
config = CONFIG_MAPPING[model_args.model_type]()
|
||||||
logger.warning("You are instantiating a new config instance from scratch.")
|
logger.warning("You are instantiating a new config instance from scratch.")
|
||||||
|
if model_args.config_overrides is not None:
|
||||||
|
logger.info(f"Overriding config: {model_args.config_overrides}")
|
||||||
|
config.update_from_string(model_args.config_overrides)
|
||||||
|
|
||||||
tokenizer_kwargs = {
|
tokenizer_kwargs = {
|
||||||
"cache_dir": model_args.cache_dir,
|
"cache_dir": model_args.cache_dir,
|
||||||
@@ -306,8 +322,9 @@ def main():
|
|||||||
use_auth_token=True if model_args.use_auth_token else None,
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("Training new model from scratch")
|
|
||||||
model = AutoModelForCausalLM.from_config(config)
|
model = AutoModelForCausalLM.from_config(config)
|
||||||
|
n_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
|
||||||
|
logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
|
||||||
|
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
|
|||||||
@@ -667,7 +667,45 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
Updates attributes of this class with attributes from ``config_dict``.
|
Updates attributes of this class with attributes from ``config_dict``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config_dict (:obj:`Dict[str, Any]`): Dictionary of attributes that shall be updated for this class.
|
config_dict (:obj:`Dict[str, Any]`): Dictionary of attributes that should be updated for this class.
|
||||||
"""
|
"""
|
||||||
for key, value in config_dict.items():
|
for key, value in config_dict.items():
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
def update_from_string(self, update_str: str):
|
||||||
|
"""
|
||||||
|
Updates attributes of this class with attributes from ``update_str``.
|
||||||
|
|
||||||
|
The expected format is ints, floats and strings as is, and for booleans use ``true`` or ``false``. For example:
|
||||||
|
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
||||||
|
|
||||||
|
The keys to change have to already exist in the config object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
update_str (:obj:`str`): String with attributes that should be updated for this class.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
d = dict(x.split("=") for x in update_str.split(","))
|
||||||
|
for k, v in d.items():
|
||||||
|
if not hasattr(self, k):
|
||||||
|
raise ValueError(f"key {k} isn't in the original config dict")
|
||||||
|
|
||||||
|
old_v = getattr(self, k)
|
||||||
|
if isinstance(old_v, bool):
|
||||||
|
if v.lower() in ["true", "1", "y", "yes"]:
|
||||||
|
v = True
|
||||||
|
elif v.lower() in ["false", "0", "n", "no"]:
|
||||||
|
v = False
|
||||||
|
else:
|
||||||
|
raise ValueError(f"can't derive true or false from {v} (key {k})")
|
||||||
|
elif isinstance(old_v, int):
|
||||||
|
v = int(v)
|
||||||
|
elif isinstance(old_v, float):
|
||||||
|
v = float(v)
|
||||||
|
elif not isinstance(old_v, str):
|
||||||
|
raise ValueError(
|
||||||
|
f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(self, k, v)
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import unittest
|
|||||||
|
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
from transformers import BertConfig
|
from transformers import BertConfig, GPT2Config
|
||||||
from transformers.testing_utils import ENDPOINT_STAGING, PASS, USER, is_staging_test
|
from transformers.testing_utils import ENDPOINT_STAGING, PASS, USER, is_staging_test
|
||||||
|
|
||||||
|
|
||||||
@@ -138,3 +138,21 @@ class ConfigPushToHubTester(unittest.TestCase):
|
|||||||
for k, v in config.__dict__.items():
|
for k, v in config.__dict__.items():
|
||||||
if k != "transformers_version":
|
if k != "transformers_version":
|
||||||
self.assertEqual(v, getattr(new_config, k))
|
self.assertEqual(v, getattr(new_config, k))
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigTestUtils(unittest.TestCase):
|
||||||
|
def test_config_from_string(self):
|
||||||
|
c = GPT2Config()
|
||||||
|
|
||||||
|
# attempt to modify each of int/float/bool/str config records and verify they were updated
|
||||||
|
n_embd = c.n_embd + 1 # int
|
||||||
|
resid_pdrop = c.resid_pdrop + 1.0 # float
|
||||||
|
scale_attn_weights = not c.scale_attn_weights # bool
|
||||||
|
summary_type = c.summary_type + "foo" # str
|
||||||
|
c.update_from_string(
|
||||||
|
f"n_embd={n_embd},resid_pdrop={resid_pdrop},scale_attn_weights={scale_attn_weights},summary_type={summary_type}"
|
||||||
|
)
|
||||||
|
self.assertEqual(n_embd, c.n_embd, "mismatch for key: n_embd")
|
||||||
|
self.assertEqual(resid_pdrop, c.resid_pdrop, "mismatch for key: resid_pdrop")
|
||||||
|
self.assertEqual(scale_attn_weights, c.scale_attn_weights, "mismatch for key: scale_attn_weights")
|
||||||
|
self.assertEqual(summary_type, c.summary_type, "mismatch for key: summary_type")
|
||||||
|
|||||||
Reference in New Issue
Block a user