[VLMs] fix flash-attention tests (#37603)
* fix one test * fa2 ln test * remove keys from config recursively * fix * fixup
This commit is contained in:
committed by
GitHub
parent
02baa61fab
commit
1cfcbfcab8
@@ -4444,7 +4444,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
# once the weights have been quantized
|
||||
# Note that once you have loaded a quantized model, you can't change its dtype so this will
|
||||
# remain a single source of truth
|
||||
config._pre_quantization_dtype = torch_dtype if torch_dtype is not None else torch.get_default_dtype()
|
||||
original_dtype = torch_dtype if torch_dtype is not None else torch.get_default_dtype()
|
||||
|
||||
def _assign_original_dtype(module):
|
||||
for child in module.children():
|
||||
if isinstance(child, PreTrainedModel):
|
||||
child.config._pre_quantization_dtype = original_dtype
|
||||
_assign_original_dtype(child)
|
||||
|
||||
config._pre_quantization_dtype = original_dtype
|
||||
_assign_original_dtype(model)
|
||||
|
||||
# Prepare the full device map
|
||||
if device_map is not None:
|
||||
|
||||
Reference in New Issue
Block a user