New run_clm script (#8105)

* New run_clm script

* Formatting

* More comments

* Remove unused imports

* Apply suggestions from code review

Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>

* Address review comments

* Change link to the hub

Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
Sylvain Gugger
2020-10-28 10:38:58 -04:00
committed by GitHub
parent 8065fea870
commit 47dfa65b0c
3 changed files with 379 additions and 2 deletions

View File

@@ -34,6 +34,7 @@ sys.path.extend(SRC_DIRS)
if SRC_DIRS is not None:
import run_clm
import run_generation
import run_glue
import run_language_modeling
@@ -128,6 +129,38 @@ class ExamplesTests(TestCasePlus):
# self.assertGreaterEqual(v, 0.75, f"({k})")
#
def test_run_clm(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_clm.py
--model_name_or_path distilgpt2
--train_file ./tests/fixtures/sample_text.txt
--validation_file ./tests/fixtures/sample_text.txt
--do_train
--do_eval
--block_size 128
--per_device_train_batch_size 5
--per_device_eval_batch_size 5
--num_train_epochs 2
--output_dir {tmp_dir}
--overwrite_output_dir
--prediction_loss_only
""".split()
if torch.cuda.device_count() > 1:
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
return
if torch_device != "cuda":
testargs.append("--no_cuda")
with patch.object(sys, "argv", testargs):
result = run_clm.main()
self.assertLess(result["perplexity"], 100)
def test_run_language_modeling(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)