Torch 1.1.0 compatibility + FP16 O1 + TF checkpoints
Co-authored-by: wassname
This commit is contained in:
@@ -203,8 +203,8 @@ class AlbertAttention(BertSelfAttention):
|
||||
|
||||
|
||||
# Should find a better way to do this
|
||||
w = self.dense.weight.T.view(self.num_attention_heads, self.attention_head_size, self.hidden_size)
|
||||
b = self.dense.bias
|
||||
w = self.dense.weight.t().view(self.num_attention_heads, self.attention_head_size, self.hidden_size).to(context_layer.dtype)
|
||||
b = self.dense.bias.to(context_layer.dtype)
|
||||
|
||||
projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer, w) + b
|
||||
projected_context_layer_dropout = self.dropout(projected_context_layer)
|
||||
|
||||
@@ -36,7 +36,14 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
# TODO FILL THAT UP
|
||||
'albert-base-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-tf_model.h5",
|
||||
'albert-large-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-tf_model.h5",
|
||||
'albert-xlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-tf_model.h5",
|
||||
'albert-xxlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-tf_model.h5",
|
||||
'albert-base-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-tf_model.h5",
|
||||
'albert-large-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-tf_model.h5",
|
||||
'albert-xlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-tf_model.h5",
|
||||
'albert-xxlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-tf_model.h5",
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user