Supporting seq2seq models for bitsandbytes integration (#18579)

* Supporting seq2seq models for `bitsandbytes` integration

- `bitsandbytes` integration supports now seq2seq models
- check if a model has tied weights as an additional check

* small modification

- tie the weights before looking at tied weights!
This commit is contained in:
Younes Belkada
2022-08-12 16:15:09 +02:00
committed by GitHub
parent ed1924e801
commit a5ca56ff15
2 changed files with 33 additions and 3 deletions

View File

@@ -15,7 +15,14 @@
import gc
import unittest
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, pipeline
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoTokenizer,
pipeline,
)
from transformers.testing_utils import (
is_torch_available,
require_accelerate,
@@ -106,12 +113,21 @@ class MixedInt8ModelClassesTest(BaseMixedInt8Test):
super().setUp()
# model_name
self.model_name = "bigscience/bloom-560m"
# Models and tokenizer
self.seq_to_seq_name = "t5-small"
# Different types of model
self.base_model = AutoModel.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
# Sequence classification model
self.sequence_model = AutoModelForSequenceClassification.from_pretrained(
self.model_name, load_in_8bit=True, device_map="auto"
)
# CausalLM model
self.model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
# Seq2seq model
self.seq_to_seq_model = AutoModelForSeq2SeqLM.from_pretrained(
self.seq_to_seq_name, load_in_8bit=True, device_map="auto"
)
def tearDown(self):
r"""
@@ -121,6 +137,7 @@ class MixedInt8ModelClassesTest(BaseMixedInt8Test):
del self.base_model
del self.sequence_model
del self.model_8bit
del self.seq_to_seq_model
gc.collect()
torch.cuda.empty_cache()
@@ -138,6 +155,7 @@ class MixedInt8ModelClassesTest(BaseMixedInt8Test):
# Other heads should be nn.Parameter
self.assertTrue(self.model_8bit.lm_head.weight.__class__ == torch.nn.Parameter)
self.assertTrue(self.sequence_model.score.weight.__class__ == torch.nn.Parameter)
self.assertTrue(self.seq_to_seq_model.lm_head.weight.__class__ == torch.nn.Parameter)
class MixedInt8TestPipeline(BaseMixedInt8Test):