Fix more inefficient PT operations (#37060)

* Fix inefficient operations

* Remove cpu() call

* Reorder detach()

* Reorder detach()

* tolist without detach

* item without detach

* Update src/transformers/models/rag/modeling_rag.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update tests/models/encodec/test_modeling_encodec.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Use detach().cpu().numpy

* Revert some numpy operations

* More fixes

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
cyyever
2025-03-31 23:31:24 +08:00
committed by GitHub
parent a1e389e637
commit 786d9c5ed9
54 changed files with 106 additions and 104 deletions

View File

@@ -311,7 +311,7 @@ class RagTestMixin:
out = retriever(
input_ids,
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
question_hidden_states.detach().to(device="cpu", dtype=torch.float32).numpy(),
prefix=config.generator.prefix,
return_tensors="pt",
)
@@ -379,7 +379,7 @@ class RagTestMixin:
out = retriever(
input_ids,
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
question_hidden_states.detach().to(device="cpu", dtype=torch.float32).numpy(),
prefix=config.generator.prefix,
return_tensors="pt",
)
@@ -438,7 +438,7 @@ class RagTestMixin:
out = retriever(
input_ids,
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
question_hidden_states.detach().to(device="cpu", dtype=torch.float32).numpy(),
prefix=config.generator.prefix,
return_tensors="pt",
n_docs=n_docs,
@@ -507,7 +507,7 @@ class RagTestMixin:
out = retriever(
input_ids,
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
question_hidden_states.detach().to(device="cpu", dtype=torch.float32).numpy(),
prefix=config.generator.prefix,
return_tensors="pt",
n_docs=retriever_n_docs,
@@ -964,7 +964,7 @@ class RagModelIntegrationTests(unittest.TestCase):
question_hidden_states = rag_sequence.question_encoder(input_ids, attention_mask=attention_mask)[0]
docs_dict = retriever(
input_ids.cpu().detach().numpy(), question_hidden_states.cpu().detach().numpy(), return_tensors="pt"
input_ids.detach().cpu().numpy(), question_hidden_states.detach().cpu().numpy(), return_tensors="pt"
)
doc_scores = torch.bmm(
question_hidden_states.unsqueeze(1),