Chameleon: minor fixes after shipping (#32037)
* fix merging * make chameleon conditional
This commit is contained in:
committed by
GitHub
parent
765732e92c
commit
673d30b826
@@ -44,7 +44,7 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
ChameleonForCausalLM,
|
||||
ChameleonForConditionalGeneration,
|
||||
ChameleonModel,
|
||||
ChameleonProcessor,
|
||||
)
|
||||
@@ -191,7 +191,7 @@ class ChameleonModelTester:
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
model = ChameleonForCausalLM(config=config)
|
||||
model = ChameleonForConditionalGeneration(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
@@ -209,7 +209,7 @@ class ChameleonModelTester:
|
||||
encoder_attention_mask,
|
||||
):
|
||||
config.is_decoder = True
|
||||
model = ChameleonForCausalLM(config=config)
|
||||
model = ChameleonForConditionalGeneration(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
@@ -273,12 +273,12 @@ class ChameleonModelTester:
|
||||
|
||||
@require_torch
|
||||
class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (ChameleonModel, ChameleonForCausalLM) if is_torch_available() else ()
|
||||
all_generative_model_classes = (ChameleonForCausalLM,) if is_torch_available() else ()
|
||||
all_model_classes = (ChameleonModel, ChameleonForConditionalGeneration) if is_torch_available() else ()
|
||||
all_generative_model_classes = (ChameleonForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": ChameleonModel,
|
||||
"text-generation": ChameleonForCausalLM,
|
||||
"text-generation": ChameleonForConditionalGeneration,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
@@ -339,7 +339,7 @@ class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
"""
|
||||
Overwritting the common test as the test is flaky on tiny models
|
||||
"""
|
||||
model = ChameleonForCausalLM.from_pretrained(
|
||||
model = ChameleonForConditionalGeneration.from_pretrained(
|
||||
"facebook/chameleon-7b",
|
||||
load_in_4bit=True,
|
||||
device_map={"": 0},
|
||||
@@ -355,7 +355,7 @@ class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_native = processor.tokenizer.batch_decode(output_native)
|
||||
|
||||
model = ChameleonForCausalLM.from_pretrained(
|
||||
model = ChameleonForConditionalGeneration.from_pretrained(
|
||||
"facebook/chameleon-7b",
|
||||
load_in_4bit=True,
|
||||
attn_implementation="flash_attention_2",
|
||||
@@ -377,7 +377,9 @@ class ChameleonIntegrationTest(unittest.TestCase):
|
||||
@require_bitsandbytes
|
||||
@require_read_token
|
||||
def test_model_7b(self):
|
||||
model = ChameleonForCausalLM.from_pretrained("facebook/chameleon-7b", load_in_4bit=True, device_map="auto")
|
||||
model = ChameleonForConditionalGeneration.from_pretrained(
|
||||
"facebook/chameleon-7b", load_in_4bit=True, device_map="auto"
|
||||
)
|
||||
processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
|
||||
|
||||
image = Image.open(
|
||||
@@ -397,7 +399,9 @@ class ChameleonIntegrationTest(unittest.TestCase):
|
||||
@require_bitsandbytes
|
||||
@require_read_token
|
||||
def test_model_7b_batched(self):
|
||||
model = ChameleonForCausalLM.from_pretrained("facebook/chameleon-7b", load_in_4bit=True, device_map="auto")
|
||||
model = ChameleonForConditionalGeneration.from_pretrained(
|
||||
"facebook/chameleon-7b", load_in_4bit=True, device_map="auto"
|
||||
)
|
||||
processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
|
||||
|
||||
image = Image.open(
|
||||
@@ -428,7 +432,9 @@ class ChameleonIntegrationTest(unittest.TestCase):
|
||||
@require_bitsandbytes
|
||||
@require_read_token
|
||||
def test_model_7b_multi_image(self):
|
||||
model = ChameleonForCausalLM.from_pretrained("facebook/chameleon-7b", load_in_4bit=True, device_map="auto")
|
||||
model = ChameleonForConditionalGeneration.from_pretrained(
|
||||
"facebook/chameleon-7b", load_in_4bit=True, device_map="auto"
|
||||
)
|
||||
processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
|
||||
|
||||
image = Image.open(
|
||||
|
||||
Reference in New Issue
Block a user