fix (#6946)
This commit is contained in:
committed by
GitHub
parent
a75e319819
commit
e3990d137a
@@ -120,8 +120,8 @@ class LxmertModelTester:
|
|||||||
|
|
||||||
output_attentions = self.output_attentions
|
output_attentions = self.output_attentions
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], vocab_size=self.vocab_size)
|
input_ids = ids_tensor([self.batch_size, self.seq_length], vocab_size=self.vocab_size)
|
||||||
visual_feats = torch.rand(self.batch_size, self.num_visual_features, self.visual_feat_dim)
|
visual_feats = torch.rand(self.batch_size, self.num_visual_features, self.visual_feat_dim, device=torch_device)
|
||||||
bounding_boxes = torch.rand(self.batch_size, self.num_visual_features, 4)
|
bounding_boxes = torch.rand(self.batch_size, self.num_visual_features, 4, device=torch_device)
|
||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_lang_mask:
|
if self.use_lang_mask:
|
||||||
@@ -407,8 +407,8 @@ class LxmertModelTester:
|
|||||||
num_small_labels = int(config.num_qa_labels * 2)
|
num_small_labels = int(config.num_qa_labels * 2)
|
||||||
less_labels_ans = ids_tensor([self.batch_size], num_small_labels)
|
less_labels_ans = ids_tensor([self.batch_size], num_small_labels)
|
||||||
more_labels_ans = ids_tensor([self.batch_size], num_large_labels)
|
more_labels_ans = ids_tensor([self.batch_size], num_large_labels)
|
||||||
model_pretrain = LxmertForPreTraining(config=config)
|
model_pretrain = LxmertForPreTraining(config=config).to(torch_device)
|
||||||
model_qa = LxmertForQuestionAnswering(config=config)
|
model_qa = LxmertForQuestionAnswering(config=config).to(torch_device)
|
||||||
config.num_labels = num_small_labels
|
config.num_labels = num_small_labels
|
||||||
end_labels = config.num_labels
|
end_labels = config.num_labels
|
||||||
|
|
||||||
@@ -560,6 +560,7 @@ class LxmertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
model = LxmertModel.from_pretrained(model_name)
|
model = LxmertModel.from_pretrained(model_name)
|
||||||
|
model.to(torch_device)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
def test_attention_outputs(self):
|
def test_attention_outputs(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user