[RAG] Add Ray implementation for distributed retrieval (#9197)
* wip * wip * wip * wip * wip * wip * wip * wip * uncomment * uncomment * wip * updates * add docstring * updates * fix arg * fixes * add unit tests * update readme * update readme * update finetune script * update test * add test * add ray to test dependencies * separate ray and ray tune * formatting * shutdown ray at end of test * fix tests * formatting * formatting * even more formatting * address comments * formatting * add files * Update examples/research_projects/rag/test_distributed_retriever.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * address comments * addressing comments Co-authored-by: Ubuntu <ubuntu@ip-172-31-21-208.us-west-2.compute.internal> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -219,6 +219,7 @@ from .integrations import ( # isort:skip
|
||||
is_comet_available,
|
||||
is_optuna_available,
|
||||
is_ray_available,
|
||||
is_ray_tune_available,
|
||||
is_tensorboard_available,
|
||||
is_wandb_available,
|
||||
)
|
||||
|
||||
@@ -63,8 +63,16 @@ try:
|
||||
import ray # noqa: F401
|
||||
|
||||
_has_ray = True
|
||||
try:
|
||||
# Ray Tune has additional dependencies.
|
||||
from ray import tune # noqa: F401
|
||||
|
||||
_has_ray_tune = True
|
||||
except (ImportError):
|
||||
_has_ray_tune = False
|
||||
except (ImportError):
|
||||
_has_ray = False
|
||||
_has_ray_tune = False
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter # noqa: F401
|
||||
@@ -127,6 +135,10 @@ def is_ray_available():
|
||||
return _has_ray
|
||||
|
||||
|
||||
def is_ray_tune_available():
|
||||
return _has_ray_tune
|
||||
|
||||
|
||||
def is_azureml_available():
|
||||
return _has_azureml
|
||||
|
||||
@@ -143,7 +155,7 @@ def hp_params(trial):
|
||||
if is_optuna_available():
|
||||
if isinstance(trial, optuna.Trial):
|
||||
return trial.params
|
||||
if is_ray_available():
|
||||
if is_ray_tune_available():
|
||||
if isinstance(trial, dict):
|
||||
return trial
|
||||
|
||||
@@ -153,7 +165,7 @@ def hp_params(trial):
|
||||
def default_hp_search_backend():
|
||||
if is_optuna_available():
|
||||
return "optuna"
|
||||
elif is_ray_available():
|
||||
elif is_ray_tune_available():
|
||||
return "ray"
|
||||
|
||||
|
||||
|
||||
@@ -370,9 +370,8 @@ class RagRetriever:
|
||||
|
||||
"""
|
||||
|
||||
_init_retrieval = True
|
||||
|
||||
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None):
|
||||
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None, init_retrieval=True):
|
||||
self._init_retrieval = init_retrieval
|
||||
requires_datasets(self)
|
||||
requires_faiss(self)
|
||||
super().__init__()
|
||||
|
||||
@@ -37,7 +37,7 @@ from .integrations import ( # isort: split
|
||||
is_fairscale_available,
|
||||
is_mlflow_available,
|
||||
is_optuna_available,
|
||||
is_ray_available,
|
||||
is_ray_tune_available,
|
||||
is_tensorboard_available,
|
||||
is_wandb_available,
|
||||
run_hp_search_optuna,
|
||||
@@ -145,7 +145,7 @@ if is_mlflow_available():
|
||||
if is_optuna_available():
|
||||
import optuna
|
||||
|
||||
if is_ray_available():
|
||||
if is_ray_tune_available():
|
||||
from ray import tune
|
||||
|
||||
if is_azureml_available():
|
||||
@@ -1062,7 +1062,7 @@ class Trainer:
|
||||
backend = HPSearchBackend(backend)
|
||||
if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
|
||||
raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
|
||||
if backend == HPSearchBackend.RAY and not is_ray_available():
|
||||
if backend == HPSearchBackend.RAY and not is_ray_tune_available():
|
||||
raise RuntimeError(
|
||||
"You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
|
||||
)
|
||||
|
||||
@@ -132,9 +132,9 @@ def default_hp_space_optuna(trial) -> Dict[str, float]:
|
||||
|
||||
|
||||
def default_hp_space_ray(trial) -> Dict[str, float]:
|
||||
from .integrations import is_ray_available
|
||||
from .integrations import is_ray_tune_available
|
||||
|
||||
assert is_ray_available(), "This function needs ray installed: `pip install ray[tune]`"
|
||||
assert is_ray_tune_available(), "This function needs ray installed: `pip " "install ray[tune]`"
|
||||
from ray import tune
|
||||
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user