Improve mismatched sizes management when loading a pretrained model (#17257)
- Add --ignore_mismatched_sizes argument to classification examples - Expand the error message when loading a model whose head dimensions are different from expected dimensions
This commit is contained in:
@@ -55,6 +55,8 @@ uses special features of those tokenizers. You can check if your favorite model
|
||||
[this table](https://huggingface.co/transformers/index.html#supported-frameworks), if it doesn't you can still use the old version
|
||||
of the script.
|
||||
|
||||
> If your model classification head dimensions do not fit the number of labels in the dataset, you can specify `--ignore_mismatched_sizes` to adapt it.
|
||||
|
||||
## Old version of the script
|
||||
|
||||
You can find the old version of the PyTorch script [here](https://github.com/huggingface/transformers/blob/main/examples/legacy/token-classification/run_ner.py).
|
||||
|
||||
@@ -87,6 +87,10 @@ class ModelArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
ignore_mismatched_sizes: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -364,6 +368,7 @@ def main():
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
|
||||
)
|
||||
|
||||
# Tokenizer check: this script requires a fast tokenizer.
|
||||
|
||||
@@ -223,6 +223,11 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ignore_mismatched_sizes",
|
||||
action="store_true",
|
||||
help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Sanity checks
|
||||
@@ -383,6 +388,7 @@ def main():
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
ignore_mismatched_sizes=args.ignore_mismatched_sizes,
|
||||
)
|
||||
else:
|
||||
logger.info("Training new model from scratch")
|
||||
|
||||
Reference in New Issue
Block a user