From a5ca56ff158075351149220319c14dde555a86f5 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 12 Aug 2022 16:15:09 +0200 Subject: [PATCH] Supporting seq2seq models for `bitsandbytes` integration (#18579) * Supporting seq2seq models for `bitsandbytes` integration - `bitsandbytes` integration supports now seq2seq models - check if a model has tied weights as an additional check * small modification - tie the weights before looking at tied weights! --- src/transformers/utils/bitsandbytes.py | 14 +++++++++++++- tests/mixed_int8/test_mixed_int8.py | 22 ++++++++++++++++++++-- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index ee4e52d421..eca605b2ed 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -1,3 +1,5 @@ +from copy import deepcopy + from transformers.utils import is_accelerate_available, is_bitsandbytes_available @@ -9,6 +11,7 @@ if is_bitsandbytes_available(): if is_accelerate_available(): from accelerate import init_empty_weights + from accelerate.utils import find_tied_parameters def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None): @@ -132,8 +135,17 @@ def get_key_to_not_convert(model): model (`torch.nn.Module`): Input model """ + # Create a copy of the model and tie the weights, then + # check if it contains tied weights + tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager` + tied_model.tie_weights() + has_tied_params = len(find_tied_parameters(tied_model)) > 0 + + # Check if it is a base model + is_base_model = not hasattr(model, model.base_model_prefix) + # Ignore this for base models (BertModel, GPT2Model, etc.) - if not hasattr(model, model.base_model_prefix): + if (not has_tied_params) and is_base_model: return "" # otherwise they have an attached head diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index 0cd7ca1641..2911d67748 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -15,7 +15,14 @@ import gc import unittest -from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, pipeline +from transformers import ( + AutoModel, + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + AutoModelForSequenceClassification, + AutoTokenizer, + pipeline, +) from transformers.testing_utils import ( is_torch_available, require_accelerate, @@ -106,12 +113,21 @@ class MixedInt8ModelClassesTest(BaseMixedInt8Test): super().setUp() # model_name self.model_name = "bigscience/bloom-560m" - # Models and tokenizer + self.seq_to_seq_name = "t5-small" + + # Different types of model + self.base_model = AutoModel.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") + # Sequence classification model self.sequence_model = AutoModelForSequenceClassification.from_pretrained( self.model_name, load_in_8bit=True, device_map="auto" ) + # CausalLM model self.model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") + # Seq2seq model + self.seq_to_seq_model = AutoModelForSeq2SeqLM.from_pretrained( + self.seq_to_seq_name, load_in_8bit=True, device_map="auto" + ) def tearDown(self): r""" @@ -121,6 +137,7 @@ class MixedInt8ModelClassesTest(BaseMixedInt8Test): del self.base_model del self.sequence_model del self.model_8bit + del self.seq_to_seq_model gc.collect() torch.cuda.empty_cache() @@ -138,6 +155,7 @@ class MixedInt8ModelClassesTest(BaseMixedInt8Test): # Other heads should be nn.Parameter self.assertTrue(self.model_8bit.lm_head.weight.__class__ == torch.nn.Parameter) self.assertTrue(self.sequence_model.score.weight.__class__ == torch.nn.Parameter) + self.assertTrue(self.seq_to_seq_model.lm_head.weight.__class__ == torch.nn.Parameter) class MixedInt8TestPipeline(BaseMixedInt8Test):