Fix require_read_token (#37422)

* nit

* fix

* fix
This commit is contained in:
Mohamed Mekkouri
2025-04-10 17:01:40 +02:00
committed by GitHub
parent bde41d69b4
commit 9c0c323e12
2 changed files with 7 additions and 3 deletions

View File

@@ -156,7 +156,7 @@ class DonutSwinImageClassifierOutput(ModelOutput):
""" """
loss: Optional[torch.FloatTensor] = None loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None logits: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None

View File

@@ -57,7 +57,6 @@ class QuarkTest(unittest.TestCase):
EXPECTED_RELATIVE_DIFFERENCE = 1.66 EXPECTED_RELATIVE_DIFFERENCE = 1.66
device_map = None device_map = None
@require_read_token
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
""" """
@@ -76,11 +75,13 @@ class QuarkTest(unittest.TestCase):
device_map=cls.device_map, device_map=cls.device_map,
) )
@require_read_token
def test_memory_footprint(self): def test_memory_footprint(self):
mem_quantized = self.quantized_model.get_memory_footprint() mem_quantized = self.quantized_model.get_memory_footprint()
self.assertTrue(self.mem_fp16 / mem_quantized > self.EXPECTED_RELATIVE_DIFFERENCE) self.assertTrue(self.mem_fp16 / mem_quantized > self.EXPECTED_RELATIVE_DIFFERENCE)
@require_read_token
def test_device_and_dtype_assignment(self): def test_device_and_dtype_assignment(self):
r""" r"""
Test whether trying to cast (or assigning a device to) a model after quantization will throw an error. Test whether trying to cast (or assigning a device to) a model after quantization will throw an error.
@@ -94,6 +95,7 @@ class QuarkTest(unittest.TestCase):
# Tries with a `dtype`` # Tries with a `dtype``
self.quantized_model.to(torch.float16) self.quantized_model.to(torch.float16)
@require_read_token
def test_original_dtype(self): def test_original_dtype(self):
r""" r"""
A simple test to check if the model succesfully stores the original dtype A simple test to check if the model succesfully stores the original dtype
@@ -104,6 +106,7 @@ class QuarkTest(unittest.TestCase):
self.assertTrue(isinstance(self.quantized_model.model.layers[0].mlp.gate_proj, QParamsLinear)) self.assertTrue(isinstance(self.quantized_model.model.layers[0].mlp.gate_proj, QParamsLinear))
@require_read_token
def check_inference_correctness(self, model): def check_inference_correctness(self, model):
r""" r"""
Test the generation quality of the quantized model and see that we are matching the expected output. Test the generation quality of the quantized model and see that we are matching the expected output.
@@ -127,6 +130,7 @@ class QuarkTest(unittest.TestCase):
# Get the generation # Get the generation
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
@require_read_token
def test_generate_quality(self): def test_generate_quality(self):
""" """
Simple test to check the quality of the model by comparing the generated tokens with the expected tokens Simple test to check the quality of the model by comparing the generated tokens with the expected tokens