[bnb] Let's warn users when saving 8-bit models (#20282)
* add warning on 8-bit models - added tests - added wrapper * move to a private attribute - remove wrapper - changed `save_pretrained` method * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix suggestions Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import (
|
||||
@@ -107,6 +108,13 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
|
||||
self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def test_warns_save_pretrained(self):
|
||||
r"""
|
||||
Test whether trying to save a model after converting it in 8-bit will throw a warning.
|
||||
"""
|
||||
with self.assertWarns(UserWarning), tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.model_8bit.save_pretrained(tmpdirname)
|
||||
|
||||
|
||||
class MixedInt8ModelClassesTest(BaseMixedInt8Test):
|
||||
def setUp(self):
|
||||
|
||||
Reference in New Issue
Block a user