[bnb] Let's warn users when saving 8-bit models (#20282)

* add warning on 8-bit models

- added tests
- added wrapper

* move to a private attribute

- remove wrapper
- changed `save_pretrained` method

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix suggestions

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Younes Belkada
2022-11-17 08:16:36 +01:00
committed by GitHub
parent 0a144b8c6b
commit 7d65efec29
2 changed files with 18 additions and 0 deletions

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest
from transformers import (
@@ -107,6 +108,13 @@ class MixedInt8Test(BaseMixedInt8Test):
self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_warns_save_pretrained(self):
r"""
Test whether trying to save a model after converting it in 8-bit will throw a warning.
"""
with self.assertWarns(UserWarning), tempfile.TemporaryDirectory() as tmpdirname:
self.model_8bit.save_pretrained(tmpdirname)
class MixedInt8ModelClassesTest(BaseMixedInt8Test):
def setUp(self):