Improve perceiver (#14750)

* First draft

* Improve docstring + clean up tests

* Remove unused code

* Add check in case one doesn't provide a preprocessor
This commit is contained in:
NielsRogge
2021-12-13 18:46:49 +01:00
committed by GitHub
parent 971e36667a
commit e926ea2bdd
3 changed files with 34 additions and 45 deletions

View File

@@ -147,19 +147,14 @@ class PerceiverModelTester:
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
elif model_class.__name__ == "PerceiverForImageClassificationLearned":
config.d_model = 512
inputs = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
elif model_class.__name__ == "PerceiverForImageClassificationFourier":
config.d_model = 261
inputs = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
elif model_class.__name__ == "PerceiverForImageClassificationConvProcessing":
config.d_model = 322
inputs = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
elif model_class.__name__ == "PerceiverForOpticalFlow":
config.d_model = 322
inputs = floats_tensor([self.batch_size, 2, 27, self.train_size[0], self.train_size[1]])
elif model_class.__name__ == "PerceiverForMultimodalAutoencoding":
config.d_model = 409
images = torch.randn(
(self.batch_size, self.num_frames, self.num_channels, self.image_size, self.image_size),
device=torch_device,
@@ -211,8 +206,6 @@ class PerceiverModelTester:
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_for_sequence_classification(self, config, inputs, input_mask, sequence_labels, token_labels):
# set num_labels
config.num_labels = self.num_labels
model = PerceiverForSequenceClassification(config=config)
model.to(torch_device)
model.eval()
@@ -222,9 +215,6 @@ class PerceiverModelTester:
def create_and_check_for_image_classification_learned(
self, config, inputs, input_mask, sequence_labels, token_labels
):
# set d_model and num_labels
config.d_model = 512
config.num_labels = self.num_labels
model = PerceiverForImageClassificationLearned(config=config)
model.to(torch_device)
model.eval()
@@ -234,9 +224,6 @@ class PerceiverModelTester:
def create_and_check_for_image_classification_fourier(
self, config, inputs, input_mask, sequence_labels, token_labels
):
# set d_model and num_labels
config.d_model = 261
config.num_labels = self.num_labels
model = PerceiverForImageClassificationFourier(config=config)
model.to(torch_device)
model.eval()
@@ -246,9 +233,6 @@ class PerceiverModelTester:
def create_and_check_for_image_classification_conv(
self, config, inputs, input_mask, sequence_labels, token_labels
):
# set d_model and num_labels
config.d_model = 322
config.num_labels = self.num_labels
model = PerceiverForImageClassificationConvProcessing(config=config)
model.to(torch_device)
model.eval()