Update form pretrained to make TP a first class citizen (#36335)

* clean code

* oups

* fix merge

* yups

* fix if

* now you can play

* fix shape issue

* try non blocking

* fix

* updates

* up

* updates

* fix most of thetests

* update

* update

* small updates

* up

* fix the remaining bug?

* update

* rename when you read from the file

* buffer issues

* current status

* cleanup

* properly allocate dumb memory

* update a small bug

* fix colwise rep issue

* fix keep in float 32 that was keeping everything in float 32

* typo

* more fixes with keep_in_fp32_modules as we use to serach on it

* fix ROPE dtype for TP

* remove what's breaking the tests

* updates

* update and fixes

* small cleanup after merging

* allocate 2x to be safe

* style, auto

* update

* yup nit

* fix

* remove slow as fuck torch api :(

* work

* fixup

* update

* brting the fix back

* fix and update

* fixes

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* updates because some suggestions were wrong 👀

* update?

* fuck this bloated function

* typo

* fix the dumb prefix thing once and forall

* fixes here and there

* updates

* remove prints

* fix strict cases

* styel

* properly fix keys on load!

* update

* fix base model prefix issue

* style

* update

* fix all?

* remoce 1 print

* fix the final etsts

* fixup

* last nits

* fix the detach issue which cause a 2x slowdown

* fixup

* small fixes

* ultra nit

* fix

* fix

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Arthur
2025-02-26 20:12:38 +01:00
committed by GitHub
parent 981c276a02
commit 1603018e7a
36 changed files with 442 additions and 340 deletions

View File

@@ -525,12 +525,13 @@ class ModelUtilsTest(TestCasePlus):
self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float16)
# TODO @ARTHURZUCKER FIX THIS
# but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what
LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"]
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto")
self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float32)
# LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"]
# model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto")
# self.assertEqual(model.language_model.dtype, torch.float32)
# self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
# self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float32)
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
with self.assertRaises(ValueError):
@@ -540,6 +541,7 @@ class ModelUtilsTest(TestCasePlus):
)
@require_torch
@unittest.skip("Broken by @arthurzucker because the fix was not correct. Knowing the context is super hard")
def test_model_from_pretrained_meta_device(self):
def is_on_meta(model_id, dtype):
with torch.device("meta"):