Add parameter --config_overrides for run_mlm_wwm.py (#16961)
* dd parameter --config_overrides for run_mlm_wwm.py * linter
This commit is contained in:
@@ -69,6 +69,13 @@ class ModelArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
||||||
)
|
)
|
||||||
|
config_overrides: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Override some existing default config settings when a model is trained from scratch. Example: "
|
||||||
|
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
||||||
|
},
|
||||||
|
)
|
||||||
config_name: Optional[str] = field(
|
config_name: Optional[str] = field(
|
||||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||||
)
|
)
|
||||||
@@ -95,6 +102,12 @@ class ModelArguments:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
|
||||||
|
raise ValueError(
|
||||||
|
"--config_overrides can't be used in combination with --config_name or --model_name_or_path"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataTrainingArguments:
|
class DataTrainingArguments:
|
||||||
@@ -275,6 +288,10 @@ def main():
|
|||||||
else:
|
else:
|
||||||
config = CONFIG_MAPPING[model_args.model_type]()
|
config = CONFIG_MAPPING[model_args.model_type]()
|
||||||
logger.warning("You are instantiating a new config instance from scratch.")
|
logger.warning("You are instantiating a new config instance from scratch.")
|
||||||
|
if model_args.config_overrides is not None:
|
||||||
|
logger.info(f"Overriding config: {model_args.config_overrides}")
|
||||||
|
config.update_from_string(model_args.config_overrides)
|
||||||
|
logger.info(f"New config: {config}")
|
||||||
|
|
||||||
tokenizer_kwargs = {
|
tokenizer_kwargs = {
|
||||||
"cache_dir": model_args.cache_dir,
|
"cache_dir": model_args.cache_dir,
|
||||||
|
|||||||
Reference in New Issue
Block a user