fix failing test_sdpa_can_dispatch_on_flash (#39259)
* fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -993,16 +993,23 @@ def check_model_inputs(func):
|
|||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(self, *args, **kwargs):
|
def wrapper(self, *args, **kwargs):
|
||||||
use_cache = kwargs.get("use_cache", getattr(self.config, "use_cache", False))
|
use_cache = kwargs.get("use_cache", None)
|
||||||
return_dict = kwargs.pop("return_dict", getattr(self.config, "return_dict", True))
|
if use_cache is None:
|
||||||
all_args = kwargs.copy()
|
use_cache = getattr(self.config, "use_cache", False)
|
||||||
|
|
||||||
|
return_dict = kwargs.pop("return_dict", None)
|
||||||
|
if return_dict is None:
|
||||||
|
return_dict = getattr(self.config, "return_dict", True)
|
||||||
|
|
||||||
if getattr(self, "gradient_checkpointing", False) and self.training and use_cache:
|
if getattr(self, "gradient_checkpointing", False) and self.training and use_cache:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||||
)
|
)
|
||||||
kwargs["use_cache"] = False
|
use_cache = False
|
||||||
|
|
||||||
|
kwargs["use_cache"] = use_cache
|
||||||
|
|
||||||
|
all_args = kwargs.copy()
|
||||||
if "kwargs" in all_args:
|
if "kwargs" in all_args:
|
||||||
for k, v in all_args["kwargs"].items():
|
for k, v in all_args["kwargs"].items():
|
||||||
all_args[k] = v
|
all_args[k] = v
|
||||||
|
|||||||
@@ -311,7 +311,7 @@ class T5GemmaModelTester:
|
|||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
labels=lm_labels,
|
labels=lm_labels,
|
||||||
)
|
)
|
||||||
self.parent.assertEqual(len(outputs), 4)
|
self.parent.assertEqual(len(outputs), 5)
|
||||||
self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.seq_length, self.vocab_size))
|
self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.seq_length, self.vocab_size))
|
||||||
self.parent.assertEqual(outputs["loss"].size(), ())
|
self.parent.assertEqual(outputs["loss"].size(), ())
|
||||||
|
|
||||||
@@ -1067,7 +1067,7 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
|
|
||||||
for i in range(num_decoder_layers):
|
for i in range(num_decoder_layers):
|
||||||
if is_legacy_cache:
|
if is_legacy_cache:
|
||||||
self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple
|
self.assertEqual(len(past_kv[0]), 5) # legacy check: confirm number of elements in tuple
|
||||||
|
|
||||||
# Self attention
|
# Self attention
|
||||||
self_attention_layer_key_cache = (
|
self_attention_layer_key_cache = (
|
||||||
@@ -1687,7 +1687,7 @@ class TestAsymmetricT5Gemma(unittest.TestCase):
|
|||||||
labels=lm_labels,
|
labels=lm_labels,
|
||||||
)
|
)
|
||||||
# outputs = model(*inputs)
|
# outputs = model(*inputs)
|
||||||
assert len(outputs) == 4
|
assert len(outputs) == 5
|
||||||
assert outputs["logits"].size() == (tester.batch_size, tester.seq_length, tester.vocab_size)
|
assert outputs["logits"].size() == (tester.batch_size, tester.seq_length, tester.vocab_size)
|
||||||
assert outputs["loss"].size() == ()
|
assert outputs["loss"].size() == ()
|
||||||
return model.model
|
return model.model
|
||||||
|
|||||||
Reference in New Issue
Block a user