Fix doc errors and typos across the board (#8139)
* Fix doc errors and typos across the board * Fix a typo * Fix the CI * Fix more typos * Fix CI * More fixes * Fix CI * More fixes * More fixes
This commit is contained in:
@@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
|
||||
class RagPyTorchDistributedRetriever(RagRetriever):
|
||||
"""
|
||||
A distributed retriever built on top of the ``torch.distributed`` communication package. During training all workers
|
||||
initalize their own instance of the retriever, however, only the main worker loads the index into memory. The index is stored
|
||||
initialize their own instance of the retriever, however, only the main worker loads the index into memory. The index is stored
|
||||
in cpu memory. The index will also work well in a non-distributed setup.
|
||||
|
||||
Args:
|
||||
@@ -45,7 +45,7 @@ class RagPyTorchDistributedRetriever(RagRetriever):
|
||||
|
||||
def init_retrieval(self, distributed_port: int):
|
||||
"""
|
||||
Retriever initalization function, needs to be called from the training process. The function sets some common parameters
|
||||
Retriever initialization function, needs to be called from the training process. The function sets some common parameters
|
||||
and environment variables. On top of that, (only) the main process in the process group loads the index into memory.
|
||||
|
||||
Args:
|
||||
@@ -56,7 +56,7 @@ class RagPyTorchDistributedRetriever(RagRetriever):
|
||||
|
||||
logger.info("initializing retrieval")
|
||||
|
||||
# initializing a separate process group for retrievel as the default
|
||||
# initializing a separate process group for retrieval as the default
|
||||
# nccl backend doesn't support gather/scatter operations while gloo
|
||||
# is too slow to replace nccl for the core gpu communication
|
||||
if dist.is_initialized():
|
||||
@@ -101,7 +101,7 @@ class RagPyTorchDistributedRetriever(RagRetriever):
|
||||
n_docs (:obj:`int`):
|
||||
The number of docs retrieved per query.
|
||||
|
||||
Ouput:
|
||||
Output:
|
||||
retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)`
|
||||
The retrieval embeddings of the retrieved docs per query.
|
||||
doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`)
|
||||
|
||||
@@ -176,7 +176,7 @@ def get_args():
|
||||
choices=["e2e", "retrieval"],
|
||||
default="e2e",
|
||||
type=str,
|
||||
help="Evaluation mode, e2e calculates exact match and F1 of the downstream task, retrieval calulates precision@k.",
|
||||
help="Evaluation mode, e2e calculates exact match and F1 of the downstream task, retrieval calculates precision@k.",
|
||||
)
|
||||
parser.add_argument("--k", default=1, type=int, help="k for the precision@k calculation")
|
||||
parser.add_argument(
|
||||
@@ -206,7 +206,7 @@ def get_args():
|
||||
"--predictions_path",
|
||||
type=str,
|
||||
default="predictions.txt",
|
||||
help="Name of the predictions file, to be stored in the checkpoints directry",
|
||||
help="Name of the predictions file, to be stored in the checkpoints directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_all_checkpoints",
|
||||
|
||||
@@ -26,7 +26,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def split_text(text: str, n=100, character=" ") -> List[str]:
|
||||
"""Split the text every ``n``-th occurence of ``character``"""
|
||||
"""Split the text every ``n``-th occurrence of ``character``"""
|
||||
text = text.split(character)
|
||||
return [character.join(text[i : i + n]).strip() for i in range(0, len(text), n)]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user