config.architectures
This commit is contained in:
committed by
Lysandre Debut
parent
f9bc3f5771
commit
b85c59f997
@@ -82,6 +82,7 @@ class PretrainedConfig(object):
|
|||||||
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
||||||
|
|
||||||
# Fine-tuning task arguments
|
# Fine-tuning task arguments
|
||||||
|
self.architectures = kwargs.pop("architectures", None)
|
||||||
self.finetuning_task = kwargs.pop("finetuning_task", None)
|
self.finetuning_task = kwargs.pop("finetuning_task", None)
|
||||||
self.num_labels = kwargs.pop("num_labels", 2)
|
self.num_labels = kwargs.pop("num_labels", 2)
|
||||||
self.id2label = kwargs.pop("id2label", {i: "LABEL_{}".format(i) for i in range(self.num_labels)})
|
self.id2label = kwargs.pop("id2label", {i: "LABEL_{}".format(i) for i in range(self.num_labels)})
|
||||||
|
|||||||
@@ -284,6 +284,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
# Only save the model itself if we are using distributed training
|
# Only save the model itself if we are using distributed training
|
||||||
model_to_save = self.module if hasattr(self, "module") else self
|
model_to_save = self.module if hasattr(self, "module") else self
|
||||||
|
|
||||||
|
# Attach architecture to the config
|
||||||
|
model_to_save.config.architectures = [model_to_save.__class__.__name__]
|
||||||
|
|
||||||
# Save configuration file
|
# Save configuration file
|
||||||
model_to_save.config.save_pretrained(save_directory)
|
model_to_save.config.save_pretrained(save_directory)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user