Fix more inefficient PT operations (#37060)

* Fix inefficient operations

* Remove cpu() call

* Reorder detach()

* Reorder detach()

* tolist without detach

* item without detach

* Update src/transformers/models/rag/modeling_rag.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update tests/models/encodec/test_modeling_encodec.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Use detach().cpu().numpy

* Revert some numpy operations

* More fixes

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
cyyever
2025-03-31 23:31:24 +08:00
committed by GitHub
parent a1e389e637
commit 786d9c5ed9
54 changed files with 106 additions and 104 deletions

View File

@@ -204,7 +204,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
time_to_first_token = end - start
logger.info(f"completed first compile generation in: {time_to_first_token}s")
cache_position += 1
all_generated_tokens += next_token.clone().detach().cpu().tolist()
all_generated_tokens += next_token.tolist()
cache_position = torch.tensor([seq_length], device=device)
### First compile, decoding
@@ -217,7 +217,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
time_to_second_token = end - start
logger.info(f"completed second compile generation in: {time_to_second_token}s")
cache_position += 1
all_generated_tokens += next_token.clone().detach().cpu().tolist()
all_generated_tokens += next_token.tolist()
### Second compile, decoding
start = perf_counter()
@@ -229,13 +229,13 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
time_to_third_token = end - start
logger.info(f"completed third compile forward in: {time_to_third_token}s")
cache_position += 1
all_generated_tokens += next_token.clone().detach().cpu().tolist()
all_generated_tokens += next_token.tolist()
### Using cuda graphs decoding
start = perf_counter()
for _ in range(1, num_tokens_to_generate):
all_generated_tokens += next_token.clone().detach().cpu().tolist()
all_generated_tokens += next_token.tolist()
next_token = decode_one_token(
model, next_token.clone(), cache_position=cache_position, past_key_values=past_key_values
)