[style] consistent nn. and nn.functional: part 3 tests (#12155)
* consistent nn. and nn.functional: p3 templates * restore
This commit is contained in:
@@ -44,6 +44,7 @@ from transformers.testing_utils import (
|
||||
if is_torch_available():
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import (
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
@@ -1150,10 +1151,10 @@ class ModelTesterMixin:
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
self.assertIsInstance(model.get_input_embeddings(), (torch.nn.Embedding, AdaptiveEmbedding))
|
||||
model.set_input_embeddings(torch.nn.Embedding(10, 10))
|
||||
self.assertIsInstance(model.get_input_embeddings(), (nn.Embedding, AdaptiveEmbedding))
|
||||
model.set_input_embeddings(nn.Embedding(10, 10))
|
||||
x = model.get_output_embeddings()
|
||||
self.assertTrue(x is None or isinstance(x, torch.nn.Linear))
|
||||
self.assertTrue(x is None or isinstance(x, nn.Linear))
|
||||
|
||||
def test_correct_missing_keys(self):
|
||||
if not self.test_missing_keys:
|
||||
@@ -1337,7 +1338,7 @@ class ModelTesterMixin:
|
||||
model.eval()
|
||||
|
||||
# Wrap model in nn.DataParallel
|
||||
model = torch.nn.DataParallel(model)
|
||||
model = nn.DataParallel(model)
|
||||
with torch.no_grad():
|
||||
_ = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user