Fix model kwargs (#35875)

* Save state

* Make a failing test

* Better test

* mpt -> done, many more to go

* Rm extranious

* Bamba

* Bert

* big_bird

* biogpt

* bloom

* codegen

* ctrl

* data2vec

* dbrx

* Through up to Dbrx

* electra

* ernie

* falcon

* Fuyu/persimmon

* Include noop kwargs to base models

* Rebase

* Skip musigen

* Refactor/skip mllama

* Revert makefile

* Rm file

* Fix PT failing, need to modify rest of loss funcs to not resize

* Propagate some

* Continue

* More

* More options

* Mostly fixed

* Proved that it's the same

* Bloom is good

* Make ability to override loss func possible

* Fixup

* Clean

* Fix xglm

* Quality tests

* Skip OCR2

* Make specific loss for xglm

* Make order the same/line up 1:1

* xglm

* Skip fx output loss bloom model

* Didn't pass in pad_token_id

* Fix quality
This commit is contained in:
Zach Mueller
2025-02-06 11:35:25 -05:00
committed by GitHub
parent 1590c66430
commit 28f73bc307
48 changed files with 365 additions and 241 deletions

View File

@@ -922,6 +922,42 @@ class ModelTesterMixin:
loss = model(**inputs).loss
loss.backward()
def test_causal_lm_can_accept_kwargs(self):
if not getattr(self.model_tester, "is_training", False):
self.skipTest(reason="ModelTester is not configured to run training tests")
valid_model_class = False
incompatible_models = (
"MusicgenForCausalLM",
"MusicgenMelodyForCausalLM",
"MllamaForCausalLM",
"CpmAntForCausalLM",
"GotOcr2ForConditionalGeneration",
)
for model_class in self.all_model_classes:
if (
model_class.__name__ in get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
and model_class.__name__ not in incompatible_models
):
valid_model_class = True
if not valid_model_class:
self.skipTest(reason="No causal lm model classes found")
for model_class in self.all_model_classes:
model_name = model_class.__name__
if model_name in get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) and model_name not in incompatible_models:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
with tempfile.TemporaryDirectory() as tmpdir:
with torch.device(torch_device):
model_eager = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float32)
model_eager.save_pretrained(tmpdir)
with torch.device(torch_device):
model = AutoModelForCausalLM.from_pretrained(tmpdir, torch_dtype=torch.float32)
inputs_dict["num_items_in_batch"] = inputs_dict["input_ids"].shape[0]
inputs_dict["labels"] = inputs_dict["input_ids"]
_ = model(**inputs_dict, return_dict=False)
def test_training_gradient_checkpointing(self):
# Scenario - 1 default behaviour
self.check_training_gradient_checkpointing()
@@ -1236,6 +1272,8 @@ class ModelTesterMixin:
self._create_and_check_torch_fx_tracing(config, inputs_dict)
def test_torch_fx_output_loss(self):
if self.all_model_classes[0].__name__ == "BloomModel":
self.skipTest(reason="Bloom currently has issues, @michaelbenayoun")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True)