Improve mismatched sizes management when loading a pretrained model (#17257)
- Add --ignore_mismatched_sizes argument to classification examples - Expand the error message when loading a model whose head dimensions are different from expected dimensions
This commit is contained in:
@@ -22,7 +22,7 @@ Based on the script [`run_glue.py`](https://github.com/huggingface/transformers/
|
||||
|
||||
Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding
|
||||
Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models)
|
||||
and can also be used for a dataset hosted on our [hub](https://huggingface.co/datasets) or your own data in a csv or a JSON file
|
||||
and can also be used for a dataset hosted on our [hub](https://huggingface.co/datasets) or your own data in a csv or a JSON file
|
||||
(the script might need some tweaks in that case, refer to the comments inside for help).
|
||||
|
||||
GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them:
|
||||
@@ -79,6 +79,8 @@ python run_glue.py \
|
||||
--output_dir /tmp/imdb/
|
||||
```
|
||||
|
||||
> If your model classification head dimensions do not fit the number of labels in the dataset, you can specify `--ignore_mismatched_sizes` to adapt it.
|
||||
|
||||
|
||||
### Mixed precision training
|
||||
|
||||
|
||||
@@ -196,6 +196,10 @@ class ModelArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
ignore_mismatched_sizes: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
@@ -364,6 +368,7 @@ def main():
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
|
||||
)
|
||||
|
||||
# Preprocessing the raw_datasets
|
||||
|
||||
@@ -170,6 +170,11 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ignore_mismatched_sizes",
|
||||
action="store_true",
|
||||
help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Sanity checks
|
||||
@@ -288,6 +293,7 @@ def main():
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
ignore_mismatched_sizes=args.ignore_mismatched_sizes,
|
||||
)
|
||||
|
||||
# Preprocessing the datasets
|
||||
|
||||
@@ -162,6 +162,10 @@ class ModelArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
ignore_mismatched_sizes: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
@@ -291,6 +295,7 @@ def main():
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
|
||||
)
|
||||
|
||||
# Preprocessing the datasets
|
||||
|
||||
Reference in New Issue
Block a user