Improve perceiver (#14750)
* First draft * Improve docstring + clean up tests * Remove unused code * Add check in case one doesn't provide a preprocessor
This commit is contained in:
@@ -42,7 +42,8 @@ class PerceiverConfig(PretrainedConfig):
|
|||||||
d_latents (:obj:`int`, `optional`, defaults to 1280):
|
d_latents (:obj:`int`, `optional`, defaults to 1280):
|
||||||
Dimension of the latent embeddings.
|
Dimension of the latent embeddings.
|
||||||
d_model (:obj:`int`, `optional`, defaults to 768):
|
d_model (:obj:`int`, `optional`, defaults to 768):
|
||||||
Dimension of the inputs.
|
Dimension of the inputs. Should only be provided in case [`PerceiverTextPreprocessor`] is used or no
|
||||||
|
preprocessor is provided.
|
||||||
num_blocks (:obj:`int`, `optional`, defaults to 1):
|
num_blocks (:obj:`int`, `optional`, defaults to 1):
|
||||||
Number of blocks in the Transformer encoder.
|
Number of blocks in the Transformer encoder.
|
||||||
num_self_attends_per_block (:obj:`int`, `optional`, defaults to 26):
|
num_self_attends_per_block (:obj:`int`, `optional`, defaults to 26):
|
||||||
|
|||||||
@@ -499,7 +499,7 @@ class PerceiverLayer(nn.Module):
|
|||||||
class PerceiverEncoder(nn.Module):
|
class PerceiverEncoder(nn.Module):
|
||||||
"""The Perceiver Encoder: a scalable, fully attentional encoder."""
|
"""The Perceiver Encoder: a scalable, fully attentional encoder."""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config, kv_dim=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
@@ -523,7 +523,7 @@ class PerceiverEncoder(nn.Module):
|
|||||||
v_channels=config.v_channels,
|
v_channels=config.v_channels,
|
||||||
num_heads=config.num_cross_attention_heads,
|
num_heads=config.num_cross_attention_heads,
|
||||||
q_dim=config.d_latents,
|
q_dim=config.d_latents,
|
||||||
kv_dim=config.d_model,
|
kv_dim=kv_dim,
|
||||||
widening_factor=config.cross_attention_widening_factor,
|
widening_factor=config.cross_attention_widening_factor,
|
||||||
use_query_residual=config.use_query_residual,
|
use_query_residual=config.use_query_residual,
|
||||||
)
|
)
|
||||||
@@ -734,7 +734,9 @@ class PerceiverModel(PerceiverPreTrainedModel):
|
|||||||
self.input_preprocessor = input_preprocessor
|
self.input_preprocessor = input_preprocessor
|
||||||
self.output_postprocessor = output_postprocessor
|
self.output_postprocessor = output_postprocessor
|
||||||
self.embeddings = PerceiverEmbeddings(config)
|
self.embeddings = PerceiverEmbeddings(config)
|
||||||
self.encoder = PerceiverEncoder(config)
|
self.encoder = PerceiverEncoder(
|
||||||
|
config, kv_dim=input_preprocessor.num_channels if input_preprocessor is not None else config.d_model
|
||||||
|
)
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
@@ -782,16 +784,13 @@ class PerceiverModel(PerceiverPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
modality_sizes = None
|
modality_sizes = None
|
||||||
inputs_without_pos = None
|
inputs_without_pos = None
|
||||||
|
if inputs.size()[-1] != self.config.d_model:
|
||||||
|
raise ValueError(
|
||||||
|
f"Last dimension of the inputs: {inputs.size()[-1]} doesn't correspond to config.d_model: {self.config.d_model}. "
|
||||||
|
"Make sure to set config.d_model appropriately."
|
||||||
|
)
|
||||||
|
|
||||||
if inputs.size()[-1] != self.config.d_model:
|
batch_size, seq_length, _ = inputs.size()
|
||||||
raise ValueError(
|
|
||||||
f"Last dimension of the inputs: {inputs.size()[-1]} doesn't correspond to config.d_model: {self.config.d_model}. "
|
|
||||||
"Please update config.d_model appropriately."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
input_shape = inputs.size()
|
|
||||||
|
|
||||||
batch_size, seq_length, _ = input_shape
|
|
||||||
device = inputs.device
|
device = inputs.device
|
||||||
|
|
||||||
# If no attention mask is provided, make them all ones
|
# If no attention mask is provided, make them all ones
|
||||||
@@ -874,20 +873,22 @@ class PerceiverForMaskedLM(PerceiverPreTrainedModel):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
text_preprocessor = PerceiverTextPreprocessor(config)
|
||||||
|
|
||||||
trainable_position_encoding_kwargs_decoder = dict(
|
trainable_position_encoding_kwargs_decoder = dict(
|
||||||
num_channels=config.d_model, index_dims=config.max_position_embeddings
|
num_channels=text_preprocessor.num_channels, index_dims=config.max_position_embeddings
|
||||||
)
|
)
|
||||||
|
|
||||||
self.perceiver = PerceiverModel(
|
self.perceiver = PerceiverModel(
|
||||||
config,
|
config,
|
||||||
input_preprocessor=PerceiverTextPreprocessor(config),
|
input_preprocessor=text_preprocessor,
|
||||||
decoder=PerceiverBasicDecoder(
|
decoder=PerceiverBasicDecoder(
|
||||||
config,
|
config,
|
||||||
output_num_channels=config.d_latents,
|
output_num_channels=config.d_latents,
|
||||||
output_index_dims=config.max_position_embeddings, # we need to define the seq_len of the inputs beforehand
|
output_index_dims=config.max_position_embeddings, # we need to define the seq_len of the inputs beforehand
|
||||||
num_channels=config.d_model,
|
num_channels=text_preprocessor.num_channels,
|
||||||
qk_channels=8 * 32,
|
qk_channels=8 * 32,
|
||||||
v_channels=config.d_model,
|
v_channels=text_preprocessor.num_channels,
|
||||||
num_heads=8,
|
num_heads=8,
|
||||||
use_query_residual=False,
|
use_query_residual=False,
|
||||||
final_project=False,
|
final_project=False,
|
||||||
@@ -1502,22 +1503,24 @@ class PerceiverForOpticalFlow(PerceiverPreTrainedModel):
|
|||||||
concat_pos=True, max_resolution=config.train_size, num_bands=64, sine_only=False
|
concat_pos=True, max_resolution=config.train_size, num_bands=64, sine_only=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
image_preprocessor = PerceiverImagePreprocessor(
|
||||||
|
config,
|
||||||
|
prep_type="patches",
|
||||||
|
spatial_downsample=1,
|
||||||
|
conv_after_patching=True,
|
||||||
|
conv_after_patching_in_channels=54,
|
||||||
|
temporal_downsample=2,
|
||||||
|
position_encoding_type="fourier",
|
||||||
|
# position_encoding_kwargs
|
||||||
|
fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor,
|
||||||
|
)
|
||||||
|
|
||||||
self.perceiver = PerceiverModel(
|
self.perceiver = PerceiverModel(
|
||||||
config,
|
config,
|
||||||
input_preprocessor=PerceiverImagePreprocessor(
|
input_preprocessor=image_preprocessor,
|
||||||
config,
|
|
||||||
prep_type="patches",
|
|
||||||
spatial_downsample=1,
|
|
||||||
conv_after_patching=True,
|
|
||||||
conv_after_patching_in_channels=54,
|
|
||||||
temporal_downsample=2,
|
|
||||||
position_encoding_type="fourier",
|
|
||||||
# position_encoding_kwargs
|
|
||||||
fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor,
|
|
||||||
),
|
|
||||||
decoder=PerceiverOpticalFlowDecoder(
|
decoder=PerceiverOpticalFlowDecoder(
|
||||||
config,
|
config,
|
||||||
num_channels=config.d_model,
|
num_channels=image_preprocessor.num_channels,
|
||||||
output_image_shape=config.train_size,
|
output_image_shape=config.train_size,
|
||||||
rescale_factor=100.0,
|
rescale_factor=100.0,
|
||||||
# decoder kwargs
|
# decoder kwargs
|
||||||
@@ -2631,6 +2634,7 @@ class PerceiverTextPreprocessor(AbstractPreprocessor):
|
|||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
self.embeddings = nn.Embedding(num_embeddings=config.vocab_size, embedding_dim=config.d_model)
|
self.embeddings = nn.Embedding(num_embeddings=config.vocab_size, embedding_dim=config.d_model)
|
||||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model)
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model)
|
||||||
|
|
||||||
|
|||||||
@@ -147,19 +147,14 @@ class PerceiverModelTester:
|
|||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||||
elif model_class.__name__ == "PerceiverForImageClassificationLearned":
|
elif model_class.__name__ == "PerceiverForImageClassificationLearned":
|
||||||
config.d_model = 512
|
|
||||||
inputs = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
inputs = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||||
elif model_class.__name__ == "PerceiverForImageClassificationFourier":
|
elif model_class.__name__ == "PerceiverForImageClassificationFourier":
|
||||||
config.d_model = 261
|
|
||||||
inputs = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
inputs = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||||
elif model_class.__name__ == "PerceiverForImageClassificationConvProcessing":
|
elif model_class.__name__ == "PerceiverForImageClassificationConvProcessing":
|
||||||
config.d_model = 322
|
|
||||||
inputs = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
inputs = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||||
elif model_class.__name__ == "PerceiverForOpticalFlow":
|
elif model_class.__name__ == "PerceiverForOpticalFlow":
|
||||||
config.d_model = 322
|
|
||||||
inputs = floats_tensor([self.batch_size, 2, 27, self.train_size[0], self.train_size[1]])
|
inputs = floats_tensor([self.batch_size, 2, 27, self.train_size[0], self.train_size[1]])
|
||||||
elif model_class.__name__ == "PerceiverForMultimodalAutoencoding":
|
elif model_class.__name__ == "PerceiverForMultimodalAutoencoding":
|
||||||
config.d_model = 409
|
|
||||||
images = torch.randn(
|
images = torch.randn(
|
||||||
(self.batch_size, self.num_frames, self.num_channels, self.image_size, self.image_size),
|
(self.batch_size, self.num_frames, self.num_channels, self.image_size, self.image_size),
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
@@ -211,8 +206,6 @@ class PerceiverModelTester:
|
|||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||||
|
|
||||||
def create_and_check_for_sequence_classification(self, config, inputs, input_mask, sequence_labels, token_labels):
|
def create_and_check_for_sequence_classification(self, config, inputs, input_mask, sequence_labels, token_labels):
|
||||||
# set num_labels
|
|
||||||
config.num_labels = self.num_labels
|
|
||||||
model = PerceiverForSequenceClassification(config=config)
|
model = PerceiverForSequenceClassification(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -222,9 +215,6 @@ class PerceiverModelTester:
|
|||||||
def create_and_check_for_image_classification_learned(
|
def create_and_check_for_image_classification_learned(
|
||||||
self, config, inputs, input_mask, sequence_labels, token_labels
|
self, config, inputs, input_mask, sequence_labels, token_labels
|
||||||
):
|
):
|
||||||
# set d_model and num_labels
|
|
||||||
config.d_model = 512
|
|
||||||
config.num_labels = self.num_labels
|
|
||||||
model = PerceiverForImageClassificationLearned(config=config)
|
model = PerceiverForImageClassificationLearned(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -234,9 +224,6 @@ class PerceiverModelTester:
|
|||||||
def create_and_check_for_image_classification_fourier(
|
def create_and_check_for_image_classification_fourier(
|
||||||
self, config, inputs, input_mask, sequence_labels, token_labels
|
self, config, inputs, input_mask, sequence_labels, token_labels
|
||||||
):
|
):
|
||||||
# set d_model and num_labels
|
|
||||||
config.d_model = 261
|
|
||||||
config.num_labels = self.num_labels
|
|
||||||
model = PerceiverForImageClassificationFourier(config=config)
|
model = PerceiverForImageClassificationFourier(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -246,9 +233,6 @@ class PerceiverModelTester:
|
|||||||
def create_and_check_for_image_classification_conv(
|
def create_and_check_for_image_classification_conv(
|
||||||
self, config, inputs, input_mask, sequence_labels, token_labels
|
self, config, inputs, input_mask, sequence_labels, token_labels
|
||||||
):
|
):
|
||||||
# set d_model and num_labels
|
|
||||||
config.d_model = 322
|
|
||||||
config.num_labels = self.num_labels
|
|
||||||
model = PerceiverForImageClassificationConvProcessing(config=config)
|
model = PerceiverForImageClassificationConvProcessing(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|||||||
Reference in New Issue
Block a user