Merge branch 'master' into resumable_http
This commit is contained in:
@@ -53,7 +53,7 @@ class PreTrainedModel(nn.Module):
|
||||
r""" Base class for all models.
|
||||
|
||||
:class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
|
||||
as well as a few methods commons to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
|
||||
as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
|
||||
|
||||
Class attributes (overridden by derived classes):
|
||||
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
|
||||
@@ -83,6 +83,94 @@ class PreTrainedModel(nn.Module):
|
||||
# Save config in model
|
||||
self.config = config
|
||||
|
||||
@property
|
||||
def base_model(self):
|
||||
return getattr(self, self.base_model_prefix, self)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
""" Get model's input embeddings
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
if base_model is not self:
|
||||
return base_model.get_input_embeddings()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
""" Set model's input embeddings
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
if base_model is not self:
|
||||
base_model.set_input_embeddings(value)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_output_embeddings(self):
|
||||
""" Get model's output embeddings
|
||||
Return None if the model doesn't have output embeddings
|
||||
"""
|
||||
return None # Overwrite for models with output embeddings
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
output_embeddings = self.get_output_embeddings()
|
||||
if output_embeddings is not None:
|
||||
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
||||
|
||||
def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
|
||||
""" Tie or clone module weights depending of weither we are using TorchScript or not
|
||||
"""
|
||||
if self.config.torchscript:
|
||||
output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
|
||||
else:
|
||||
output_embeddings.weight = input_embeddings.weight
|
||||
|
||||
if hasattr(output_embeddings, 'bias') and output_embeddings.bias is not None:
|
||||
output_embeddings.bias.data = torch.nn.functional.pad(
|
||||
output_embeddings.bias.data,
|
||||
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
|
||||
'constant',
|
||||
0
|
||||
)
|
||||
if hasattr(output_embeddings, 'out_features') and hasattr(input_embeddings, 'num_embeddings'):
|
||||
output_embeddings.out_features = input_embeddings.num_embeddings
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens=None):
|
||||
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
|
||||
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
|
||||
|
||||
Arguments:
|
||||
|
||||
new_num_tokens: (`optional`) int:
|
||||
New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
|
||||
If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
|
||||
|
||||
Return: ``torch.nn.Embeddings``
|
||||
Pointer to the input tokens Embeddings Module of the model
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||
model_embeds = base_model._resize_token_embeddings(new_num_tokens)
|
||||
if new_num_tokens is None:
|
||||
return model_embeds
|
||||
|
||||
# Update base model and current model config
|
||||
self.config.vocab_size = new_num_tokens
|
||||
base_model.vocab_size = new_num_tokens
|
||||
|
||||
# Tie weights again if needed
|
||||
if hasattr(self, 'tie_weights'):
|
||||
self.tie_weights()
|
||||
|
||||
return model_embeds
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
old_embeddings = self.get_input_embeddings()
|
||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
self.set_input_embeddings(new_embeddings)
|
||||
return self.get_input_embeddings()
|
||||
|
||||
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
|
||||
""" Build a resized Embedding Module from a provided token Embedding Module.
|
||||
Increasing the size will add newly initialized vectors at the end
|
||||
@@ -117,50 +205,6 @@ class PreTrainedModel(nn.Module):
|
||||
|
||||
return new_embeddings
|
||||
|
||||
def _tie_or_clone_weights(self, first_module, second_module):
|
||||
""" Tie or clone module weights depending of weither we are using TorchScript or not
|
||||
"""
|
||||
if self.config.torchscript:
|
||||
first_module.weight = nn.Parameter(second_module.weight.clone())
|
||||
else:
|
||||
first_module.weight = second_module.weight
|
||||
|
||||
if hasattr(first_module, 'bias') and first_module.bias is not None:
|
||||
first_module.bias.data = torch.nn.functional.pad(
|
||||
first_module.bias.data,
|
||||
(0, first_module.weight.shape[0] - first_module.bias.shape[0]),
|
||||
'constant',
|
||||
0
|
||||
)
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens=None):
|
||||
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
|
||||
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
|
||||
|
||||
Arguments:
|
||||
|
||||
new_num_tokens: (`optional`) int:
|
||||
New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
|
||||
If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
|
||||
|
||||
Return: ``torch.nn.Embeddings``
|
||||
Pointer to the input tokens Embeddings Module of the model
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||
model_embeds = base_model._resize_token_embeddings(new_num_tokens)
|
||||
if new_num_tokens is None:
|
||||
return model_embeds
|
||||
|
||||
# Update base model and current model config
|
||||
self.config.vocab_size = new_num_tokens
|
||||
base_model.vocab_size = new_num_tokens
|
||||
|
||||
# Tie weights again if needed
|
||||
if hasattr(self, 'tie_weights'):
|
||||
self.tie_weights()
|
||||
|
||||
return model_embeds
|
||||
|
||||
def init_weights(self):
|
||||
""" Initialize and prunes weights if needed. """
|
||||
# Initialize weights
|
||||
@@ -170,6 +214,9 @@ class PreTrainedModel(nn.Module):
|
||||
if self.config.pruned_heads:
|
||||
self.prune_heads(self.config.pruned_heads)
|
||||
|
||||
# Tie weights if needed
|
||||
self.tie_weights()
|
||||
|
||||
def prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the base model.
|
||||
|
||||
@@ -178,14 +225,12 @@ class PreTrainedModel(nn.Module):
|
||||
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
|
||||
E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||
|
||||
# save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
|
||||
for layer, heads in heads_to_prune.items():
|
||||
union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
|
||||
self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
|
||||
|
||||
base_model._prune_heads(heads_to_prune)
|
||||
self.base_model._prune_heads(heads_to_prune)
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
""" Save a model and its configuration file to a directory, so that it
|
||||
@@ -193,7 +238,7 @@ class PreTrainedModel(nn.Module):
|
||||
"""
|
||||
assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
|
||||
|
||||
# Only save the model it-self if we are using distributed training
|
||||
# Only save the model itself if we are using distributed training
|
||||
model_to_save = self.module if hasattr(self, 'module') else self
|
||||
|
||||
# Save configuration file
|
||||
@@ -273,6 +318,10 @@ class PreTrainedModel(nn.Module):
|
||||
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
|
||||
"""
|
||||
if "albert" in pretrained_model_name_or_path and "v2" in pretrained_model_name_or_path:
|
||||
logger.warning("There is currently an upstream reproducibility issue with ALBERT v2 models. Please see " +
|
||||
"https://github.com/google-research/google-research/issues/119 for more information.")
|
||||
|
||||
config = kwargs.pop('config', None)
|
||||
state_dict = kwargs.pop('state_dict', None)
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
@@ -289,6 +338,7 @@ class PreTrainedModel(nn.Module):
|
||||
cache_dir=cache_dir, return_unused_kwargs=True,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
@@ -389,6 +439,8 @@ class PreTrainedModel(nn.Module):
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
||||
# so we need to apply the function recursively.
|
||||
def load(module, prefix=''):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
module._load_from_state_dict(
|
||||
|
||||
Reference in New Issue
Block a user