@@ -156,7 +156,7 @@ class DonutSwinImageClassifierOutput(ModelOutput):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
loss: Optional[torch.FloatTensor] = None
|
loss: Optional[torch.FloatTensor] = None
|
||||||
logits: torch.FloatTensor = None
|
logits: Optional[torch.FloatTensor] = None
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||||
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||||
|
|||||||
@@ -57,7 +57,6 @@ class QuarkTest(unittest.TestCase):
|
|||||||
EXPECTED_RELATIVE_DIFFERENCE = 1.66
|
EXPECTED_RELATIVE_DIFFERENCE = 1.66
|
||||||
device_map = None
|
device_map = None
|
||||||
|
|
||||||
@require_read_token
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
"""
|
"""
|
||||||
@@ -76,15 +75,17 @@ class QuarkTest(unittest.TestCase):
|
|||||||
device_map=cls.device_map,
|
device_map=cls.device_map,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_read_token
|
||||||
def test_memory_footprint(self):
|
def test_memory_footprint(self):
|
||||||
mem_quantized = self.quantized_model.get_memory_footprint()
|
mem_quantized = self.quantized_model.get_memory_footprint()
|
||||||
|
|
||||||
self.assertTrue(self.mem_fp16 / mem_quantized > self.EXPECTED_RELATIVE_DIFFERENCE)
|
self.assertTrue(self.mem_fp16 / mem_quantized > self.EXPECTED_RELATIVE_DIFFERENCE)
|
||||||
|
|
||||||
|
@require_read_token
|
||||||
def test_device_and_dtype_assignment(self):
|
def test_device_and_dtype_assignment(self):
|
||||||
r"""
|
r"""
|
||||||
Test whether trying to cast (or assigning a device to) a model after quantization will throw an error.
|
Test whether trying to cast (or assigning a device to) a model after quantization will throw an error.
|
||||||
Checks also if other models are casted correctly.
|
Checks also if other models are casted correctly .
|
||||||
"""
|
"""
|
||||||
# This should work
|
# This should work
|
||||||
if self.device_map is None:
|
if self.device_map is None:
|
||||||
@@ -94,6 +95,7 @@ class QuarkTest(unittest.TestCase):
|
|||||||
# Tries with a `dtype``
|
# Tries with a `dtype``
|
||||||
self.quantized_model.to(torch.float16)
|
self.quantized_model.to(torch.float16)
|
||||||
|
|
||||||
|
@require_read_token
|
||||||
def test_original_dtype(self):
|
def test_original_dtype(self):
|
||||||
r"""
|
r"""
|
||||||
A simple test to check if the model succesfully stores the original dtype
|
A simple test to check if the model succesfully stores the original dtype
|
||||||
@@ -104,6 +106,7 @@ class QuarkTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertTrue(isinstance(self.quantized_model.model.layers[0].mlp.gate_proj, QParamsLinear))
|
self.assertTrue(isinstance(self.quantized_model.model.layers[0].mlp.gate_proj, QParamsLinear))
|
||||||
|
|
||||||
|
@require_read_token
|
||||||
def check_inference_correctness(self, model):
|
def check_inference_correctness(self, model):
|
||||||
r"""
|
r"""
|
||||||
Test the generation quality of the quantized model and see that we are matching the expected output.
|
Test the generation quality of the quantized model and see that we are matching the expected output.
|
||||||
@@ -127,6 +130,7 @@ class QuarkTest(unittest.TestCase):
|
|||||||
# Get the generation
|
# Get the generation
|
||||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||||
|
|
||||||
|
@require_read_token
|
||||||
def test_generate_quality(self):
|
def test_generate_quality(self):
|
||||||
"""
|
"""
|
||||||
Simple test to check the quality of the model by comparing the generated tokens with the expected tokens
|
Simple test to check the quality of the model by comparing the generated tokens with the expected tokens
|
||||||
|
|||||||
Reference in New Issue
Block a user