Fix flakey test with seed (#20318)
This commit is contained in:
@@ -26,7 +26,7 @@ from unittest import mock
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from accelerate.utils import write_basic_config
|
from accelerate.utils import write_basic_config
|
||||||
from transformers.testing_utils import TestCasePlus, get_gpu_count, is_flaky, run_command, slow, torch_device
|
from transformers.testing_utils import TestCasePlus, get_gpu_count, run_command, slow, torch_device
|
||||||
from transformers.utils import is_apex_available
|
from transformers.utils import is_apex_available
|
||||||
|
|
||||||
|
|
||||||
@@ -176,7 +176,6 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
|||||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "ner_no_trainer")))
|
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "ner_no_trainer")))
|
||||||
|
|
||||||
@is_flaky()
|
|
||||||
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline"})
|
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline"})
|
||||||
def test_run_squad_no_trainer(self):
|
def test_run_squad_no_trainer(self):
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
@@ -187,6 +186,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
|||||||
--train_file tests/fixtures/tests_samples/SQUAD/sample.json
|
--train_file tests/fixtures/tests_samples/SQUAD/sample.json
|
||||||
--validation_file tests/fixtures/tests_samples/SQUAD/sample.json
|
--validation_file tests/fixtures/tests_samples/SQUAD/sample.json
|
||||||
--output_dir {tmp_dir}
|
--output_dir {tmp_dir}
|
||||||
|
--seed=42
|
||||||
--max_train_steps=10
|
--max_train_steps=10
|
||||||
--num_warmup_steps=2
|
--num_warmup_steps=2
|
||||||
--learning_rate=2e-4
|
--learning_rate=2e-4
|
||||||
|
|||||||
Reference in New Issue
Block a user