[bnb] Let's warn users when saving 8-bit models (#20282)
* add warning on 8-bit models - added tests - added wrapper * move to a private attribute - remove wrapper - changed `save_pretrained` method * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix suggestions Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -1538,6 +1538,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
kwargs:
|
kwargs:
|
||||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||||
"""
|
"""
|
||||||
|
# Checks if the model has been loaded in 8-bit
|
||||||
|
if getattr(self, "is_loaded_in_8bit", False):
|
||||||
|
warnings.warn(
|
||||||
|
"You are calling `save_pretrained` to a 8-bit converted model you may likely encounter unexepected"
|
||||||
|
" behaviors. ",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
if "save_config" in kwargs:
|
if "save_config" in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead."
|
"`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead."
|
||||||
@@ -2340,6 +2348,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
load_in_8bit=load_in_8bit,
|
load_in_8bit=load_in_8bit,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cls.is_loaded_in_8bit = load_in_8bit
|
||||||
|
|
||||||
# make sure token embedding weights are still tied if needed
|
# make sure token embedding weights are still tied if needed
|
||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import gc
|
import gc
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -107,6 +108,13 @@ class MixedInt8Test(BaseMixedInt8Test):
|
|||||||
|
|
||||||
self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
def test_warns_save_pretrained(self):
|
||||||
|
r"""
|
||||||
|
Test whether trying to save a model after converting it in 8-bit will throw a warning.
|
||||||
|
"""
|
||||||
|
with self.assertWarns(UserWarning), tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
self.model_8bit.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
|
||||||
class MixedInt8ModelClassesTest(BaseMixedInt8Test):
|
class MixedInt8ModelClassesTest(BaseMixedInt8Test):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user