Fix more inefficient PT operations (#37060)
* Fix inefficient operations * Remove cpu() call * Reorder detach() * Reorder detach() * tolist without detach * item without detach * Update src/transformers/models/rag/modeling_rag.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update tests/models/encodec/test_modeling_encodec.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Use detach().cpu().numpy * Revert some numpy operations * More fixes --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
@@ -2378,14 +2378,14 @@ class AttentionMaskTester(unittest.TestCase):
|
||||
num_tokens_masked = bsz * (q_len * (q_len - 1) // 2)
|
||||
|
||||
if 0 not in mask_2d:
|
||||
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
|
||||
assert (mask_4d != 0).sum().item() == num_tokens_masked
|
||||
if 0 in mask_2d:
|
||||
# at least causal mask + maybe more
|
||||
assert (mask_4d != 0).sum().cpu().item() >= num_tokens_masked
|
||||
assert (mask_4d != 0).sum().item() >= num_tokens_masked
|
||||
self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
|
||||
elif not mask_converter.is_causal and context is None:
|
||||
if 0 not in mask_2d:
|
||||
assert (mask_4d != 0).sum().cpu().item() == 0
|
||||
assert (mask_4d != 0).sum().item() == 0
|
||||
if 0 in mask_2d:
|
||||
self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
|
||||
elif mask_converter.is_causal and context is not None:
|
||||
@@ -2394,10 +2394,10 @@ class AttentionMaskTester(unittest.TestCase):
|
||||
num_tokens_masked = bsz * num_tokens_masked
|
||||
|
||||
if 0 not in mask_2d:
|
||||
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
|
||||
assert (mask_4d != 0).sum().item() == num_tokens_masked
|
||||
if 0 in mask_2d:
|
||||
# at least causal mask + maybe more
|
||||
assert (mask_4d != 0).sum().cpu().item() >= num_tokens_masked
|
||||
assert (mask_4d != 0).sum().item() >= num_tokens_masked
|
||||
self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
|
||||
|
||||
def check_to_causal(self, mask_converter, q_len, kv_len, bsz=3):
|
||||
@@ -2415,15 +2415,15 @@ class AttentionMaskTester(unittest.TestCase):
|
||||
# k * (k+1) / 2 tokens are masked in triangualar masks
|
||||
num_tokens_masked = bsz * (q_len * (q_len - 1) // 2)
|
||||
|
||||
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
|
||||
assert (mask_4d != 0).sum().item() == num_tokens_masked
|
||||
elif not mask_converter.is_causal and context is None:
|
||||
assert (mask_4d != 0).sum().cpu().item() == 0
|
||||
assert (mask_4d != 0).sum().item() == 0
|
||||
elif mask_converter.is_causal and context is not None:
|
||||
# k * (k+1) / 2 tokens are masked in triangualar masks
|
||||
num_tokens_masked = (q_len * (q_len - 1) // 2) + self.compute_num_context_mask(kv_len, context, q_len)
|
||||
num_tokens_masked = bsz * num_tokens_masked
|
||||
|
||||
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
|
||||
assert (mask_4d != 0).sum().item() == num_tokens_masked
|
||||
|
||||
def compute_num_context_mask(self, kv_len, context, q_len):
|
||||
# This function computes the # of attention tokens that are added for
|
||||
|
||||
Reference in New Issue
Block a user