[common attributes] Fix previous commit for transfo-xl

This commit is contained in:
Julien Chaumond
2019-11-11 20:03:19 -05:00
parent 2f17464266
commit 2aef2f0bbc
2 changed files with 3 additions and 2 deletions

View File

@@ -35,7 +35,7 @@ if is_torch_available():
import torch
import numpy as np
from transformers import (PretrainedConfig, PreTrainedModel,
from transformers import (AdaptiveEmbedding, PretrainedConfig, PreTrainedModel,
BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2LMHeadModel, GPT2Config, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
else:
@@ -470,7 +470,7 @@ class CommonTestCases:
model = model_class(config)
self.assertIsInstance(
model.get_input_embeddings(),
torch.nn.Embedding
(torch.nn.Embedding, AdaptiveEmbedding)
)
model.set_input_embeddings(torch.nn.Embedding(10, 10))
x = model.get_output_embeddings()