Fix model kwargs (#35875)
* Save state * Make a failing test * Better test * mpt -> done, many more to go * Rm extranious * Bamba * Bert * big_bird * biogpt * bloom * codegen * ctrl * data2vec * dbrx * Through up to Dbrx * electra * ernie * falcon * Fuyu/persimmon * Include noop kwargs to base models * Rebase * Skip musigen * Refactor/skip mllama * Revert makefile * Rm file * Fix PT failing, need to modify rest of loss funcs to not resize * Propagate some * Continue * More * More options * Mostly fixed * Proved that it's the same * Bloom is good * Make ability to override loss func possible * Fixup * Clean * Fix xglm * Quality tests * Skip OCR2 * Make specific loss for xglm * Make order the same/line up 1:1 * xglm * Skip fx output loss bloom model * Didn't pass in pad_token_id * Fix quality
This commit is contained in:
@@ -922,6 +922,42 @@ class ModelTesterMixin:
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
|
||||
def test_causal_lm_can_accept_kwargs(self):
|
||||
if not getattr(self.model_tester, "is_training", False):
|
||||
self.skipTest(reason="ModelTester is not configured to run training tests")
|
||||
|
||||
valid_model_class = False
|
||||
incompatible_models = (
|
||||
"MusicgenForCausalLM",
|
||||
"MusicgenMelodyForCausalLM",
|
||||
"MllamaForCausalLM",
|
||||
"CpmAntForCausalLM",
|
||||
"GotOcr2ForConditionalGeneration",
|
||||
)
|
||||
for model_class in self.all_model_classes:
|
||||
if (
|
||||
model_class.__name__ in get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
|
||||
and model_class.__name__ not in incompatible_models
|
||||
):
|
||||
valid_model_class = True
|
||||
if not valid_model_class:
|
||||
self.skipTest(reason="No causal lm model classes found")
|
||||
for model_class in self.all_model_classes:
|
||||
model_name = model_class.__name__
|
||||
if model_name in get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) and model_name not in incompatible_models:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with torch.device(torch_device):
|
||||
model_eager = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float32)
|
||||
|
||||
model_eager.save_pretrained(tmpdir)
|
||||
with torch.device(torch_device):
|
||||
model = AutoModelForCausalLM.from_pretrained(tmpdir, torch_dtype=torch.float32)
|
||||
inputs_dict["num_items_in_batch"] = inputs_dict["input_ids"].shape[0]
|
||||
inputs_dict["labels"] = inputs_dict["input_ids"]
|
||||
_ = model(**inputs_dict, return_dict=False)
|
||||
|
||||
def test_training_gradient_checkpointing(self):
|
||||
# Scenario - 1 default behaviour
|
||||
self.check_training_gradient_checkpointing()
|
||||
@@ -1236,6 +1272,8 @@ class ModelTesterMixin:
|
||||
self._create_and_check_torch_fx_tracing(config, inputs_dict)
|
||||
|
||||
def test_torch_fx_output_loss(self):
|
||||
if self.all_model_classes[0].__name__ == "BloomModel":
|
||||
self.skipTest(reason="Bloom currently has issues, @michaelbenayoun")
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user