Ray Tune Integration Updates (#12134)
* fix * fixes * add back to scheduled tests * formatting * Update integrations.py
This commit is contained in:
@@ -1307,7 +1307,7 @@ class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):
|
||||
self.n_epochs = args.num_train_epochs
|
||||
self.batch_size = args.train_batch_size
|
||||
|
||||
def test_hyperparameter_search(self):
|
||||
def ray_hyperparameter_search(self):
|
||||
class MyTrialShortNamer(TrialShortNamer):
|
||||
DEFAULTS = {"a": 0, "b": 0}
|
||||
|
||||
@@ -1320,7 +1320,13 @@ class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):
|
||||
}
|
||||
|
||||
def model_init(config):
|
||||
model_config = RegressionModelConfig(a=config["a"], b=config["b"], double_output=False)
|
||||
if config is None:
|
||||
a = 0
|
||||
b = 0
|
||||
else:
|
||||
a = config["a"]
|
||||
b = config["b"]
|
||||
model_config = RegressionModelConfig(a=a, b=b, double_output=False)
|
||||
|
||||
return RegressionPreTrainedModel(model_config)
|
||||
|
||||
@@ -1343,3 +1349,14 @@ class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):
|
||||
trainer.hyperparameter_search(
|
||||
direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="ray", n_trials=4
|
||||
)
|
||||
|
||||
def test_hyperparameter_search(self):
|
||||
self.ray_hyperparameter_search()
|
||||
|
||||
def test_hyperparameter_search_ray_client(self):
|
||||
import ray
|
||||
from ray.util.client.ray_client_helpers import ray_start_client_server
|
||||
|
||||
with ray_start_client_server():
|
||||
assert ray.util.client.ray.is_connected()
|
||||
self.ray_hyperparameter_search()
|
||||
|
||||
Reference in New Issue
Block a user