From 5ba2dbd9b1058c824909c188b1b90a2e362ba96d Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Mon, 12 Dec 2022 15:37:43 +0100 Subject: [PATCH] Fix `AutoModelTest.test_model_from_pretrained` (#20730) Co-authored-by: ydshieh --- tests/models/auto/test_modeling_auto.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/models/auto/test_modeling_auto.py b/tests/models/auto/test_modeling_auto.py index ca45ae78ec..c59abe4cd4 100644 --- a/tests/models/auto/test_modeling_auto.py +++ b/tests/models/auto/test_modeling_auto.py @@ -22,7 +22,7 @@ from pathlib import Path import pytest -from transformers import BertConfig, GPT2Model, is_torch_available +from transformers import BertConfig, GPT2Model, is_safetensors_available, is_torch_available from transformers.models.auto.configuration_auto import CONFIG_MAPPING from transformers.testing_utils import ( DUMMY_UNKNOWN_IDENTIFIER, @@ -102,7 +102,10 @@ class AutoModelTest(unittest.TestCase): self.assertIsInstance(model, BertModel) self.assertEqual(len(loading_info["missing_keys"]), 0) - self.assertEqual(len(loading_info["unexpected_keys"]), 8) + # When using PyTorch checkpoint, the expected value is `8`. With `safetensors` checkpoint (if it is + # installed), the expected value becomes `7`. + EXPECTED_NUM_OF_UNEXPECTED_KEYS = 7 if is_safetensors_available() else 8 + self.assertEqual(len(loading_info["unexpected_keys"]), EXPECTED_NUM_OF_UNEXPECTED_KEYS) self.assertEqual(len(loading_info["mismatched_keys"]), 0) self.assertEqual(len(loading_info["error_msgs"]), 0)