From 11f65d41581c5d12a96dc7856b7c458ddf47d855 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 22 Nov 2021 08:33:43 -0800 Subject: [PATCH] [test] add test for --config_overrides (#14466) * add test for --config_overrides * remove unneeded parts of the test --- examples/pytorch/language-modeling/run_clm.py | 1 + examples/pytorch/language-modeling/run_mlm.py | 1 + examples/pytorch/language-modeling/run_plm.py | 1 + examples/pytorch/test_examples.py | 27 ++++++++++++++++++- 4 files changed, 29 insertions(+), 1 deletion(-) diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index fdbf8e2095..444df1b809 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -324,6 +324,7 @@ def main(): if model_args.config_overrides is not None: logger.info(f"Overriding config: {model_args.config_overrides}") config.update_from_string(model_args.config_overrides) + logger.info(f"New config: {config}") tokenizer_kwargs = { "cache_dir": model_args.cache_dir, diff --git a/examples/pytorch/language-modeling/run_mlm.py b/examples/pytorch/language-modeling/run_mlm.py index 5cd3edd3ba..a1b5b7aca3 100755 --- a/examples/pytorch/language-modeling/run_mlm.py +++ b/examples/pytorch/language-modeling/run_mlm.py @@ -326,6 +326,7 @@ def main(): if model_args.config_overrides is not None: logger.info(f"Overriding config: {model_args.config_overrides}") config.update_from_string(model_args.config_overrides) + logger.info(f"New config: {config}") tokenizer_kwargs = { "cache_dir": model_args.cache_dir, diff --git a/examples/pytorch/language-modeling/run_plm.py b/examples/pytorch/language-modeling/run_plm.py index 4428755236..840bfa9ad6 100755 --- a/examples/pytorch/language-modeling/run_plm.py +++ b/examples/pytorch/language-modeling/run_plm.py @@ -318,6 +318,7 @@ def main(): if model_args.config_overrides is not None: logger.info(f"Overriding config: {model_args.config_overrides}") config.update_from_string(model_args.config_overrides) + logger.info(f"New config: {config}") tokenizer_kwargs = { "cache_dir": model_args.cache_dir, diff --git a/examples/pytorch/test_examples.py b/examples/pytorch/test_examples.py index de045630d8..1a1c2ea06a 100644 --- a/examples/pytorch/test_examples.py +++ b/examples/pytorch/test_examples.py @@ -25,7 +25,7 @@ import torch from transformers import Wav2Vec2ForPreTraining from transformers.file_utils import is_apex_available -from transformers.testing_utils import TestCasePlus, get_gpu_count, slow, torch_device +from transformers.testing_utils import CaptureLogger, TestCasePlus, get_gpu_count, slow, torch_device SRC_DIRS = [ @@ -157,6 +157,31 @@ class ExamplesTests(TestCasePlus): result = get_results(tmp_dir) self.assertLess(result["perplexity"], 100) + def test_run_clm_config_overrides(self): + # test that config_overrides works, despite the misleading dumps of default un-updated + # config via tokenizer + + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" + run_clm.py + --model_type gpt2 + --tokenizer_name gpt2 + --train_file ./tests/fixtures/sample_text.txt + --output_dir {tmp_dir} + --config_overrides n_embd=10,n_head=2 + """.split() + + if torch_device != "cuda": + testargs.append("--no_cuda") + + logger = run_clm.logger + with patch.object(sys, "argv", testargs): + with CaptureLogger(logger) as cl: + run_clm.main() + + self.assertIn('"n_embd": 10', cl.out) + self.assertIn('"n_head": 2', cl.out) + def test_run_mlm(self): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler)