From 7d65efec29258ce3c97449117ecf758fbbfbd6de Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 17 Nov 2022 08:16:36 +0100 Subject: [PATCH] [bnb] Let's warn users when saving 8-bit models (#20282) * add warning on 8-bit models - added tests - added wrapper * move to a private attribute - remove wrapper - changed `save_pretrained` method * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix suggestions Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/modeling_utils.py | 10 ++++++++++ tests/mixed_int8/test_mixed_int8.py | 8 ++++++++ 2 files changed, 18 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index aa5592c7d9..57294682a8 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1538,6 +1538,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix kwargs: Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ + # Checks if the model has been loaded in 8-bit + if getattr(self, "is_loaded_in_8bit", False): + warnings.warn( + "You are calling `save_pretrained` to a 8-bit converted model you may likely encounter unexepected" + " behaviors. ", + UserWarning, + ) + if "save_config" in kwargs: warnings.warn( "`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead." @@ -2340,6 +2348,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix load_in_8bit=load_in_8bit, ) + cls.is_loaded_in_8bit = load_in_8bit + # make sure token embedding weights are still tied if needed model.tie_weights() diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index 2911d67748..a459ffa84d 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import gc +import tempfile import unittest from transformers import ( @@ -107,6 +108,13 @@ class MixedInt8Test(BaseMixedInt8Test): self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + def test_warns_save_pretrained(self): + r""" + Test whether trying to save a model after converting it in 8-bit will throw a warning. + """ + with self.assertWarns(UserWarning), tempfile.TemporaryDirectory() as tmpdirname: + self.model_8bit.save_pretrained(tmpdirname) + class MixedInt8ModelClassesTest(BaseMixedInt8Test): def setUp(self):