Chameleon: minor fixes after shipping (#32037)

* fix merging

* make chameleon conditional
This commit is contained in:
Raushan Turganbay
2024-07-18 16:54:07 +05:00
committed by GitHub
parent 765732e92c
commit 673d30b826
7 changed files with 38 additions and 31 deletions

View File

@@ -44,7 +44,7 @@ if is_torch_available():
import torch
from transformers import (
ChameleonForCausalLM,
ChameleonForConditionalGeneration,
ChameleonModel,
ChameleonProcessor,
)
@@ -191,7 +191,7 @@ class ChameleonModelTester:
encoder_hidden_states,
encoder_attention_mask,
):
model = ChameleonForCausalLM(config=config)
model = ChameleonForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
@@ -209,7 +209,7 @@ class ChameleonModelTester:
encoder_attention_mask,
):
config.is_decoder = True
model = ChameleonForCausalLM(config=config)
model = ChameleonForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()
@@ -273,12 +273,12 @@ class ChameleonModelTester:
@require_torch
class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (ChameleonModel, ChameleonForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (ChameleonForCausalLM,) if is_torch_available() else ()
all_model_classes = (ChameleonModel, ChameleonForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (ChameleonForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": ChameleonModel,
"text-generation": ChameleonForCausalLM,
"text-generation": ChameleonForConditionalGeneration,
}
if is_torch_available()
else {}
@@ -339,7 +339,7 @@ class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
"""
Overwritting the common test as the test is flaky on tiny models
"""
model = ChameleonForCausalLM.from_pretrained(
model = ChameleonForConditionalGeneration.from_pretrained(
"facebook/chameleon-7b",
load_in_4bit=True,
device_map={"": 0},
@@ -355,7 +355,7 @@ class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_native = processor.tokenizer.batch_decode(output_native)
model = ChameleonForCausalLM.from_pretrained(
model = ChameleonForConditionalGeneration.from_pretrained(
"facebook/chameleon-7b",
load_in_4bit=True,
attn_implementation="flash_attention_2",
@@ -377,7 +377,9 @@ class ChameleonIntegrationTest(unittest.TestCase):
@require_bitsandbytes
@require_read_token
def test_model_7b(self):
model = ChameleonForCausalLM.from_pretrained("facebook/chameleon-7b", load_in_4bit=True, device_map="auto")
model = ChameleonForConditionalGeneration.from_pretrained(
"facebook/chameleon-7b", load_in_4bit=True, device_map="auto"
)
processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
image = Image.open(
@@ -397,7 +399,9 @@ class ChameleonIntegrationTest(unittest.TestCase):
@require_bitsandbytes
@require_read_token
def test_model_7b_batched(self):
model = ChameleonForCausalLM.from_pretrained("facebook/chameleon-7b", load_in_4bit=True, device_map="auto")
model = ChameleonForConditionalGeneration.from_pretrained(
"facebook/chameleon-7b", load_in_4bit=True, device_map="auto"
)
processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
image = Image.open(
@@ -428,7 +432,9 @@ class ChameleonIntegrationTest(unittest.TestCase):
@require_bitsandbytes
@require_read_token
def test_model_7b_multi_image(self):
model = ChameleonForCausalLM.from_pretrained("facebook/chameleon-7b", load_in_4bit=True, device_map="auto")
model = ChameleonForConditionalGeneration.from_pretrained(
"facebook/chameleon-7b", load_in_4bit=True, device_map="auto"
)
processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
image = Image.open(