[RAG] Add Ray implementation for distributed retrieval (#9197)

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* uncomment

* uncomment

* wip

* updates

* add docstring

* updates

* fix arg

* fixes

* add unit tests

* update readme

* update readme

* update finetune script

* update test

* add test

* add ray to test dependencies

* separate ray and ray tune

* formatting

* shutdown ray at end of test

* fix tests

* formatting

* formatting

* even more formatting

* address comments

* formatting

* add files

* Update examples/research_projects/rag/test_distributed_retriever.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* address comments

* addressing comments

Co-authored-by: Ubuntu <ubuntu@ip-172-31-21-208.us-west-2.compute.internal>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Amog Kamsetty
2020-12-21 01:39:30 -08:00
committed by GitHub
parent f38c4ad302
commit a4b21cdd20
14 changed files with 561 additions and 56 deletions

View File

@@ -219,6 +219,7 @@ from .integrations import ( # isort:skip
is_comet_available,
is_optuna_available,
is_ray_available,
is_ray_tune_available,
is_tensorboard_available,
is_wandb_available,
)

View File

@@ -63,8 +63,16 @@ try:
import ray # noqa: F401
_has_ray = True
try:
# Ray Tune has additional dependencies.
from ray import tune # noqa: F401
_has_ray_tune = True
except (ImportError):
_has_ray_tune = False
except (ImportError):
_has_ray = False
_has_ray_tune = False
try:
from torch.utils.tensorboard import SummaryWriter # noqa: F401
@@ -127,6 +135,10 @@ def is_ray_available():
return _has_ray
def is_ray_tune_available():
return _has_ray_tune
def is_azureml_available():
return _has_azureml
@@ -143,7 +155,7 @@ def hp_params(trial):
if is_optuna_available():
if isinstance(trial, optuna.Trial):
return trial.params
if is_ray_available():
if is_ray_tune_available():
if isinstance(trial, dict):
return trial
@@ -153,7 +165,7 @@ def hp_params(trial):
def default_hp_search_backend():
if is_optuna_available():
return "optuna"
elif is_ray_available():
elif is_ray_tune_available():
return "ray"

View File

@@ -370,9 +370,8 @@ class RagRetriever:
"""
_init_retrieval = True
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None):
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None, init_retrieval=True):
self._init_retrieval = init_retrieval
requires_datasets(self)
requires_faiss(self)
super().__init__()

View File

@@ -37,7 +37,7 @@ from .integrations import ( # isort: split
is_fairscale_available,
is_mlflow_available,
is_optuna_available,
is_ray_available,
is_ray_tune_available,
is_tensorboard_available,
is_wandb_available,
run_hp_search_optuna,
@@ -145,7 +145,7 @@ if is_mlflow_available():
if is_optuna_available():
import optuna
if is_ray_available():
if is_ray_tune_available():
from ray import tune
if is_azureml_available():
@@ -1062,7 +1062,7 @@ class Trainer:
backend = HPSearchBackend(backend)
if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
if backend == HPSearchBackend.RAY and not is_ray_available():
if backend == HPSearchBackend.RAY and not is_ray_tune_available():
raise RuntimeError(
"You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
)

View File

@@ -132,9 +132,9 @@ def default_hp_space_optuna(trial) -> Dict[str, float]:
def default_hp_space_ray(trial) -> Dict[str, float]:
from .integrations import is_ray_available
from .integrations import is_ray_tune_available
assert is_ray_available(), "This function needs ray installed: `pip install ray[tune]`"
assert is_ray_tune_available(), "This function needs ray installed: `pip " "install ray[tune]`"
from ray import tune
return {