Fix rag finetuning + add finetuning test (#8585)

* replace init_ddp_connection for index init

* style

* add finetune test

* add test data

* move generate tensors to device

* add test on EM metric

* style

* allow multi process test

* keep gloo process group for retrieval

* add multi-gpu test

* use custom accelerator

* clean test finetune

* minor

* style

* style

* typo

* use python call instead of imported main fumction

* return_dict fix in modeling_rag

* use float32 in retrieval

* store as float32 as well in the custom knowledge dataset example

* style

* rename to finetune_rag

* style

* update readme

* rename utils and callbacks to utils_rag and callbacks_rag

* fix test

* patrick's comments

* generate dummy data in the finetue test script

* remove dummy data files

* style
This commit is contained in:
Quentin Lhoest
2020-11-20 19:05:03 +01:00
committed by GitHub
parent 63e91f5fde
commit 8062fa63c5
11 changed files with 200 additions and 56 deletions

View File

@@ -556,7 +556,9 @@ class RagModel(RagPreTrainedModel):
if encoder_outputs is None:
if has_to_retrieve:
question_enc_outputs = self.question_encoder(input_ids, attention_mask=attention_mask)
question_enc_outputs = self.question_encoder(
input_ids, attention_mask=attention_mask, return_dict=True
)
question_encoder_last_hidden_state = question_enc_outputs[0] # hidden states of question encoder
retriever_outputs = self.retriever(
@@ -616,6 +618,7 @@ class RagModel(RagPreTrainedModel):
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
return_dict=True,
)
if not has_to_retrieve:

View File

@@ -196,7 +196,7 @@ class HFIndexBase(Index):
self.dataset = dataset
self._index_initialized = index_initialized
self._check_dataset_format(with_index=index_initialized)
dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True)
dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True, dtype="float32")
def _check_dataset_format(self, with_index: bool):
if not isinstance(self.dataset, Dataset):