Ray Tune Integration Updates (#12134)

* fix

* fixes

* add back to scheduled tests

* formatting

* Update integrations.py
This commit is contained in:
Amog Kamsetty
2021-06-15 11:11:29 -07:00
committed by GitHub
parent a79585bbf9
commit b9d66f4c4b
3 changed files with 33 additions and 6 deletions

View File

@@ -1307,7 +1307,7 @@ class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):
self.n_epochs = args.num_train_epochs
self.batch_size = args.train_batch_size
def test_hyperparameter_search(self):
def ray_hyperparameter_search(self):
class MyTrialShortNamer(TrialShortNamer):
DEFAULTS = {"a": 0, "b": 0}
@@ -1320,7 +1320,13 @@ class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):
}
def model_init(config):
model_config = RegressionModelConfig(a=config["a"], b=config["b"], double_output=False)
if config is None:
a = 0
b = 0
else:
a = config["a"]
b = config["b"]
model_config = RegressionModelConfig(a=a, b=b, double_output=False)
return RegressionPreTrainedModel(model_config)
@@ -1343,3 +1349,14 @@ class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):
trainer.hyperparameter_search(
direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="ray", n_trials=4
)
def test_hyperparameter_search(self):
self.ray_hyperparameter_search()
def test_hyperparameter_search_ray_client(self):
import ray
from ray.util.client.ray_client_helpers import ray_start_client_server
with ray_start_client_server():
assert ray.util.client.ray.is_connected()
self.ray_hyperparameter_search()