[VLMs] fix flash-attention tests (#37603)

* fix one test

* fa2 ln test

* remove keys from config recursively

* fix

* fixup
This commit is contained in:
Raushan Turganbay
2025-04-24 11:48:11 +02:00
committed by GitHub
parent 02baa61fab
commit 1cfcbfcab8
17 changed files with 52 additions and 83 deletions

View File

@@ -4444,7 +4444,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# once the weights have been quantized
# Note that once you have loaded a quantized model, you can't change its dtype so this will
# remain a single source of truth
config._pre_quantization_dtype = torch_dtype if torch_dtype is not None else torch.get_default_dtype()
original_dtype = torch_dtype if torch_dtype is not None else torch.get_default_dtype()
def _assign_original_dtype(module):
for child in module.children():
if isinstance(child, PreTrainedModel):
child.config._pre_quantization_dtype = original_dtype
_assign_original_dtype(child)
config._pre_quantization_dtype = original_dtype
_assign_original_dtype(model)
# Prepare the full device map
if device_map is not None: