Put Falcon back (#25960)
* Put Falcon back * Update src/transformers/models/auto/configuration_auto.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update test --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -32,7 +32,12 @@ from ...utils import (
|
|||||||
logging,
|
logging,
|
||||||
requires_backends,
|
requires_backends,
|
||||||
)
|
)
|
||||||
from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings
|
from .configuration_auto import (
|
||||||
|
AutoConfig,
|
||||||
|
model_type_to_module_name,
|
||||||
|
replace_list_option_in_docstrings,
|
||||||
|
sanitize_code_revision,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -465,6 +470,9 @@ class _BaseAutoModelClass:
|
|||||||
code_revision = kwargs.pop("code_revision", None)
|
code_revision = kwargs.pop("code_revision", None)
|
||||||
commit_hash = kwargs.pop("_commit_hash", None)
|
commit_hash = kwargs.pop("_commit_hash", None)
|
||||||
|
|
||||||
|
revision = hub_kwargs.pop("revision", None)
|
||||||
|
hub_kwargs["revision"] = sanitize_code_revision(pretrained_model_name_or_path, revision, trust_remote_code)
|
||||||
|
|
||||||
token = hub_kwargs.pop("token", None)
|
token = hub_kwargs.pop("token", None)
|
||||||
use_auth_token = hub_kwargs.pop("use_auth_token", None)
|
use_auth_token = hub_kwargs.pop("use_auth_token", None)
|
||||||
if use_auth_token is not None:
|
if use_auth_token is not None:
|
||||||
|
|||||||
@@ -1016,8 +1016,13 @@ class AutoConfig:
|
|||||||
kwargs["name_or_path"] = pretrained_model_name_or_path
|
kwargs["name_or_path"] = pretrained_model_name_or_path
|
||||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||||
code_revision = kwargs.pop("code_revision", None)
|
code_revision = kwargs.pop("code_revision", None)
|
||||||
|
revision = kwargs.pop("revision", None)
|
||||||
|
|
||||||
config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
revision = sanitize_code_revision(pretrained_model_name_or_path, revision, trust_remote_code)
|
||||||
|
|
||||||
|
config_dict, unused_kwargs = PretrainedConfig.get_config_dict(
|
||||||
|
pretrained_model_name_or_path, revision=revision, **kwargs
|
||||||
|
)
|
||||||
has_remote_code = "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]
|
has_remote_code = "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]
|
||||||
has_local_code = "model_type" in config_dict and config_dict["model_type"] in CONFIG_MAPPING
|
has_local_code = "model_type" in config_dict and config_dict["model_type"] in CONFIG_MAPPING
|
||||||
trust_remote_code = resolve_trust_remote_code(
|
trust_remote_code = resolve_trust_remote_code(
|
||||||
@@ -1064,3 +1069,24 @@ class AutoConfig:
|
|||||||
"match!"
|
"match!"
|
||||||
)
|
)
|
||||||
CONFIG_MAPPING.register(model_type, config, exist_ok=exist_ok)
|
CONFIG_MAPPING.register(model_type, config, exist_ok=exist_ok)
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_code_revision(pretrained_model_name_or_path, revision, trust_remote_code):
|
||||||
|
if revision in ["main", None] and not trust_remote_code:
|
||||||
|
revision_dict = {
|
||||||
|
"tiiuae/falcon-7b": "4e2d06f0a7c6370ebabbc30c6f59377ae8f73d76",
|
||||||
|
"tiiuae/falcon-7b-instruct": "f8dac3fff96d5debd43edf56fb4e1abcfffbef28",
|
||||||
|
"tiiuae/falcon-40b": "f1ba7d328c06aa6fbb4a8afd3c756f46d7e6b232",
|
||||||
|
"tiiuae/falcon-40b-instruct": "7475ff8cfc36ed9a962b658ae3c33391566a85a5",
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(pretrained_model_name_or_path, str) and pretrained_model_name_or_path.lower() in revision_dict:
|
||||||
|
revision = revision_dict.get(pretrained_model_name_or_path.lower())
|
||||||
|
logger.warning(
|
||||||
|
"The Falcon model was initialized without `trust_remote_code=True`, and will therefore leverage the "
|
||||||
|
f"transformers library implementation. {pretrained_model_name_or_path}'s revision is set to a version that doesn't "
|
||||||
|
f"leverage remote code ({revision}).\n\nIn order to override this, please set a revision manually or set "
|
||||||
|
"`trust_remote_code=True`."
|
||||||
|
)
|
||||||
|
|
||||||
|
return revision
|
||||||
|
|||||||
@@ -13,8 +13,12 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Falcon configuration"""
|
""" Falcon configuration"""
|
||||||
|
import os
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
from ..auto.configuration_auto import sanitize_code_revision
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -189,3 +193,26 @@ class FalconConfig(PretrainedConfig):
|
|||||||
)
|
)
|
||||||
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
||||||
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
|
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||||
|
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||||
|
force_download: bool = False,
|
||||||
|
local_files_only: bool = False,
|
||||||
|
token: Optional[Union[str, bool]] = None,
|
||||||
|
revision: str = "main",
|
||||||
|
**kwargs,
|
||||||
|
) -> "PretrainedConfig":
|
||||||
|
revision = sanitize_code_revision(pretrained_model_name_or_path, revision, kwargs.get("trust_remote_code"))
|
||||||
|
|
||||||
|
return super().from_pretrained(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
token=token,
|
||||||
|
revision=revision,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
"""PyTorch Falcon model."""
|
"""PyTorch Falcon model."""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -32,6 +33,7 @@ from ...modeling_outputs import (
|
|||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||||
|
from ..auto.configuration_auto import sanitize_code_revision
|
||||||
from .configuration_falcon import FalconConfig
|
from .configuration_falcon import FalconConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -723,6 +725,37 @@ class FalconPreTrainedModel(PreTrainedModel):
|
|||||||
for layer_past in past_key_value
|
for layer_past in past_key_value
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
||||||
|
*model_args,
|
||||||
|
config: Optional[Union[str, os.PathLike]] = None,
|
||||||
|
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||||
|
ignore_mismatched_sizes: bool = False,
|
||||||
|
force_download: bool = False,
|
||||||
|
local_files_only: bool = False,
|
||||||
|
token: Optional[Union[str, bool]] = None,
|
||||||
|
revision: str = "main",
|
||||||
|
use_safetensors: bool = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
revision = sanitize_code_revision(pretrained_model_name_or_path, revision, kwargs.get("trust_remote_code"))
|
||||||
|
|
||||||
|
return super().from_pretrained(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
*model_args,
|
||||||
|
config=config,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||||
|
force_download=force_download,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
token=token,
|
||||||
|
revision=revision,
|
||||||
|
use_safetensors=use_safetensors,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"The bare Falcon Model transformer outputting raw hidden-states without any specific head on top.",
|
"The bare Falcon Model transformer outputting raw hidden-states without any specific head on top.",
|
||||||
|
|||||||
@@ -19,8 +19,9 @@ import unittest
|
|||||||
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import AutoTokenizer, FalconConfig, is_torch_available, set_seed
|
from transformers import AutoConfig, AutoModel, AutoTokenizer, FalconConfig, is_torch_available, set_seed
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import CaptureLogger, require_torch, slow, tooslow, torch_device
|
||||||
|
from transformers.utils import logging as transformers_logging
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@@ -500,3 +501,132 @@ class FalconLanguageGenerationTest(unittest.TestCase):
|
|||||||
outputs_no_cache = model.generate(**inputs, do_sample=False, max_new_tokens=20, use_cache=False)
|
outputs_no_cache = model.generate(**inputs, do_sample=False, max_new_tokens=20, use_cache=False)
|
||||||
outputs_cache = model.generate(**inputs, do_sample=False, max_new_tokens=20, use_cache=True)
|
outputs_cache = model.generate(**inputs, do_sample=False, max_new_tokens=20, use_cache=True)
|
||||||
self.assertTrue((outputs_cache - outputs_no_cache).sum().item() == 0)
|
self.assertTrue((outputs_cache - outputs_no_cache).sum().item() == 0)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO Lysandre: Remove this in version v4.34
|
||||||
|
class FalconOverrideTest(unittest.TestCase):
|
||||||
|
supported_checkpoints = [
|
||||||
|
"tiiuae/falcon-7b",
|
||||||
|
"tiiuae/falcon-7b-instruct",
|
||||||
|
"tiiuae/falcon-40b",
|
||||||
|
"tiiuae/falcon-40b-instruct",
|
||||||
|
]
|
||||||
|
|
||||||
|
latest_revisions = {
|
||||||
|
"tiiuae/falcon-7b": "f7796529e36b2d49094450fb038cc7c4c86afa44",
|
||||||
|
"tiiuae/falcon-7b-instruct": "eb410fb6ffa9028e97adb801f0d6ec46d02f8b07",
|
||||||
|
"tiiuae/falcon-40b": "561820f7eef0cc56a31ea38af15ca1acb07fab5d",
|
||||||
|
"tiiuae/falcon-40b-instruct": "ca78eac0ed45bf64445ff0687fabba1598daebf3",
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_config_without_remote_code(self):
|
||||||
|
logger_ = transformers_logging.get_logger("transformers.models.auto.configuration_auto")
|
||||||
|
|
||||||
|
for supported_checkpoint in self.supported_checkpoints:
|
||||||
|
with CaptureLogger(logger_) as cm:
|
||||||
|
config1 = FalconConfig.from_pretrained(supported_checkpoint, trust_remote_code=False)
|
||||||
|
config2 = FalconConfig.from_pretrained(supported_checkpoint)
|
||||||
|
|
||||||
|
self.assertIn(
|
||||||
|
"The Falcon model was initialized without `trust_remote_code=True`, and will therefore leverage the "
|
||||||
|
"transformers library implementation.",
|
||||||
|
cm.out,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(config1.to_dict(), config2.to_dict())
|
||||||
|
|
||||||
|
def test_auto_config_without_remote_code(self):
|
||||||
|
logger_ = transformers_logging.get_logger("transformers.models.auto.configuration_auto")
|
||||||
|
|
||||||
|
for supported_checkpoint in self.supported_checkpoints:
|
||||||
|
with CaptureLogger(logger_) as cm:
|
||||||
|
config1 = AutoConfig.from_pretrained(supported_checkpoint, trust_remote_code=False)
|
||||||
|
config2 = AutoConfig.from_pretrained(supported_checkpoint)
|
||||||
|
|
||||||
|
self.assertIn(
|
||||||
|
"The Falcon model was initialized without `trust_remote_code=True`, and will therefore leverage the "
|
||||||
|
"transformers library implementation.",
|
||||||
|
cm.out,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(config1.to_dict(), config2.to_dict())
|
||||||
|
|
||||||
|
def test_config_with_remote_code(self):
|
||||||
|
for supported_checkpoint in self.supported_checkpoints:
|
||||||
|
config = FalconConfig.from_pretrained(supported_checkpoint, trust_remote_code=True)
|
||||||
|
|
||||||
|
self.assertIn(config.model_type, ["RefinedWebModel", "RefinedWeb"])
|
||||||
|
|
||||||
|
def test_auto_config_with_remote_code(self):
|
||||||
|
for supported_checkpoint in self.supported_checkpoints:
|
||||||
|
config = AutoConfig.from_pretrained(supported_checkpoint, trust_remote_code=True)
|
||||||
|
|
||||||
|
self.assertIn(config.model_type, ["RefinedWebModel", "RefinedWeb"])
|
||||||
|
|
||||||
|
def test_config_with_specific_revision(self):
|
||||||
|
for supported_checkpoint in self.supported_checkpoints:
|
||||||
|
config = FalconConfig.from_pretrained(
|
||||||
|
supported_checkpoint, revision=self.latest_revisions[supported_checkpoint], trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIn(config.model_type, ["RefinedWebModel", "RefinedWeb"])
|
||||||
|
|
||||||
|
def test_auto_config_with_specific_revision(self):
|
||||||
|
for supported_checkpoint in self.supported_checkpoints:
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
supported_checkpoint, revision=self.latest_revisions[supported_checkpoint], trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIn(config.model_type, ["RefinedWebModel", "RefinedWeb"])
|
||||||
|
|
||||||
|
@tooslow
|
||||||
|
def test_model_without_remote_code(self):
|
||||||
|
logger_ = transformers_logging.get_logger("transformers.models.auto.configuration_auto")
|
||||||
|
for supported_checkpoint in self.supported_checkpoints:
|
||||||
|
with CaptureLogger(logger_) as cm:
|
||||||
|
config1 = FalconModel.from_pretrained(supported_checkpoint, trust_remote_code=False).config
|
||||||
|
config2 = FalconModel.from_pretrained(supported_checkpoint).config
|
||||||
|
|
||||||
|
# trust_remote_code only works with Auto Classes !
|
||||||
|
config3 = FalconModel.from_pretrained(supported_checkpoint, trust_remote_code=True).config
|
||||||
|
|
||||||
|
self.assertIn(
|
||||||
|
"The Falcon model was initialized without `trust_remote_code=True`, and will therefore leverage the "
|
||||||
|
"transformers library implementation.",
|
||||||
|
cm.out,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(config1.to_dict(), config2.to_dict())
|
||||||
|
self.assertEqual(config1.to_dict(), config3.to_dict())
|
||||||
|
|
||||||
|
@tooslow
|
||||||
|
def test_auto_model_without_remote_code(self):
|
||||||
|
logger_ = transformers_logging.get_logger("transformers.models.auto.configuration_auto")
|
||||||
|
for supported_checkpoint in self.supported_checkpoints:
|
||||||
|
with CaptureLogger(logger_) as cm:
|
||||||
|
config1 = AutoModel.from_pretrained(supported_checkpoint, trust_remote_code=False).config
|
||||||
|
config2 = AutoModel.from_pretrained(supported_checkpoint).config
|
||||||
|
|
||||||
|
self.assertIn(
|
||||||
|
"The Falcon model was initialized without `trust_remote_code=True`, and will therefore leverage the "
|
||||||
|
"transformers library implementation.",
|
||||||
|
cm.out,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(config1.to_dict(), config2.to_dict())
|
||||||
|
|
||||||
|
@tooslow
|
||||||
|
def test_auto_model_with_remote_code(self):
|
||||||
|
for supported_checkpoint in self.supported_checkpoints:
|
||||||
|
config = AutoModel.from_pretrained(supported_checkpoint, trust_remote_code=True).config
|
||||||
|
|
||||||
|
self.assertIn(config.model_type, ["RefinedWebModel", "RefinedWeb"])
|
||||||
|
|
||||||
|
@tooslow
|
||||||
|
def test_auto_model_with_specific_revision(self):
|
||||||
|
for supported_checkpoint in self.supported_checkpoints:
|
||||||
|
config = AutoModel.from_pretrained(
|
||||||
|
supported_checkpoint, revision=self.latest_revisions[supported_checkpoint], trust_remote_code=True
|
||||||
|
).config
|
||||||
|
|
||||||
|
self.assertIn(config.model_type, ["RefinedWebModel", "RefinedWeb"])
|
||||||
|
|||||||
Reference in New Issue
Block a user