Fix bnb training test failure (#34414)
* Fix bnb training test: compatibility with OPTSdpaAttention
This commit is contained in:
@@ -29,6 +29,7 @@ from transformers import (
|
|||||||
BitsAndBytesConfig,
|
BitsAndBytesConfig,
|
||||||
pipeline,
|
pipeline,
|
||||||
)
|
)
|
||||||
|
from transformers.models.opt.modeling_opt import OPTAttention
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
apply_skip_if_not_implemented,
|
apply_skip_if_not_implemented,
|
||||||
is_bitsandbytes_available,
|
is_bitsandbytes_available,
|
||||||
@@ -565,7 +566,7 @@ class Bnb4BitTestTraining(Base4bitTest):
|
|||||||
|
|
||||||
# Step 2: add adapters
|
# Step 2: add adapters
|
||||||
for _, module in model.named_modules():
|
for _, module in model.named_modules():
|
||||||
if "OPTAttention" in repr(type(module)):
|
if isinstance(module, OPTAttention):
|
||||||
module.q_proj = LoRALayer(module.q_proj, rank=16)
|
module.q_proj = LoRALayer(module.q_proj, rank=16)
|
||||||
module.k_proj = LoRALayer(module.k_proj, rank=16)
|
module.k_proj = LoRALayer(module.k_proj, rank=16)
|
||||||
module.v_proj = LoRALayer(module.v_proj, rank=16)
|
module.v_proj = LoRALayer(module.v_proj, rank=16)
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from transformers import (
|
|||||||
BitsAndBytesConfig,
|
BitsAndBytesConfig,
|
||||||
pipeline,
|
pipeline,
|
||||||
)
|
)
|
||||||
|
from transformers.models.opt.modeling_opt import OPTAttention
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
apply_skip_if_not_implemented,
|
apply_skip_if_not_implemented,
|
||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
@@ -868,7 +869,7 @@ class MixedInt8TestTraining(BaseMixedInt8Test):
|
|||||||
|
|
||||||
# Step 2: add adapters
|
# Step 2: add adapters
|
||||||
for _, module in model.named_modules():
|
for _, module in model.named_modules():
|
||||||
if "OPTAttention" in repr(type(module)):
|
if isinstance(module, OPTAttention):
|
||||||
module.q_proj = LoRALayer(module.q_proj, rank=16)
|
module.q_proj = LoRALayer(module.q_proj, rank=16)
|
||||||
module.k_proj = LoRALayer(module.k_proj, rank=16)
|
module.k_proj = LoRALayer(module.k_proj, rank=16)
|
||||||
module.v_proj = LoRALayer(module.v_proj, rank=16)
|
module.v_proj = LoRALayer(module.v_proj, rank=16)
|
||||||
|
|||||||
Reference in New Issue
Block a user