Add a template for examples and apply it for mlm and plm examples (#8153)
* Add a template for example scripts and apply it to mlm * Formatting * Fix test * Add plm script * Add a template for example scripts and apply it to mlm * Formatting * Fix test * Add plm script * Add a template for example scripts and apply it to mlm * Formatting * Fix test * Add plm script * Styling
This commit is contained in:
@@ -37,7 +37,7 @@ if SRC_DIRS is not None:
|
||||
import run_clm
|
||||
import run_generation
|
||||
import run_glue
|
||||
import run_language_modeling
|
||||
import run_mlm
|
||||
import run_pl_glue
|
||||
import run_squad
|
||||
|
||||
@@ -160,31 +160,29 @@ class ExamplesTests(TestCasePlus):
|
||||
result = run_clm.main()
|
||||
self.assertLess(result["perplexity"], 100)
|
||||
|
||||
def test_run_language_modeling(self):
|
||||
def test_run_mlm(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_language_modeling.py
|
||||
run_mlm.py
|
||||
--model_name_or_path distilroberta-base
|
||||
--model_type roberta
|
||||
--mlm
|
||||
--line_by_line
|
||||
--train_data_file ./tests/fixtures/sample_text.txt
|
||||
--eval_data_file ./tests/fixtures/sample_text.txt
|
||||
--train_file ./tests/fixtures/sample_text.txt
|
||||
--validation_file ./tests/fixtures/sample_text.txt
|
||||
--output_dir {tmp_dir}
|
||||
--overwrite_output_dir
|
||||
--do_train
|
||||
--do_eval
|
||||
--prediction_loss_only
|
||||
--num_train_epochs=1
|
||||
""".split()
|
||||
""".split()
|
||||
|
||||
if torch_device != "cuda":
|
||||
testargs.append("--no_cuda")
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
result = run_language_modeling.main()
|
||||
result = run_mlm.main()
|
||||
self.assertLess(result["perplexity"], 42)
|
||||
|
||||
def test_run_squad(self):
|
||||
|
||||
Reference in New Issue
Block a user