Supporting seq2seq models for bitsandbytes integration (#18579)
* Supporting seq2seq models for `bitsandbytes` integration - `bitsandbytes` integration supports now seq2seq models - check if a model has tied weights as an additional check * small modification - tie the weights before looking at tied weights!
This commit is contained in:
@@ -1,3 +1,5 @@
|
|||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
from transformers.utils import is_accelerate_available, is_bitsandbytes_available
|
from transformers.utils import is_accelerate_available, is_bitsandbytes_available
|
||||||
|
|
||||||
|
|
||||||
@@ -9,6 +11,7 @@ if is_bitsandbytes_available():
|
|||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
|
from accelerate.utils import find_tied_parameters
|
||||||
|
|
||||||
|
|
||||||
def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None):
|
def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None):
|
||||||
@@ -132,8 +135,17 @@ def get_key_to_not_convert(model):
|
|||||||
model (`torch.nn.Module`):
|
model (`torch.nn.Module`):
|
||||||
Input model
|
Input model
|
||||||
"""
|
"""
|
||||||
|
# Create a copy of the model and tie the weights, then
|
||||||
|
# check if it contains tied weights
|
||||||
|
tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
|
||||||
|
tied_model.tie_weights()
|
||||||
|
has_tied_params = len(find_tied_parameters(tied_model)) > 0
|
||||||
|
|
||||||
|
# Check if it is a base model
|
||||||
|
is_base_model = not hasattr(model, model.base_model_prefix)
|
||||||
|
|
||||||
# Ignore this for base models (BertModel, GPT2Model, etc.)
|
# Ignore this for base models (BertModel, GPT2Model, etc.)
|
||||||
if not hasattr(model, model.base_model_prefix):
|
if (not has_tied_params) and is_base_model:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# otherwise they have an attached head
|
# otherwise they have an attached head
|
||||||
|
|||||||
@@ -15,7 +15,14 @@
|
|||||||
import gc
|
import gc
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, pipeline
|
from transformers import (
|
||||||
|
AutoModel,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoModelForSeq2SeqLM,
|
||||||
|
AutoModelForSequenceClassification,
|
||||||
|
AutoTokenizer,
|
||||||
|
pipeline,
|
||||||
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
@@ -106,12 +113,21 @@ class MixedInt8ModelClassesTest(BaseMixedInt8Test):
|
|||||||
super().setUp()
|
super().setUp()
|
||||||
# model_name
|
# model_name
|
||||||
self.model_name = "bigscience/bloom-560m"
|
self.model_name = "bigscience/bloom-560m"
|
||||||
# Models and tokenizer
|
self.seq_to_seq_name = "t5-small"
|
||||||
|
|
||||||
|
# Different types of model
|
||||||
|
|
||||||
self.base_model = AutoModel.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
|
self.base_model = AutoModel.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
|
||||||
|
# Sequence classification model
|
||||||
self.sequence_model = AutoModelForSequenceClassification.from_pretrained(
|
self.sequence_model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
self.model_name, load_in_8bit=True, device_map="auto"
|
self.model_name, load_in_8bit=True, device_map="auto"
|
||||||
)
|
)
|
||||||
|
# CausalLM model
|
||||||
self.model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
|
self.model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
|
||||||
|
# Seq2seq model
|
||||||
|
self.seq_to_seq_model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||||
|
self.seq_to_seq_name, load_in_8bit=True, device_map="auto"
|
||||||
|
)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
r"""
|
r"""
|
||||||
@@ -121,6 +137,7 @@ class MixedInt8ModelClassesTest(BaseMixedInt8Test):
|
|||||||
del self.base_model
|
del self.base_model
|
||||||
del self.sequence_model
|
del self.sequence_model
|
||||||
del self.model_8bit
|
del self.model_8bit
|
||||||
|
del self.seq_to_seq_model
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@@ -138,6 +155,7 @@ class MixedInt8ModelClassesTest(BaseMixedInt8Test):
|
|||||||
# Other heads should be nn.Parameter
|
# Other heads should be nn.Parameter
|
||||||
self.assertTrue(self.model_8bit.lm_head.weight.__class__ == torch.nn.Parameter)
|
self.assertTrue(self.model_8bit.lm_head.weight.__class__ == torch.nn.Parameter)
|
||||||
self.assertTrue(self.sequence_model.score.weight.__class__ == torch.nn.Parameter)
|
self.assertTrue(self.sequence_model.score.weight.__class__ == torch.nn.Parameter)
|
||||||
|
self.assertTrue(self.seq_to_seq_model.lm_head.weight.__class__ == torch.nn.Parameter)
|
||||||
|
|
||||||
|
|
||||||
class MixedInt8TestPipeline(BaseMixedInt8Test):
|
class MixedInt8TestPipeline(BaseMixedInt8Test):
|
||||||
|
|||||||
Reference in New Issue
Block a user