Fix more inefficient PT operations (#37060)
* Fix inefficient operations * Remove cpu() call * Reorder detach() * Reorder detach() * tolist without detach * item without detach * Update src/transformers/models/rag/modeling_rag.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update tests/models/encodec/test_modeling_encodec.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Use detach().cpu().numpy * Revert some numpy operations * More fixes --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
@@ -311,7 +311,7 @@ class RagTestMixin:
|
||||
|
||||
out = retriever(
|
||||
input_ids,
|
||||
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
|
||||
question_hidden_states.detach().to(device="cpu", dtype=torch.float32).numpy(),
|
||||
prefix=config.generator.prefix,
|
||||
return_tensors="pt",
|
||||
)
|
||||
@@ -379,7 +379,7 @@ class RagTestMixin:
|
||||
|
||||
out = retriever(
|
||||
input_ids,
|
||||
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
|
||||
question_hidden_states.detach().to(device="cpu", dtype=torch.float32).numpy(),
|
||||
prefix=config.generator.prefix,
|
||||
return_tensors="pt",
|
||||
)
|
||||
@@ -438,7 +438,7 @@ class RagTestMixin:
|
||||
|
||||
out = retriever(
|
||||
input_ids,
|
||||
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
|
||||
question_hidden_states.detach().to(device="cpu", dtype=torch.float32).numpy(),
|
||||
prefix=config.generator.prefix,
|
||||
return_tensors="pt",
|
||||
n_docs=n_docs,
|
||||
@@ -507,7 +507,7 @@ class RagTestMixin:
|
||||
|
||||
out = retriever(
|
||||
input_ids,
|
||||
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
|
||||
question_hidden_states.detach().to(device="cpu", dtype=torch.float32).numpy(),
|
||||
prefix=config.generator.prefix,
|
||||
return_tensors="pt",
|
||||
n_docs=retriever_n_docs,
|
||||
@@ -964,7 +964,7 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
question_hidden_states = rag_sequence.question_encoder(input_ids, attention_mask=attention_mask)[0]
|
||||
docs_dict = retriever(
|
||||
input_ids.cpu().detach().numpy(), question_hidden_states.cpu().detach().numpy(), return_tensors="pt"
|
||||
input_ids.detach().cpu().numpy(), question_hidden_states.detach().cpu().numpy(), return_tensors="pt"
|
||||
)
|
||||
doc_scores = torch.bmm(
|
||||
question_hidden_states.unsqueeze(1),
|
||||
|
||||
Reference in New Issue
Block a user