Support PeftModel signature inspect (#27865)
* Support PeftModel signature inspect * Use get_base_model() to get the base model --------- Co-authored-by: shujunhua1 <shujunhua1@jd.com>
This commit is contained in:
@@ -695,7 +695,10 @@ class Trainer:
|
|||||||
def _set_signature_columns_if_needed(self):
|
def _set_signature_columns_if_needed(self):
|
||||||
if self._signature_columns is None:
|
if self._signature_columns is None:
|
||||||
# Inspect model forward signature to keep only the arguments it accepts.
|
# Inspect model forward signature to keep only the arguments it accepts.
|
||||||
signature = inspect.signature(self.model.forward)
|
model_to_inspect = self.model
|
||||||
|
if is_peft_available() and isinstance(self.model, PeftModel):
|
||||||
|
model_to_inspect = self.model.get_base_model()
|
||||||
|
signature = inspect.signature(model_to_inspect.forward)
|
||||||
self._signature_columns = list(signature.parameters.keys())
|
self._signature_columns = list(signature.parameters.keys())
|
||||||
# Labels may be named label or label_ids, the default data collator handles that.
|
# Labels may be named label or label_ids, the default data collator handles that.
|
||||||
self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
|
self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
|
||||||
|
|||||||
Reference in New Issue
Block a user