Improve perceiver (#14750)
* First draft * Improve docstring + clean up tests * Remove unused code * Add check in case one doesn't provide a preprocessor
This commit is contained in:
@@ -147,19 +147,14 @@ class PerceiverModelTester:
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
elif model_class.__name__ == "PerceiverForImageClassificationLearned":
|
||||
config.d_model = 512
|
||||
inputs = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
elif model_class.__name__ == "PerceiverForImageClassificationFourier":
|
||||
config.d_model = 261
|
||||
inputs = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
elif model_class.__name__ == "PerceiverForImageClassificationConvProcessing":
|
||||
config.d_model = 322
|
||||
inputs = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
elif model_class.__name__ == "PerceiverForOpticalFlow":
|
||||
config.d_model = 322
|
||||
inputs = floats_tensor([self.batch_size, 2, 27, self.train_size[0], self.train_size[1]])
|
||||
elif model_class.__name__ == "PerceiverForMultimodalAutoencoding":
|
||||
config.d_model = 409
|
||||
images = torch.randn(
|
||||
(self.batch_size, self.num_frames, self.num_channels, self.image_size, self.image_size),
|
||||
device=torch_device,
|
||||
@@ -211,8 +206,6 @@ class PerceiverModelTester:
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_for_sequence_classification(self, config, inputs, input_mask, sequence_labels, token_labels):
|
||||
# set num_labels
|
||||
config.num_labels = self.num_labels
|
||||
model = PerceiverForSequenceClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -222,9 +215,6 @@ class PerceiverModelTester:
|
||||
def create_and_check_for_image_classification_learned(
|
||||
self, config, inputs, input_mask, sequence_labels, token_labels
|
||||
):
|
||||
# set d_model and num_labels
|
||||
config.d_model = 512
|
||||
config.num_labels = self.num_labels
|
||||
model = PerceiverForImageClassificationLearned(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -234,9 +224,6 @@ class PerceiverModelTester:
|
||||
def create_and_check_for_image_classification_fourier(
|
||||
self, config, inputs, input_mask, sequence_labels, token_labels
|
||||
):
|
||||
# set d_model and num_labels
|
||||
config.d_model = 261
|
||||
config.num_labels = self.num_labels
|
||||
model = PerceiverForImageClassificationFourier(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -246,9 +233,6 @@ class PerceiverModelTester:
|
||||
def create_and_check_for_image_classification_conv(
|
||||
self, config, inputs, input_mask, sequence_labels, token_labels
|
||||
):
|
||||
# set d_model and num_labels
|
||||
config.d_model = 322
|
||||
config.num_labels = self.num_labels
|
||||
model = PerceiverForImageClassificationConvProcessing(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
Reference in New Issue
Block a user