New run_clm script (#8105)
* New run_clm script * Formatting * More comments * Remove unused imports * Apply suggestions from code review Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com> * Address review comments * Change link to the hub Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
@@ -34,6 +34,7 @@ sys.path.extend(SRC_DIRS)
|
||||
|
||||
|
||||
if SRC_DIRS is not None:
|
||||
import run_clm
|
||||
import run_generation
|
||||
import run_glue
|
||||
import run_language_modeling
|
||||
@@ -128,6 +129,38 @@ class ExamplesTests(TestCasePlus):
|
||||
# self.assertGreaterEqual(v, 0.75, f"({k})")
|
||||
#
|
||||
|
||||
def test_run_clm(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_clm.py
|
||||
--model_name_or_path distilgpt2
|
||||
--train_file ./tests/fixtures/sample_text.txt
|
||||
--validation_file ./tests/fixtures/sample_text.txt
|
||||
--do_train
|
||||
--do_eval
|
||||
--block_size 128
|
||||
--per_device_train_batch_size 5
|
||||
--per_device_eval_batch_size 5
|
||||
--num_train_epochs 2
|
||||
--output_dir {tmp_dir}
|
||||
--overwrite_output_dir
|
||||
--prediction_loss_only
|
||||
""".split()
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
|
||||
return
|
||||
|
||||
if torch_device != "cuda":
|
||||
testargs.append("--no_cuda")
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
result = run_clm.main()
|
||||
self.assertLess(result["perplexity"], 100)
|
||||
|
||||
def test_run_language_modeling(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
Reference in New Issue
Block a user