From 1be8d56ec6f7113810adc716255d371e78e8a1af Mon Sep 17 00:00:00 2001 From: conan1024hao <50416856+conan1024hao@users.noreply.github.com> Date: Thu, 28 Apr 2022 23:44:55 +0900 Subject: [PATCH] Add parameter --config_overrides for run_mlm_wwm.py (#16961) * dd parameter --config_overrides for run_mlm_wwm.py * linter --- .../research_projects/mlm_wwm/run_mlm_wwm.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/examples/research_projects/mlm_wwm/run_mlm_wwm.py b/examples/research_projects/mlm_wwm/run_mlm_wwm.py index f528dbd46c..51c05ab0b3 100644 --- a/examples/research_projects/mlm_wwm/run_mlm_wwm.py +++ b/examples/research_projects/mlm_wwm/run_mlm_wwm.py @@ -69,6 +69,13 @@ class ModelArguments: default=None, metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, ) + config_overrides: Optional[str] = field( + default=None, + metadata={ + "help": "Override some existing default config settings when a model is trained from scratch. Example: " + "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" + }, + ) config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) @@ -95,6 +102,12 @@ class ModelArguments: }, ) + def __post_init__(self): + if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): + raise ValueError( + "--config_overrides can't be used in combination with --config_name or --model_name_or_path" + ) + @dataclass class DataTrainingArguments: @@ -275,6 +288,10 @@ def main(): else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning("You are instantiating a new config instance from scratch.") + if model_args.config_overrides is not None: + logger.info(f"Overriding config: {model_args.config_overrides}") + config.update_from_string(model_args.config_overrides) + logger.info(f"New config: {config}") tokenizer_kwargs = { "cache_dir": model_args.cache_dir,