[Generate] Remove attention_mask and integrate model_main_input_name (#14856)
* up * save * correct * up * correct more * up * up * up * up * up * correct * fix tf * fix * remove tokenizer
This commit is contained in:
committed by
GitHub
parent
86b40073e9
commit
fe4197ab11
@@ -64,14 +64,7 @@ class EncoderDecoderMixin:
|
||||
pass
|
||||
|
||||
def check_encoder_decoder_model_from_pretrained_configs(
|
||||
self,
|
||||
config,
|
||||
attention_mask,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
pixel_values=None,
|
||||
**kwargs
|
||||
self, config, decoder_config, decoder_input_ids, decoder_attention_mask, pixel_values=None, **kwargs
|
||||
):
|
||||
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
self.assertTrue(encoder_decoder_config.decoder.is_decoder)
|
||||
@@ -84,7 +77,6 @@ class EncoderDecoderMixin:
|
||||
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
@@ -94,14 +86,7 @@ class EncoderDecoderMixin:
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model(
|
||||
self,
|
||||
config,
|
||||
attention_mask,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
pixel_values=None,
|
||||
**kwargs
|
||||
self, config, decoder_config, decoder_input_ids, decoder_attention_mask, pixel_values=None, **kwargs
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
@@ -111,7 +96,6 @@ class EncoderDecoderMixin:
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_hidden_states=True,
|
||||
@@ -122,7 +106,6 @@ class EncoderDecoderMixin:
|
||||
encoder_outputs = BaseModelOutput(last_hidden_state=outputs_encoder_decoder.encoder_hidden_states[-1])
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
encoder_outputs=encoder_outputs,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
@@ -134,7 +117,6 @@ class EncoderDecoderMixin:
|
||||
def check_encoder_decoder_model_from_pretrained(
|
||||
self,
|
||||
config,
|
||||
attention_mask,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
@@ -148,7 +130,6 @@ class EncoderDecoderMixin:
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_hidden_states=True,
|
||||
@@ -160,14 +141,7 @@ class EncoderDecoderMixin:
|
||||
)
|
||||
|
||||
def check_save_and_load(
|
||||
self,
|
||||
config,
|
||||
attention_mask,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
pixel_values=None,
|
||||
**kwargs
|
||||
self, config, decoder_config, decoder_input_ids, decoder_attention_mask, pixel_values=None, **kwargs
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
@@ -176,7 +150,6 @@ class EncoderDecoderMixin:
|
||||
with torch.no_grad():
|
||||
outputs = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
@@ -190,7 +163,6 @@ class EncoderDecoderMixin:
|
||||
|
||||
after_outputs = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
@@ -200,14 +172,7 @@ class EncoderDecoderMixin:
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
def check_save_and_load_encoder_decoder_model(
|
||||
self,
|
||||
config,
|
||||
attention_mask,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
pixel_values=None,
|
||||
**kwargs
|
||||
self, config, decoder_config, decoder_input_ids, decoder_attention_mask, pixel_values=None, **kwargs
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
@@ -216,7 +181,6 @@ class EncoderDecoderMixin:
|
||||
with torch.no_grad():
|
||||
outputs = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
@@ -233,7 +197,6 @@ class EncoderDecoderMixin:
|
||||
|
||||
after_outputs = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
@@ -245,7 +208,6 @@ class EncoderDecoderMixin:
|
||||
def check_encoder_decoder_model_output_attentions(
|
||||
self,
|
||||
config,
|
||||
attention_mask,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
@@ -261,7 +223,6 @@ class EncoderDecoderMixin:
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=True,
|
||||
@@ -382,13 +343,10 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
]
|
||||
)
|
||||
# for DEiT, the sequence length is equal to the number of patches + 2 (for the [CLS] and distillation tokens)
|
||||
seq_len = (model.encoder.config.image_size // model.encoder.config.patch_size) ** 2 + 2
|
||||
attention_mask = random_attention_mask([batch_size, seq_len])
|
||||
decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size)
|
||||
decoder_attention_mask = random_attention_mask([batch_size, 4])
|
||||
inputs = {
|
||||
"pixel_values": pixel_values,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
@@ -398,7 +356,6 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
def check_encoder_decoder_model_output_attentions(
|
||||
self,
|
||||
config,
|
||||
attention_mask,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
@@ -414,7 +371,6 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=True,
|
||||
@@ -463,7 +419,6 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
encoder_config_and_inputs = deit_model_tester.prepare_config_and_inputs()
|
||||
decoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs_for_decoder()
|
||||
config, pixel_values, _ = encoder_config_and_inputs
|
||||
input_mask = None # TODO add once attention_mask is supported for vision models
|
||||
(
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
@@ -481,7 +436,6 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
return {
|
||||
"config": config,
|
||||
"pixel_values": pixel_values,
|
||||
"attention_mask": input_mask,
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_token_type_ids": decoder_token_type_ids,
|
||||
@@ -509,13 +463,10 @@ class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
]
|
||||
)
|
||||
# for ViT, the sequence length is equal to the number of patches + 1 (for the [CLS] token)
|
||||
seq_len = (model.encoder.config.image_size // model.encoder.config.patch_size) ** 2 + 1
|
||||
attention_mask = random_attention_mask([batch_size, seq_len])
|
||||
decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size)
|
||||
decoder_attention_mask = random_attention_mask([batch_size, 4])
|
||||
inputs = {
|
||||
"pixel_values": pixel_values,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
@@ -534,7 +485,6 @@ class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
decoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs_for_decoder()
|
||||
|
||||
config, pixel_values, _ = encoder_config_and_inputs
|
||||
input_mask = None # TODO add once attention_mask is supported for vision models
|
||||
|
||||
(
|
||||
decoder_config,
|
||||
@@ -553,7 +503,6 @@ class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
return {
|
||||
"config": config,
|
||||
"pixel_values": pixel_values,
|
||||
"attention_mask": input_mask,
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_token_type_ids": decoder_token_type_ids,
|
||||
@@ -580,7 +529,6 @@ class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase):
|
||||
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
|
||||
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs()
|
||||
config, pixel_values, _ = encoder_config_and_inputs
|
||||
input_mask = None # TODO add once attention_mask is supported for vision models
|
||||
(decoder_config, decoder_input_ids, decoder_attention_mask, _) = decoder_config_and_inputs
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
@@ -590,7 +538,6 @@ class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase):
|
||||
return {
|
||||
"config": config,
|
||||
"pixel_values": pixel_values,
|
||||
"attention_mask": input_mask,
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
|
||||
Reference in New Issue
Block a user