Fix bnb training test failure (#34414)
* Fix bnb training test: compatibility with OPTSdpaAttention
This commit is contained in:
@@ -29,6 +29,7 @@ from transformers import (
|
||||
BitsAndBytesConfig,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.models.opt.modeling_opt import OPTAttention
|
||||
from transformers.testing_utils import (
|
||||
apply_skip_if_not_implemented,
|
||||
is_bitsandbytes_available,
|
||||
@@ -565,7 +566,7 @@ class Bnb4BitTestTraining(Base4bitTest):
|
||||
|
||||
# Step 2: add adapters
|
||||
for _, module in model.named_modules():
|
||||
if "OPTAttention" in repr(type(module)):
|
||||
if isinstance(module, OPTAttention):
|
||||
module.q_proj = LoRALayer(module.q_proj, rank=16)
|
||||
module.k_proj = LoRALayer(module.k_proj, rank=16)
|
||||
module.v_proj = LoRALayer(module.v_proj, rank=16)
|
||||
|
||||
@@ -29,6 +29,7 @@ from transformers import (
|
||||
BitsAndBytesConfig,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.models.opt.modeling_opt import OPTAttention
|
||||
from transformers.testing_utils import (
|
||||
apply_skip_if_not_implemented,
|
||||
is_accelerate_available,
|
||||
@@ -868,7 +869,7 @@ class MixedInt8TestTraining(BaseMixedInt8Test):
|
||||
|
||||
# Step 2: add adapters
|
||||
for _, module in model.named_modules():
|
||||
if "OPTAttention" in repr(type(module)):
|
||||
if isinstance(module, OPTAttention):
|
||||
module.q_proj = LoRALayer(module.q_proj, rank=16)
|
||||
module.k_proj = LoRALayer(module.k_proj, rank=16)
|
||||
module.v_proj = LoRALayer(module.v_proj, rank=16)
|
||||
|
||||
Reference in New Issue
Block a user