[test] add test for --config_overrides (#14466)
* add test for --config_overrides * remove unneeded parts of the test
This commit is contained in:
@@ -324,6 +324,7 @@ def main():
|
|||||||
if model_args.config_overrides is not None:
|
if model_args.config_overrides is not None:
|
||||||
logger.info(f"Overriding config: {model_args.config_overrides}")
|
logger.info(f"Overriding config: {model_args.config_overrides}")
|
||||||
config.update_from_string(model_args.config_overrides)
|
config.update_from_string(model_args.config_overrides)
|
||||||
|
logger.info(f"New config: {config}")
|
||||||
|
|
||||||
tokenizer_kwargs = {
|
tokenizer_kwargs = {
|
||||||
"cache_dir": model_args.cache_dir,
|
"cache_dir": model_args.cache_dir,
|
||||||
|
|||||||
@@ -326,6 +326,7 @@ def main():
|
|||||||
if model_args.config_overrides is not None:
|
if model_args.config_overrides is not None:
|
||||||
logger.info(f"Overriding config: {model_args.config_overrides}")
|
logger.info(f"Overriding config: {model_args.config_overrides}")
|
||||||
config.update_from_string(model_args.config_overrides)
|
config.update_from_string(model_args.config_overrides)
|
||||||
|
logger.info(f"New config: {config}")
|
||||||
|
|
||||||
tokenizer_kwargs = {
|
tokenizer_kwargs = {
|
||||||
"cache_dir": model_args.cache_dir,
|
"cache_dir": model_args.cache_dir,
|
||||||
|
|||||||
@@ -318,6 +318,7 @@ def main():
|
|||||||
if model_args.config_overrides is not None:
|
if model_args.config_overrides is not None:
|
||||||
logger.info(f"Overriding config: {model_args.config_overrides}")
|
logger.info(f"Overriding config: {model_args.config_overrides}")
|
||||||
config.update_from_string(model_args.config_overrides)
|
config.update_from_string(model_args.config_overrides)
|
||||||
|
logger.info(f"New config: {config}")
|
||||||
|
|
||||||
tokenizer_kwargs = {
|
tokenizer_kwargs = {
|
||||||
"cache_dir": model_args.cache_dir,
|
"cache_dir": model_args.cache_dir,
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ import torch
|
|||||||
|
|
||||||
from transformers import Wav2Vec2ForPreTraining
|
from transformers import Wav2Vec2ForPreTraining
|
||||||
from transformers.file_utils import is_apex_available
|
from transformers.file_utils import is_apex_available
|
||||||
from transformers.testing_utils import TestCasePlus, get_gpu_count, slow, torch_device
|
from transformers.testing_utils import CaptureLogger, TestCasePlus, get_gpu_count, slow, torch_device
|
||||||
|
|
||||||
|
|
||||||
SRC_DIRS = [
|
SRC_DIRS = [
|
||||||
@@ -157,6 +157,31 @@ class ExamplesTests(TestCasePlus):
|
|||||||
result = get_results(tmp_dir)
|
result = get_results(tmp_dir)
|
||||||
self.assertLess(result["perplexity"], 100)
|
self.assertLess(result["perplexity"], 100)
|
||||||
|
|
||||||
|
def test_run_clm_config_overrides(self):
|
||||||
|
# test that config_overrides works, despite the misleading dumps of default un-updated
|
||||||
|
# config via tokenizer
|
||||||
|
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
testargs = f"""
|
||||||
|
run_clm.py
|
||||||
|
--model_type gpt2
|
||||||
|
--tokenizer_name gpt2
|
||||||
|
--train_file ./tests/fixtures/sample_text.txt
|
||||||
|
--output_dir {tmp_dir}
|
||||||
|
--config_overrides n_embd=10,n_head=2
|
||||||
|
""".split()
|
||||||
|
|
||||||
|
if torch_device != "cuda":
|
||||||
|
testargs.append("--no_cuda")
|
||||||
|
|
||||||
|
logger = run_clm.logger
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
|
run_clm.main()
|
||||||
|
|
||||||
|
self.assertIn('"n_embd": 10', cl.out)
|
||||||
|
self.assertIn('"n_head": 2', cl.out)
|
||||||
|
|
||||||
def test_run_mlm(self):
|
def test_run_mlm(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
|
|||||||
Reference in New Issue
Block a user