From 2aef2f0bbcd3b192af18718684615019a7777a9b Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Mon, 11 Nov 2019 20:03:19 -0500 Subject: [PATCH] [common attributes] Fix previous commit for transfo-xl --- transformers/__init__.py | 1 + transformers/tests/modeling_common_test.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/transformers/__init__.py b/transformers/__init__.py index 53f3c39dc7..d922f52a1d 100644 --- a/transformers/__init__.py +++ b/transformers/__init__.py @@ -72,6 +72,7 @@ if is_torch_available(): OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_transfo_xl import (TransfoXLPreTrainedModel, TransfoXLModel, TransfoXLLMHeadModel, + AdaptiveEmbedding, load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_gpt2 import (GPT2PreTrainedModel, GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel, diff --git a/transformers/tests/modeling_common_test.py b/transformers/tests/modeling_common_test.py index 777e62459b..baf1531403 100644 --- a/transformers/tests/modeling_common_test.py +++ b/transformers/tests/modeling_common_test.py @@ -35,7 +35,7 @@ if is_torch_available(): import torch import numpy as np - from transformers import (PretrainedConfig, PreTrainedModel, + from transformers import (AdaptiveEmbedding, PretrainedConfig, PreTrainedModel, BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2Config, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) else: @@ -470,7 +470,7 @@ class CommonTestCases: model = model_class(config) self.assertIsInstance( model.get_input_embeddings(), - torch.nn.Embedding + (torch.nn.Embedding, AdaptiveEmbedding) ) model.set_input_embeddings(torch.nn.Embedding(10, 10)) x = model.get_output_embeddings()