[common attributes] Fix previous commit for transfo-xl
This commit is contained in:
@@ -72,6 +72,7 @@ if is_torch_available():
|
|||||||
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
|
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
|
||||||
load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
from .modeling_transfo_xl import (TransfoXLPreTrainedModel, TransfoXLModel, TransfoXLLMHeadModel,
|
from .modeling_transfo_xl import (TransfoXLPreTrainedModel, TransfoXLModel, TransfoXLLMHeadModel,
|
||||||
|
AdaptiveEmbedding,
|
||||||
load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
from .modeling_gpt2 import (GPT2PreTrainedModel, GPT2Model,
|
from .modeling_gpt2 import (GPT2PreTrainedModel, GPT2Model,
|
||||||
GPT2LMHeadModel, GPT2DoubleHeadsModel,
|
GPT2LMHeadModel, GPT2DoubleHeadsModel,
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from transformers import (PretrainedConfig, PreTrainedModel,
|
from transformers import (AdaptiveEmbedding, PretrainedConfig, PreTrainedModel,
|
||||||
BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
GPT2LMHeadModel, GPT2Config, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
|
GPT2LMHeadModel, GPT2Config, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
else:
|
else:
|
||||||
@@ -470,7 +470,7 @@ class CommonTestCases:
|
|||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
self.assertIsInstance(
|
self.assertIsInstance(
|
||||||
model.get_input_embeddings(),
|
model.get_input_embeddings(),
|
||||||
torch.nn.Embedding
|
(torch.nn.Embedding, AdaptiveEmbedding)
|
||||||
)
|
)
|
||||||
model.set_input_embeddings(torch.nn.Embedding(10, 10))
|
model.set_input_embeddings(torch.nn.Embedding(10, 10))
|
||||||
x = model.get_output_embeddings()
|
x = model.get_output_embeddings()
|
||||||
|
|||||||
Reference in New Issue
Block a user