From c9cb7f8a0fbe784665b00bfdca6bfc54ad10d5f5 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Mon, 11 Nov 2019 15:12:54 -0500 Subject: [PATCH] Torch 1.1.0 compatibility + FP16 O1 + TF checkpoints Co-authored-by: wassname --- transformers/modeling_albert.py | 4 ++-- transformers/modeling_tf_albert.py | 9 ++++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/transformers/modeling_albert.py b/transformers/modeling_albert.py index 640af537d0..ff20ca78dc 100644 --- a/transformers/modeling_albert.py +++ b/transformers/modeling_albert.py @@ -203,8 +203,8 @@ class AlbertAttention(BertSelfAttention): # Should find a better way to do this - w = self.dense.weight.T.view(self.num_attention_heads, self.attention_head_size, self.hidden_size) - b = self.dense.bias + w = self.dense.weight.t().view(self.num_attention_heads, self.attention_head_size, self.hidden_size).to(context_layer.dtype) + b = self.dense.bias.to(context_layer.dtype) projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer, w) + b projected_context_layer_dropout = self.dropout(projected_context_layer) diff --git a/transformers/modeling_tf_albert.py b/transformers/modeling_tf_albert.py index a3f183b192..ee8712eb28 100644 --- a/transformers/modeling_tf_albert.py +++ b/transformers/modeling_tf_albert.py @@ -36,7 +36,14 @@ import logging logger = logging.getLogger(__name__) TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = { - # TODO FILL THAT UP + 'albert-base-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-tf_model.h5", + 'albert-large-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-tf_model.h5", + 'albert-xlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-tf_model.h5", + 'albert-xxlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-tf_model.h5", + 'albert-base-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-tf_model.h5", + 'albert-large-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-tf_model.h5", + 'albert-xlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-tf_model.h5", + 'albert-xxlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-tf_model.h5", }