From 22a69f1d7d520d5fbccbdb163d05db56bf79724c Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Mon, 4 Sep 2023 14:17:09 -0400 Subject: [PATCH] Put Falcon back (#25960) * Put Falcon back * Update src/transformers/models/auto/configuration_auto.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update test --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/models/auto/auto_factory.py | 10 +- .../models/auto/configuration_auto.py | 28 +++- .../models/falcon/configuration_falcon.py | 27 ++++ .../models/falcon/modeling_falcon.py | 33 +++++ tests/models/falcon/test_modeling_falcon.py | 134 +++++++++++++++++- 5 files changed, 228 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index daca460ebb..f9f1abdd5a 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -32,7 +32,12 @@ from ...utils import ( logging, requires_backends, ) -from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings +from .configuration_auto import ( + AutoConfig, + model_type_to_module_name, + replace_list_option_in_docstrings, + sanitize_code_revision, +) logger = logging.get_logger(__name__) @@ -465,6 +470,9 @@ class _BaseAutoModelClass: code_revision = kwargs.pop("code_revision", None) commit_hash = kwargs.pop("_commit_hash", None) + revision = hub_kwargs.pop("revision", None) + hub_kwargs["revision"] = sanitize_code_revision(pretrained_model_name_or_path, revision, trust_remote_code) + token = hub_kwargs.pop("token", None) use_auth_token = hub_kwargs.pop("use_auth_token", None) if use_auth_token is not None: diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 0a3effd795..cf93951c15 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -1016,8 +1016,13 @@ class AutoConfig: kwargs["name_or_path"] = pretrained_model_name_or_path trust_remote_code = kwargs.pop("trust_remote_code", None) code_revision = kwargs.pop("code_revision", None) + revision = kwargs.pop("revision", None) - config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs) + revision = sanitize_code_revision(pretrained_model_name_or_path, revision, trust_remote_code) + + config_dict, unused_kwargs = PretrainedConfig.get_config_dict( + pretrained_model_name_or_path, revision=revision, **kwargs + ) has_remote_code = "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"] has_local_code = "model_type" in config_dict and config_dict["model_type"] in CONFIG_MAPPING trust_remote_code = resolve_trust_remote_code( @@ -1064,3 +1069,24 @@ class AutoConfig: "match!" ) CONFIG_MAPPING.register(model_type, config, exist_ok=exist_ok) + + +def sanitize_code_revision(pretrained_model_name_or_path, revision, trust_remote_code): + if revision in ["main", None] and not trust_remote_code: + revision_dict = { + "tiiuae/falcon-7b": "4e2d06f0a7c6370ebabbc30c6f59377ae8f73d76", + "tiiuae/falcon-7b-instruct": "f8dac3fff96d5debd43edf56fb4e1abcfffbef28", + "tiiuae/falcon-40b": "f1ba7d328c06aa6fbb4a8afd3c756f46d7e6b232", + "tiiuae/falcon-40b-instruct": "7475ff8cfc36ed9a962b658ae3c33391566a85a5", + } + + if isinstance(pretrained_model_name_or_path, str) and pretrained_model_name_or_path.lower() in revision_dict: + revision = revision_dict.get(pretrained_model_name_or_path.lower()) + logger.warning( + "The Falcon model was initialized without `trust_remote_code=True`, and will therefore leverage the " + f"transformers library implementation. {pretrained_model_name_or_path}'s revision is set to a version that doesn't " + f"leverage remote code ({revision}).\n\nIn order to override this, please set a revision manually or set " + "`trust_remote_code=True`." + ) + + return revision diff --git a/src/transformers/models/falcon/configuration_falcon.py b/src/transformers/models/falcon/configuration_falcon.py index eccec82bf8..101ed0d461 100644 --- a/src/transformers/models/falcon/configuration_falcon.py +++ b/src/transformers/models/falcon/configuration_falcon.py @@ -13,8 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Falcon configuration""" +import os +from typing import Optional, Union + from ...configuration_utils import PretrainedConfig from ...utils import logging +from ..auto.configuration_auto import sanitize_code_revision logger = logging.get_logger(__name__) @@ -189,3 +193,26 @@ class FalconConfig(PretrainedConfig): ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + **kwargs, + ) -> "PretrainedConfig": + revision = sanitize_code_revision(pretrained_model_name_or_path, revision, kwargs.get("trust_remote_code")) + + return super().from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + **kwargs, + ) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 608ec3a1b0..7a6f5007d8 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -15,6 +15,7 @@ """PyTorch Falcon model.""" import math +import os from typing import Optional, Tuple, Union import torch @@ -32,6 +33,7 @@ from ...modeling_outputs import ( ) from ...modeling_utils import PreTrainedModel from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from ..auto.configuration_auto import sanitize_code_revision from .configuration_falcon import FalconConfig @@ -723,6 +725,37 @@ class FalconPreTrainedModel(PreTrainedModel): for layer_past in past_key_value ) + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *model_args, + config: Optional[Union[str, os.PathLike]] = None, + cache_dir: Optional[Union[str, os.PathLike]] = None, + ignore_mismatched_sizes: bool = False, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + use_safetensors: bool = None, + **kwargs, + ): + revision = sanitize_code_revision(pretrained_model_name_or_path, revision, kwargs.get("trust_remote_code")) + + return super().from_pretrained( + pretrained_model_name_or_path, + *model_args, + config=config, + cache_dir=cache_dir, + ignore_mismatched_sizes=ignore_mismatched_sizes, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + use_safetensors=use_safetensors, + **kwargs, + ) + @add_start_docstrings( "The bare Falcon Model transformer outputting raw hidden-states without any specific head on top.", diff --git a/tests/models/falcon/test_modeling_falcon.py b/tests/models/falcon/test_modeling_falcon.py index fca4ea21e2..c5f7f7a8f9 100644 --- a/tests/models/falcon/test_modeling_falcon.py +++ b/tests/models/falcon/test_modeling_falcon.py @@ -19,8 +19,9 @@ import unittest from parameterized import parameterized -from transformers import AutoTokenizer, FalconConfig, is_torch_available, set_seed -from transformers.testing_utils import require_torch, slow, torch_device +from transformers import AutoConfig, AutoModel, AutoTokenizer, FalconConfig, is_torch_available, set_seed +from transformers.testing_utils import CaptureLogger, require_torch, slow, tooslow, torch_device +from transformers.utils import logging as transformers_logging from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -500,3 +501,132 @@ class FalconLanguageGenerationTest(unittest.TestCase): outputs_no_cache = model.generate(**inputs, do_sample=False, max_new_tokens=20, use_cache=False) outputs_cache = model.generate(**inputs, do_sample=False, max_new_tokens=20, use_cache=True) self.assertTrue((outputs_cache - outputs_no_cache).sum().item() == 0) + + +# TODO Lysandre: Remove this in version v4.34 +class FalconOverrideTest(unittest.TestCase): + supported_checkpoints = [ + "tiiuae/falcon-7b", + "tiiuae/falcon-7b-instruct", + "tiiuae/falcon-40b", + "tiiuae/falcon-40b-instruct", + ] + + latest_revisions = { + "tiiuae/falcon-7b": "f7796529e36b2d49094450fb038cc7c4c86afa44", + "tiiuae/falcon-7b-instruct": "eb410fb6ffa9028e97adb801f0d6ec46d02f8b07", + "tiiuae/falcon-40b": "561820f7eef0cc56a31ea38af15ca1acb07fab5d", + "tiiuae/falcon-40b-instruct": "ca78eac0ed45bf64445ff0687fabba1598daebf3", + } + + def test_config_without_remote_code(self): + logger_ = transformers_logging.get_logger("transformers.models.auto.configuration_auto") + + for supported_checkpoint in self.supported_checkpoints: + with CaptureLogger(logger_) as cm: + config1 = FalconConfig.from_pretrained(supported_checkpoint, trust_remote_code=False) + config2 = FalconConfig.from_pretrained(supported_checkpoint) + + self.assertIn( + "The Falcon model was initialized without `trust_remote_code=True`, and will therefore leverage the " + "transformers library implementation.", + cm.out, + ) + + self.assertEqual(config1.to_dict(), config2.to_dict()) + + def test_auto_config_without_remote_code(self): + logger_ = transformers_logging.get_logger("transformers.models.auto.configuration_auto") + + for supported_checkpoint in self.supported_checkpoints: + with CaptureLogger(logger_) as cm: + config1 = AutoConfig.from_pretrained(supported_checkpoint, trust_remote_code=False) + config2 = AutoConfig.from_pretrained(supported_checkpoint) + + self.assertIn( + "The Falcon model was initialized without `trust_remote_code=True`, and will therefore leverage the " + "transformers library implementation.", + cm.out, + ) + + self.assertEqual(config1.to_dict(), config2.to_dict()) + + def test_config_with_remote_code(self): + for supported_checkpoint in self.supported_checkpoints: + config = FalconConfig.from_pretrained(supported_checkpoint, trust_remote_code=True) + + self.assertIn(config.model_type, ["RefinedWebModel", "RefinedWeb"]) + + def test_auto_config_with_remote_code(self): + for supported_checkpoint in self.supported_checkpoints: + config = AutoConfig.from_pretrained(supported_checkpoint, trust_remote_code=True) + + self.assertIn(config.model_type, ["RefinedWebModel", "RefinedWeb"]) + + def test_config_with_specific_revision(self): + for supported_checkpoint in self.supported_checkpoints: + config = FalconConfig.from_pretrained( + supported_checkpoint, revision=self.latest_revisions[supported_checkpoint], trust_remote_code=True + ) + + self.assertIn(config.model_type, ["RefinedWebModel", "RefinedWeb"]) + + def test_auto_config_with_specific_revision(self): + for supported_checkpoint in self.supported_checkpoints: + config = AutoConfig.from_pretrained( + supported_checkpoint, revision=self.latest_revisions[supported_checkpoint], trust_remote_code=True + ) + + self.assertIn(config.model_type, ["RefinedWebModel", "RefinedWeb"]) + + @tooslow + def test_model_without_remote_code(self): + logger_ = transformers_logging.get_logger("transformers.models.auto.configuration_auto") + for supported_checkpoint in self.supported_checkpoints: + with CaptureLogger(logger_) as cm: + config1 = FalconModel.from_pretrained(supported_checkpoint, trust_remote_code=False).config + config2 = FalconModel.from_pretrained(supported_checkpoint).config + + # trust_remote_code only works with Auto Classes ! + config3 = FalconModel.from_pretrained(supported_checkpoint, trust_remote_code=True).config + + self.assertIn( + "The Falcon model was initialized without `trust_remote_code=True`, and will therefore leverage the " + "transformers library implementation.", + cm.out, + ) + + self.assertEqual(config1.to_dict(), config2.to_dict()) + self.assertEqual(config1.to_dict(), config3.to_dict()) + + @tooslow + def test_auto_model_without_remote_code(self): + logger_ = transformers_logging.get_logger("transformers.models.auto.configuration_auto") + for supported_checkpoint in self.supported_checkpoints: + with CaptureLogger(logger_) as cm: + config1 = AutoModel.from_pretrained(supported_checkpoint, trust_remote_code=False).config + config2 = AutoModel.from_pretrained(supported_checkpoint).config + + self.assertIn( + "The Falcon model was initialized without `trust_remote_code=True`, and will therefore leverage the " + "transformers library implementation.", + cm.out, + ) + + self.assertEqual(config1.to_dict(), config2.to_dict()) + + @tooslow + def test_auto_model_with_remote_code(self): + for supported_checkpoint in self.supported_checkpoints: + config = AutoModel.from_pretrained(supported_checkpoint, trust_remote_code=True).config + + self.assertIn(config.model_type, ["RefinedWebModel", "RefinedWeb"]) + + @tooslow + def test_auto_model_with_specific_revision(self): + for supported_checkpoint in self.supported_checkpoints: + config = AutoModel.from_pretrained( + supported_checkpoint, revision=self.latest_revisions[supported_checkpoint], trust_remote_code=True + ).config + + self.assertIn(config.model_type, ["RefinedWebModel", "RefinedWeb"])