From e926ea2bdde094905d15bb512d4d18667948b24f Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Mon, 13 Dec 2021 18:46:49 +0100 Subject: [PATCH] Improve perceiver (#14750) * First draft * Improve docstring + clean up tests * Remove unused code * Add check in case one doesn't provide a preprocessor --- .../perceiver/configuration_perceiver.py | 3 +- .../models/perceiver/modeling_perceiver.py | 60 ++++++++++--------- tests/test_modeling_perceiver.py | 16 ----- 3 files changed, 34 insertions(+), 45 deletions(-) diff --git a/src/transformers/models/perceiver/configuration_perceiver.py b/src/transformers/models/perceiver/configuration_perceiver.py index 16bbb6994d..849f2413de 100644 --- a/src/transformers/models/perceiver/configuration_perceiver.py +++ b/src/transformers/models/perceiver/configuration_perceiver.py @@ -42,7 +42,8 @@ class PerceiverConfig(PretrainedConfig): d_latents (:obj:`int`, `optional`, defaults to 1280): Dimension of the latent embeddings. d_model (:obj:`int`, `optional`, defaults to 768): - Dimension of the inputs. + Dimension of the inputs. Should only be provided in case [`PerceiverTextPreprocessor`] is used or no + preprocessor is provided. num_blocks (:obj:`int`, `optional`, defaults to 1): Number of blocks in the Transformer encoder. num_self_attends_per_block (:obj:`int`, `optional`, defaults to 26): diff --git a/src/transformers/models/perceiver/modeling_perceiver.py b/src/transformers/models/perceiver/modeling_perceiver.py index 5ab02adb9e..c0d3a69b35 100755 --- a/src/transformers/models/perceiver/modeling_perceiver.py +++ b/src/transformers/models/perceiver/modeling_perceiver.py @@ -499,7 +499,7 @@ class PerceiverLayer(nn.Module): class PerceiverEncoder(nn.Module): """The Perceiver Encoder: a scalable, fully attentional encoder.""" - def __init__(self, config): + def __init__(self, config, kv_dim=None): super().__init__() self.config = config @@ -523,7 +523,7 @@ class PerceiverEncoder(nn.Module): v_channels=config.v_channels, num_heads=config.num_cross_attention_heads, q_dim=config.d_latents, - kv_dim=config.d_model, + kv_dim=kv_dim, widening_factor=config.cross_attention_widening_factor, use_query_residual=config.use_query_residual, ) @@ -734,7 +734,9 @@ class PerceiverModel(PerceiverPreTrainedModel): self.input_preprocessor = input_preprocessor self.output_postprocessor = output_postprocessor self.embeddings = PerceiverEmbeddings(config) - self.encoder = PerceiverEncoder(config) + self.encoder = PerceiverEncoder( + config, kv_dim=input_preprocessor.num_channels if input_preprocessor is not None else config.d_model + ) self.decoder = decoder # Initialize weights and apply final processing @@ -782,16 +784,13 @@ class PerceiverModel(PerceiverPreTrainedModel): else: modality_sizes = None inputs_without_pos = None + if inputs.size()[-1] != self.config.d_model: + raise ValueError( + f"Last dimension of the inputs: {inputs.size()[-1]} doesn't correspond to config.d_model: {self.config.d_model}. " + "Make sure to set config.d_model appropriately." + ) - if inputs.size()[-1] != self.config.d_model: - raise ValueError( - f"Last dimension of the inputs: {inputs.size()[-1]} doesn't correspond to config.d_model: {self.config.d_model}. " - "Please update config.d_model appropriately." - ) - else: - input_shape = inputs.size() - - batch_size, seq_length, _ = input_shape + batch_size, seq_length, _ = inputs.size() device = inputs.device # If no attention mask is provided, make them all ones @@ -874,20 +873,22 @@ class PerceiverForMaskedLM(PerceiverPreTrainedModel): def __init__(self, config): super().__init__(config) + text_preprocessor = PerceiverTextPreprocessor(config) + trainable_position_encoding_kwargs_decoder = dict( - num_channels=config.d_model, index_dims=config.max_position_embeddings + num_channels=text_preprocessor.num_channels, index_dims=config.max_position_embeddings ) self.perceiver = PerceiverModel( config, - input_preprocessor=PerceiverTextPreprocessor(config), + input_preprocessor=text_preprocessor, decoder=PerceiverBasicDecoder( config, output_num_channels=config.d_latents, output_index_dims=config.max_position_embeddings, # we need to define the seq_len of the inputs beforehand - num_channels=config.d_model, + num_channels=text_preprocessor.num_channels, qk_channels=8 * 32, - v_channels=config.d_model, + v_channels=text_preprocessor.num_channels, num_heads=8, use_query_residual=False, final_project=False, @@ -1502,22 +1503,24 @@ class PerceiverForOpticalFlow(PerceiverPreTrainedModel): concat_pos=True, max_resolution=config.train_size, num_bands=64, sine_only=False ) + image_preprocessor = PerceiverImagePreprocessor( + config, + prep_type="patches", + spatial_downsample=1, + conv_after_patching=True, + conv_after_patching_in_channels=54, + temporal_downsample=2, + position_encoding_type="fourier", + # position_encoding_kwargs + fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor, + ) + self.perceiver = PerceiverModel( config, - input_preprocessor=PerceiverImagePreprocessor( - config, - prep_type="patches", - spatial_downsample=1, - conv_after_patching=True, - conv_after_patching_in_channels=54, - temporal_downsample=2, - position_encoding_type="fourier", - # position_encoding_kwargs - fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor, - ), + input_preprocessor=image_preprocessor, decoder=PerceiverOpticalFlowDecoder( config, - num_channels=config.d_model, + num_channels=image_preprocessor.num_channels, output_image_shape=config.train_size, rescale_factor=100.0, # decoder kwargs @@ -2631,6 +2634,7 @@ class PerceiverTextPreprocessor(AbstractPreprocessor): def __init__(self, config): super().__init__() + self.config = config self.embeddings = nn.Embedding(num_embeddings=config.vocab_size, embedding_dim=config.d_model) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model) diff --git a/tests/test_modeling_perceiver.py b/tests/test_modeling_perceiver.py index f0a8b1b181..d6fba44c58 100644 --- a/tests/test_modeling_perceiver.py +++ b/tests/test_modeling_perceiver.py @@ -147,19 +147,14 @@ class PerceiverModelTester: if self.use_input_mask: input_mask = random_attention_mask([self.batch_size, self.seq_length]) elif model_class.__name__ == "PerceiverForImageClassificationLearned": - config.d_model = 512 inputs = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) elif model_class.__name__ == "PerceiverForImageClassificationFourier": - config.d_model = 261 inputs = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) elif model_class.__name__ == "PerceiverForImageClassificationConvProcessing": - config.d_model = 322 inputs = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) elif model_class.__name__ == "PerceiverForOpticalFlow": - config.d_model = 322 inputs = floats_tensor([self.batch_size, 2, 27, self.train_size[0], self.train_size[1]]) elif model_class.__name__ == "PerceiverForMultimodalAutoencoding": - config.d_model = 409 images = torch.randn( (self.batch_size, self.num_frames, self.num_channels, self.image_size, self.image_size), device=torch_device, @@ -211,8 +206,6 @@ class PerceiverModelTester: self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) def create_and_check_for_sequence_classification(self, config, inputs, input_mask, sequence_labels, token_labels): - # set num_labels - config.num_labels = self.num_labels model = PerceiverForSequenceClassification(config=config) model.to(torch_device) model.eval() @@ -222,9 +215,6 @@ class PerceiverModelTester: def create_and_check_for_image_classification_learned( self, config, inputs, input_mask, sequence_labels, token_labels ): - # set d_model and num_labels - config.d_model = 512 - config.num_labels = self.num_labels model = PerceiverForImageClassificationLearned(config=config) model.to(torch_device) model.eval() @@ -234,9 +224,6 @@ class PerceiverModelTester: def create_and_check_for_image_classification_fourier( self, config, inputs, input_mask, sequence_labels, token_labels ): - # set d_model and num_labels - config.d_model = 261 - config.num_labels = self.num_labels model = PerceiverForImageClassificationFourier(config=config) model.to(torch_device) model.eval() @@ -246,9 +233,6 @@ class PerceiverModelTester: def create_and_check_for_image_classification_conv( self, config, inputs, input_mask, sequence_labels, token_labels ): - # set d_model and num_labels - config.d_model = 322 - config.num_labels = self.num_labels model = PerceiverForImageClassificationConvProcessing(config=config) model.to(torch_device) model.eval()