Put Falcon back (#25960)
* Put Falcon back * Update src/transformers/models/auto/configuration_auto.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update test --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -19,8 +19,9 @@ import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoTokenizer, FalconConfig, is_torch_available, set_seed
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
from transformers import AutoConfig, AutoModel, AutoTokenizer, FalconConfig, is_torch_available, set_seed
|
||||
from transformers.testing_utils import CaptureLogger, require_torch, slow, tooslow, torch_device
|
||||
from transformers.utils import logging as transformers_logging
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -500,3 +501,132 @@ class FalconLanguageGenerationTest(unittest.TestCase):
|
||||
outputs_no_cache = model.generate(**inputs, do_sample=False, max_new_tokens=20, use_cache=False)
|
||||
outputs_cache = model.generate(**inputs, do_sample=False, max_new_tokens=20, use_cache=True)
|
||||
self.assertTrue((outputs_cache - outputs_no_cache).sum().item() == 0)
|
||||
|
||||
|
||||
# TODO Lysandre: Remove this in version v4.34
|
||||
class FalconOverrideTest(unittest.TestCase):
|
||||
supported_checkpoints = [
|
||||
"tiiuae/falcon-7b",
|
||||
"tiiuae/falcon-7b-instruct",
|
||||
"tiiuae/falcon-40b",
|
||||
"tiiuae/falcon-40b-instruct",
|
||||
]
|
||||
|
||||
latest_revisions = {
|
||||
"tiiuae/falcon-7b": "f7796529e36b2d49094450fb038cc7c4c86afa44",
|
||||
"tiiuae/falcon-7b-instruct": "eb410fb6ffa9028e97adb801f0d6ec46d02f8b07",
|
||||
"tiiuae/falcon-40b": "561820f7eef0cc56a31ea38af15ca1acb07fab5d",
|
||||
"tiiuae/falcon-40b-instruct": "ca78eac0ed45bf64445ff0687fabba1598daebf3",
|
||||
}
|
||||
|
||||
def test_config_without_remote_code(self):
|
||||
logger_ = transformers_logging.get_logger("transformers.models.auto.configuration_auto")
|
||||
|
||||
for supported_checkpoint in self.supported_checkpoints:
|
||||
with CaptureLogger(logger_) as cm:
|
||||
config1 = FalconConfig.from_pretrained(supported_checkpoint, trust_remote_code=False)
|
||||
config2 = FalconConfig.from_pretrained(supported_checkpoint)
|
||||
|
||||
self.assertIn(
|
||||
"The Falcon model was initialized without `trust_remote_code=True`, and will therefore leverage the "
|
||||
"transformers library implementation.",
|
||||
cm.out,
|
||||
)
|
||||
|
||||
self.assertEqual(config1.to_dict(), config2.to_dict())
|
||||
|
||||
def test_auto_config_without_remote_code(self):
|
||||
logger_ = transformers_logging.get_logger("transformers.models.auto.configuration_auto")
|
||||
|
||||
for supported_checkpoint in self.supported_checkpoints:
|
||||
with CaptureLogger(logger_) as cm:
|
||||
config1 = AutoConfig.from_pretrained(supported_checkpoint, trust_remote_code=False)
|
||||
config2 = AutoConfig.from_pretrained(supported_checkpoint)
|
||||
|
||||
self.assertIn(
|
||||
"The Falcon model was initialized without `trust_remote_code=True`, and will therefore leverage the "
|
||||
"transformers library implementation.",
|
||||
cm.out,
|
||||
)
|
||||
|
||||
self.assertEqual(config1.to_dict(), config2.to_dict())
|
||||
|
||||
def test_config_with_remote_code(self):
|
||||
for supported_checkpoint in self.supported_checkpoints:
|
||||
config = FalconConfig.from_pretrained(supported_checkpoint, trust_remote_code=True)
|
||||
|
||||
self.assertIn(config.model_type, ["RefinedWebModel", "RefinedWeb"])
|
||||
|
||||
def test_auto_config_with_remote_code(self):
|
||||
for supported_checkpoint in self.supported_checkpoints:
|
||||
config = AutoConfig.from_pretrained(supported_checkpoint, trust_remote_code=True)
|
||||
|
||||
self.assertIn(config.model_type, ["RefinedWebModel", "RefinedWeb"])
|
||||
|
||||
def test_config_with_specific_revision(self):
|
||||
for supported_checkpoint in self.supported_checkpoints:
|
||||
config = FalconConfig.from_pretrained(
|
||||
supported_checkpoint, revision=self.latest_revisions[supported_checkpoint], trust_remote_code=True
|
||||
)
|
||||
|
||||
self.assertIn(config.model_type, ["RefinedWebModel", "RefinedWeb"])
|
||||
|
||||
def test_auto_config_with_specific_revision(self):
|
||||
for supported_checkpoint in self.supported_checkpoints:
|
||||
config = AutoConfig.from_pretrained(
|
||||
supported_checkpoint, revision=self.latest_revisions[supported_checkpoint], trust_remote_code=True
|
||||
)
|
||||
|
||||
self.assertIn(config.model_type, ["RefinedWebModel", "RefinedWeb"])
|
||||
|
||||
@tooslow
|
||||
def test_model_without_remote_code(self):
|
||||
logger_ = transformers_logging.get_logger("transformers.models.auto.configuration_auto")
|
||||
for supported_checkpoint in self.supported_checkpoints:
|
||||
with CaptureLogger(logger_) as cm:
|
||||
config1 = FalconModel.from_pretrained(supported_checkpoint, trust_remote_code=False).config
|
||||
config2 = FalconModel.from_pretrained(supported_checkpoint).config
|
||||
|
||||
# trust_remote_code only works with Auto Classes !
|
||||
config3 = FalconModel.from_pretrained(supported_checkpoint, trust_remote_code=True).config
|
||||
|
||||
self.assertIn(
|
||||
"The Falcon model was initialized without `trust_remote_code=True`, and will therefore leverage the "
|
||||
"transformers library implementation.",
|
||||
cm.out,
|
||||
)
|
||||
|
||||
self.assertEqual(config1.to_dict(), config2.to_dict())
|
||||
self.assertEqual(config1.to_dict(), config3.to_dict())
|
||||
|
||||
@tooslow
|
||||
def test_auto_model_without_remote_code(self):
|
||||
logger_ = transformers_logging.get_logger("transformers.models.auto.configuration_auto")
|
||||
for supported_checkpoint in self.supported_checkpoints:
|
||||
with CaptureLogger(logger_) as cm:
|
||||
config1 = AutoModel.from_pretrained(supported_checkpoint, trust_remote_code=False).config
|
||||
config2 = AutoModel.from_pretrained(supported_checkpoint).config
|
||||
|
||||
self.assertIn(
|
||||
"The Falcon model was initialized without `trust_remote_code=True`, and will therefore leverage the "
|
||||
"transformers library implementation.",
|
||||
cm.out,
|
||||
)
|
||||
|
||||
self.assertEqual(config1.to_dict(), config2.to_dict())
|
||||
|
||||
@tooslow
|
||||
def test_auto_model_with_remote_code(self):
|
||||
for supported_checkpoint in self.supported_checkpoints:
|
||||
config = AutoModel.from_pretrained(supported_checkpoint, trust_remote_code=True).config
|
||||
|
||||
self.assertIn(config.model_type, ["RefinedWebModel", "RefinedWeb"])
|
||||
|
||||
@tooslow
|
||||
def test_auto_model_with_specific_revision(self):
|
||||
for supported_checkpoint in self.supported_checkpoints:
|
||||
config = AutoModel.from_pretrained(
|
||||
supported_checkpoint, revision=self.latest_revisions[supported_checkpoint], trust_remote_code=True
|
||||
).config
|
||||
|
||||
self.assertIn(config.model_type, ["RefinedWebModel", "RefinedWeb"])
|
||||
|
||||
Reference in New Issue
Block a user