From 9de62cfbceff92f9b0444a8652a88699c5979f23 Mon Sep 17 00:00:00 2001 From: Kumar Abhishek Date: Mon, 14 Jun 2021 05:12:22 -0700 Subject: [PATCH] [lm examples] Replicate --config_overrides addition to other LM examples (#12135) * [lm examples] Replicate --config_overrides addition to other LM examples * Removing no trainer files changes * Update README Co-authored-by: Kumar Abhishek --- examples/pytorch/language-modeling/README.md | 2 +- examples/pytorch/language-modeling/run_mlm.py | 16 ++++++++++++++++ examples/pytorch/language-modeling/run_plm.py | 16 ++++++++++++++++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/examples/pytorch/language-modeling/README.md b/examples/pytorch/language-modeling/README.md index 7340986c0e..23989d7ed1 100644 --- a/examples/pytorch/language-modeling/README.md +++ b/examples/pytorch/language-modeling/README.md @@ -173,7 +173,7 @@ python run_clm.py --model_type gpt2 --tokenizer_name gpt2 \ --config_overrides=" [...] ``` -At the moment this is only available in `run_clm.py` but eventually should be copied to all other LM examples. +This feature is only available in `run_clm.py`, `run_plm.py` and `run_mlm.py`. This feature can also be used to activate gradient checkpointing by passing: ``` diff --git a/examples/pytorch/language-modeling/run_mlm.py b/examples/pytorch/language-modeling/run_mlm.py index 7612e05226..da687aea1f 100755 --- a/examples/pytorch/language-modeling/run_mlm.py +++ b/examples/pytorch/language-modeling/run_mlm.py @@ -72,6 +72,13 @@ class ModelArguments: default=None, metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, ) + config_overrides: Optional[str] = field( + default=None, + metadata={ + "help": "Override some existing default config settings when a model is trained from scratch. Example: " + "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" + }, + ) config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) @@ -98,6 +105,12 @@ class ModelArguments: }, ) + def __post_init__(self): + if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): + raise ValueError( + "--config_overrides can't be used in combination with --config_name or --model_name_or_path" + ) + @dataclass class DataTrainingArguments: @@ -283,6 +296,9 @@ def main(): else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning("You are instantiating a new config instance from scratch.") + if model_args.config_overrides is not None: + logger.info(f"Overriding config: {model_args.config_overrides}") + config.update_from_string(model_args.config_overrides) tokenizer_kwargs = { "cache_dir": model_args.cache_dir, diff --git a/examples/pytorch/language-modeling/run_plm.py b/examples/pytorch/language-modeling/run_plm.py index aa30de041b..b4cf5f5323 100755 --- a/examples/pytorch/language-modeling/run_plm.py +++ b/examples/pytorch/language-modeling/run_plm.py @@ -65,6 +65,13 @@ class ModelArguments: config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) + config_overrides: Optional[str] = field( + default=None, + metadata={ + "help": "Override some existing default config settings when a model is trained from scratch. Example: " + "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" + }, + ) tokenizer_name: Optional[str] = field( default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} ) @@ -88,6 +95,12 @@ class ModelArguments: }, ) + def __post_init__(self): + if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): + raise ValueError( + "--config_overrides can't be used in combination with --config_name or --model_name_or_path" + ) + @dataclass class DataTrainingArguments: @@ -280,6 +293,9 @@ def main(): else: config = XLNetConfig() logger.warning("You are instantiating a new config instance from scratch.") + if model_args.config_overrides is not None: + logger.info(f"Overriding config: {model_args.config_overrides}") + config.update_from_string(model_args.config_overrides) tokenizer_kwargs = { "cache_dir": model_args.cache_dir,