[Pixtral] Improve docs, rename model (#33491)
* Improve docs, rename model * Fix style * Update repo id
This commit is contained in:
@@ -21,8 +21,8 @@ import requests
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
PixtralModel,
|
||||
PixtralVisionConfig,
|
||||
PixtralVisionModel,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
@@ -46,7 +46,7 @@ if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class PixtralModelTester:
|
||||
class PixtralVisionModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
@@ -107,7 +107,7 @@ class PixtralModelTester:
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values):
|
||||
model = PixtralModel(config=config)
|
||||
model = PixtralVisionModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
@@ -120,7 +120,7 @@ class PixtralModelTester:
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
||||
|
||||
def create_and_check_model_with_projection(self, config, pixel_values):
|
||||
model = PixtralModel(config=config)
|
||||
model = PixtralVisionModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
@@ -140,17 +140,17 @@ class PixtralModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class PixtralModelModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class PixtralVisionModelModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Model tester for `PixtralModel`.
|
||||
Model tester for `PixtralVisionModel`.
|
||||
"""
|
||||
|
||||
all_model_classes = (PixtralModel,) if is_torch_available() else ()
|
||||
all_model_classes = (PixtralVisionModel,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = PixtralModelTester(self)
|
||||
self.model_tester = PixtralVisionModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=PixtralVisionConfig, has_text_modality=False)
|
||||
|
||||
@unittest.skip("model does not support input embeds")
|
||||
@@ -261,7 +261,7 @@ class PixtralModelModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
|
||||
@require_torch
|
||||
class PixtralModelIntegrationTest(unittest.TestCase):
|
||||
class PixtralVisionModelIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.processor = AutoProcessor.from_pretrained("hf-internal-testing/pixtral-12b")
|
||||
|
||||
@@ -273,7 +273,7 @@ class PixtralModelIntegrationTest(unittest.TestCase):
|
||||
@require_bitsandbytes
|
||||
def test_small_model_integration_test(self):
|
||||
# Let' s make sure we test the preprocessing to replace what is used
|
||||
model = PixtralModel.from_pretrained("hf-internal-testing/pixtral-12b", load_in_4bit=True)
|
||||
model = PixtralVisionModel.from_pretrained("hf-internal-testing/pixtral-12b", load_in_4bit=True)
|
||||
|
||||
prompt = "<s>[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]"
|
||||
image_file = "https://pixtral-vl.github.io/static/images/view.jpg"
|
||||
|
||||
Reference in New Issue
Block a user