Fix rag finetuning + add finetuning test (#8585)
* replace init_ddp_connection for index init * style * add finetune test * add test data * move generate tensors to device * add test on EM metric * style * allow multi process test * keep gloo process group for retrieval * add multi-gpu test * use custom accelerator * clean test finetune * minor * style * style * typo * use python call instead of imported main fumction * return_dict fix in modeling_rag * use float32 in retrieval * store as float32 as well in the custom knowledge dataset example * style * rename to finetune_rag * style * update readme * rename utils and callbacks to utils_rag and callbacks_rag * fix test * patrick's comments * generate dummy data in the finetue test script * remove dummy data files * style
This commit is contained in:
@@ -556,7 +556,9 @@ class RagModel(RagPreTrainedModel):
|
||||
if encoder_outputs is None:
|
||||
|
||||
if has_to_retrieve:
|
||||
question_enc_outputs = self.question_encoder(input_ids, attention_mask=attention_mask)
|
||||
question_enc_outputs = self.question_encoder(
|
||||
input_ids, attention_mask=attention_mask, return_dict=True
|
||||
)
|
||||
question_encoder_last_hidden_state = question_enc_outputs[0] # hidden states of question encoder
|
||||
|
||||
retriever_outputs = self.retriever(
|
||||
@@ -616,6 +618,7 @@ class RagModel(RagPreTrainedModel):
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
if not has_to_retrieve:
|
||||
|
||||
@@ -196,7 +196,7 @@ class HFIndexBase(Index):
|
||||
self.dataset = dataset
|
||||
self._index_initialized = index_initialized
|
||||
self._check_dataset_format(with_index=index_initialized)
|
||||
dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True)
|
||||
dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True, dtype="float32")
|
||||
|
||||
def _check_dataset_format(self, with_index: bool):
|
||||
if not isinstance(self.dataset, Dataset):
|
||||
|
||||
Reference in New Issue
Block a user