Use Python 3.9 syntax in tests (#37343)

Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
cyyever
2025-04-08 20:12:08 +08:00
committed by GitHub
parent 0fc683d1cd
commit 1e6b546ea6
666 changed files with 265 additions and 946 deletions

View File

@@ -430,8 +430,7 @@ class PipelineUtilsTest(unittest.TestCase):
from transformers.pipelines.pt_utils import PipelineIterator
def dummy_dataset():
for i in range(4):
yield i
yield from range(4)
def add(number, extra=0):
return number + extra
@@ -480,8 +479,7 @@ class PipelineUtilsTest(unittest.TestCase):
from transformers.pipelines.pt_utils import PipelineChunkIterator
def preprocess_chunk(n: int):
for i in range(n):
yield i
yield from range(n)
dataset = [2, 3]

View File

@@ -14,7 +14,6 @@
import tempfile
import unittest
from typing import Dict
import datasets
import numpy as np
@@ -65,14 +64,14 @@ def hashimage(image: Image) -> str:
return m.hexdigest()[:10]
def mask_to_test_readable(mask: Image) -> Dict:
def mask_to_test_readable(mask: Image) -> dict:
npimg = np.array(mask)
white_pixels = (npimg == 255).sum()
shape = npimg.shape
return {"hash": hashimage(mask), "white_pixels": white_pixels, "shape": shape}
def mask_to_test_readable_only_shape(mask: Image) -> Dict:
def mask_to_test_readable_only_shape(mask: Image) -> dict:
npimg = np.array(mask)
shape = npimg.shape
return {"shape": shape}

View File

@@ -13,7 +13,6 @@
# limitations under the License.
import unittest
from typing import Dict
import numpy as np
from huggingface_hub.utils import insecure_hashlib
@@ -50,7 +49,7 @@ def hashimage(image: Image) -> str:
return m.hexdigest()[:10]
def mask_to_test_readable(mask: Image) -> Dict:
def mask_to_test_readable(mask: Image) -> dict:
npimg = np.array(mask)
shape = npimg.shape
return {"hash": hashimage(mask), "shape": shape}
@@ -60,11 +59,9 @@ def mask_to_test_readable(mask: Image) -> Dict:
@require_vision
@require_torch
class MaskGenerationPipelineTests(unittest.TestCase):
model_mapping = dict(
(list(MODEL_FOR_MASK_GENERATION_MAPPING.items()) if MODEL_FOR_MASK_GENERATION_MAPPING else [])
)
model_mapping = dict(list(MODEL_FOR_MASK_GENERATION_MAPPING.items()) if MODEL_FOR_MASK_GENERATION_MAPPING else [])
tf_model_mapping = dict(
(list(TF_MODEL_FOR_MASK_GENERATION_MAPPING.items()) if TF_MODEL_FOR_MASK_GENERATION_MAPPING else [])
list(TF_MODEL_FOR_MASK_GENERATION_MAPPING.items()) if TF_MODEL_FOR_MASK_GENERATION_MAPPING else []
)
def get_test_pipeline(

View File

@@ -131,7 +131,7 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase):
model_kwargs={"torch_dtype": torch.float16},
device=torch_device,
)
self.assertEqual(vqa_pipeline.model.device, torch.device("{}:0".format(torch_device)))
self.assertEqual(vqa_pipeline.model.device, torch.device(f"{torch_device}:0"))
self.assertEqual(vqa_pipeline.model.language_model.dtype, torch.float16)
self.assertEqual(vqa_pipeline.model.vision_model.dtype, torch.float16)
@@ -173,7 +173,7 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase):
model_kwargs={"torch_dtype": torch.float16},
device=torch_device,
)
self.assertEqual(vqa_pipeline.model.device, torch.device("{}:0".format(torch_device)))
self.assertEqual(vqa_pipeline.model.device, torch.device(f"{torch_device}:0"))
self.assertEqual(vqa_pipeline.model.language_model.dtype, torch.float16)
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"