[test] add test for --config_overrides (#14466)
* add test for --config_overrides * remove unneeded parts of the test
This commit is contained in:
@@ -25,7 +25,7 @@ import torch
|
||||
|
||||
from transformers import Wav2Vec2ForPreTraining
|
||||
from transformers.file_utils import is_apex_available
|
||||
from transformers.testing_utils import TestCasePlus, get_gpu_count, slow, torch_device
|
||||
from transformers.testing_utils import CaptureLogger, TestCasePlus, get_gpu_count, slow, torch_device
|
||||
|
||||
|
||||
SRC_DIRS = [
|
||||
@@ -157,6 +157,31 @@ class ExamplesTests(TestCasePlus):
|
||||
result = get_results(tmp_dir)
|
||||
self.assertLess(result["perplexity"], 100)
|
||||
|
||||
def test_run_clm_config_overrides(self):
|
||||
# test that config_overrides works, despite the misleading dumps of default un-updated
|
||||
# config via tokenizer
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_clm.py
|
||||
--model_type gpt2
|
||||
--tokenizer_name gpt2
|
||||
--train_file ./tests/fixtures/sample_text.txt
|
||||
--output_dir {tmp_dir}
|
||||
--config_overrides n_embd=10,n_head=2
|
||||
""".split()
|
||||
|
||||
if torch_device != "cuda":
|
||||
testargs.append("--no_cuda")
|
||||
|
||||
logger = run_clm.logger
|
||||
with patch.object(sys, "argv", testargs):
|
||||
with CaptureLogger(logger) as cl:
|
||||
run_clm.main()
|
||||
|
||||
self.assertIn('"n_embd": 10', cl.out)
|
||||
self.assertIn('"n_head": 2', cl.out)
|
||||
|
||||
def test_run_mlm(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
Reference in New Issue
Block a user