[bnb] Let's warn users when saving 8-bit models (#20282)
* add warning on 8-bit models - added tests - added wrapper * move to a private attribute - remove wrapper - changed `save_pretrained` method * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix suggestions Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -1538,6 +1538,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
kwargs:
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
# Checks if the model has been loaded in 8-bit
|
||||
if getattr(self, "is_loaded_in_8bit", False):
|
||||
warnings.warn(
|
||||
"You are calling `save_pretrained` to a 8-bit converted model you may likely encounter unexepected"
|
||||
" behaviors. ",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
if "save_config" in kwargs:
|
||||
warnings.warn(
|
||||
"`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead."
|
||||
@@ -2340,6 +2348,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
load_in_8bit=load_in_8bit,
|
||||
)
|
||||
|
||||
cls.is_loaded_in_8bit = load_in_8bit
|
||||
|
||||
# make sure token embedding weights are still tied if needed
|
||||
model.tie_weights()
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import (
|
||||
@@ -107,6 +108,13 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
|
||||
self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def test_warns_save_pretrained(self):
|
||||
r"""
|
||||
Test whether trying to save a model after converting it in 8-bit will throw a warning.
|
||||
"""
|
||||
with self.assertWarns(UserWarning), tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.model_8bit.save_pretrained(tmpdirname)
|
||||
|
||||
|
||||
class MixedInt8ModelClassesTest(BaseMixedInt8Test):
|
||||
def setUp(self):
|
||||
|
||||
Reference in New Issue
Block a user