[common attributes] Fix previous commit for transfo-xl

This commit is contained in:
Julien Chaumond
2019-11-11 20:03:19 -05:00
parent 2f17464266
commit 2aef2f0bbc
2 changed files with 3 additions and 2 deletions

View File

@@ -72,6 +72,7 @@ if is_torch_available():
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP) load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_transfo_xl import (TransfoXLPreTrainedModel, TransfoXLModel, TransfoXLLMHeadModel, from .modeling_transfo_xl import (TransfoXLPreTrainedModel, TransfoXLModel, TransfoXLLMHeadModel,
AdaptiveEmbedding,
load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP) load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_gpt2 import (GPT2PreTrainedModel, GPT2Model, from .modeling_gpt2 import (GPT2PreTrainedModel, GPT2Model,
GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2LMHeadModel, GPT2DoubleHeadsModel,

View File

@@ -35,7 +35,7 @@ if is_torch_available():
import torch import torch
import numpy as np import numpy as np
from transformers import (PretrainedConfig, PreTrainedModel, from transformers import (AdaptiveEmbedding, PretrainedConfig, PreTrainedModel,
BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2LMHeadModel, GPT2Config, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) GPT2LMHeadModel, GPT2Config, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
else: else:
@@ -470,7 +470,7 @@ class CommonTestCases:
model = model_class(config) model = model_class(config)
self.assertIsInstance( self.assertIsInstance(
model.get_input_embeddings(), model.get_input_embeddings(),
torch.nn.Embedding (torch.nn.Embedding, AdaptiveEmbedding)
) )
model.set_input_embeddings(torch.nn.Embedding(10, 10)) model.set_input_embeddings(torch.nn.Embedding(10, 10))
x = model.get_output_embeddings() x = model.get_output_embeddings()