[common attributes] Fix previous commit for transfo-xl
This commit is contained in:
@@ -35,7 +35,7 @@ if is_torch_available():
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from transformers import (PretrainedConfig, PreTrainedModel,
|
||||
from transformers import (AdaptiveEmbedding, PretrainedConfig, PreTrainedModel,
|
||||
BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
GPT2LMHeadModel, GPT2Config, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
else:
|
||||
@@ -470,7 +470,7 @@ class CommonTestCases:
|
||||
model = model_class(config)
|
||||
self.assertIsInstance(
|
||||
model.get_input_embeddings(),
|
||||
torch.nn.Embedding
|
||||
(torch.nn.Embedding, AdaptiveEmbedding)
|
||||
)
|
||||
model.set_input_embeddings(torch.nn.Embedding(10, 10))
|
||||
x = model.get_output_embeddings()
|
||||
|
||||
Reference in New Issue
Block a user