[Pixtral] Improve docs, rename model (#33491)

* Improve docs, rename model

* Fix style

* Update repo id
This commit is contained in:
NielsRogge
2024-09-25 13:53:12 +02:00
committed by GitHub
parent c6379858f3
commit 06e27e3dc0
9 changed files with 48 additions and 60 deletions

View File

@@ -21,8 +21,8 @@ import requests
from transformers import (
AutoProcessor,
PixtralModel,
PixtralVisionConfig,
PixtralVisionModel,
is_torch_available,
is_vision_available,
)
@@ -46,7 +46,7 @@ if is_vision_available():
from PIL import Image
class PixtralModelTester:
class PixtralVisionModelTester:
def __init__(
self,
parent,
@@ -107,7 +107,7 @@ class PixtralModelTester:
)
def create_and_check_model(self, config, pixel_values):
model = PixtralModel(config=config)
model = PixtralVisionModel(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
@@ -120,7 +120,7 @@ class PixtralModelTester:
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_model_with_projection(self, config, pixel_values):
model = PixtralModel(config=config)
model = PixtralVisionModel(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
@@ -140,17 +140,17 @@ class PixtralModelTester:
@require_torch
class PixtralModelModelTest(ModelTesterMixin, unittest.TestCase):
class PixtralVisionModelModelTest(ModelTesterMixin, unittest.TestCase):
"""
Model tester for `PixtralModel`.
Model tester for `PixtralVisionModel`.
"""
all_model_classes = (PixtralModel,) if is_torch_available() else ()
all_model_classes = (PixtralVisionModel,) if is_torch_available() else ()
test_pruning = False
test_head_masking = False
def setUp(self):
self.model_tester = PixtralModelTester(self)
self.model_tester = PixtralVisionModelTester(self)
self.config_tester = ConfigTester(self, config_class=PixtralVisionConfig, has_text_modality=False)
@unittest.skip("model does not support input embeds")
@@ -261,7 +261,7 @@ class PixtralModelModelTest(ModelTesterMixin, unittest.TestCase):
@require_torch
class PixtralModelIntegrationTest(unittest.TestCase):
class PixtralVisionModelIntegrationTest(unittest.TestCase):
def setUp(self):
self.processor = AutoProcessor.from_pretrained("hf-internal-testing/pixtral-12b")
@@ -273,7 +273,7 @@ class PixtralModelIntegrationTest(unittest.TestCase):
@require_bitsandbytes
def test_small_model_integration_test(self):
# Let' s make sure we test the preprocessing to replace what is used
model = PixtralModel.from_pretrained("hf-internal-testing/pixtral-12b", load_in_4bit=True)
model = PixtralVisionModel.from_pretrained("hf-internal-testing/pixtral-12b", load_in_4bit=True)
prompt = "<s>[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]"
image_file = "https://pixtral-vl.github.io/static/images/view.jpg"