[style] consistent nn. and nn.functional: part 3 tests (#12155)
* consistent nn. and nn.functional: p3 templates * restore
This commit is contained in:
@@ -30,6 +30,7 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import FSMTConfig, FSMTForConditionalGeneration, FSMTModel, FSMTTokenizer
|
||||
from transformers.models.fsmt.modeling_fsmt import (
|
||||
@@ -160,10 +161,10 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
self.assertIsInstance(model.get_input_embeddings(), (torch.nn.Embedding))
|
||||
model.set_input_embeddings(torch.nn.Embedding(10, 10))
|
||||
self.assertIsInstance(model.get_input_embeddings(), (nn.Embedding))
|
||||
model.set_input_embeddings(nn.Embedding(10, 10))
|
||||
x = model.get_output_embeddings()
|
||||
self.assertTrue(x is None or isinstance(x, torch.nn.modules.sparse.Embedding))
|
||||
self.assertTrue(x is None or isinstance(x, nn.modules.sparse.Embedding))
|
||||
|
||||
def test_initialization_more(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||
|
||||
Reference in New Issue
Block a user