From 28d0048218ad7bce69510b16024510afba0daed2 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 31 May 2022 10:02:55 +0200 Subject: [PATCH] Fx support for multiple model architectures (#17393) * Support for Bart and LayoutLM, and partial support for XLNet * Support for mbart * A lot of new models supported * Support for other models * LayoutLM fix * Use strings instead of classes --- src/transformers/models/bart/modeling_bart.py | 6 +- .../modeling_bigbird_pegasus.py | 2 +- .../models/blenderbot/modeling_blenderbot.py | 6 +- .../modeling_blenderbot_small.py | 6 +- src/transformers/models/clip/modeling_clip.py | 8 +- .../models/layoutlm/modeling_layoutlm.py | 2 +- .../models/m2m_100/modeling_m2m_100.py | 6 +- .../models/marian/modeling_marian.py | 6 +- .../models/mbart/modeling_mbart.py | 6 +- src/transformers/models/opt/modeling_opt.py | 6 +- .../models/pegasus/modeling_pegasus.py | 6 +- .../models/plbart/modeling_plbart.py | 6 +- .../speech_to_text/modeling_speech_to_text.py | 6 +- .../modeling_speech_to_text_2.py | 6 +- .../models/trocr/modeling_trocr.py | 6 +- src/transformers/models/xglm/modeling_xglm.py | 8 +- .../models/xlnet/modeling_xlnet.py | 2 +- src/transformers/utils/fx.py | 282 +++++++++++++----- tests/models/bart/test_modeling_bart.py | 2 + .../blenderbot/test_modeling_blenderbot.py | 1 + .../test_modeling_blenderbot_small.py | 1 + tests/models/clip/test_modeling_clip.py | 4 +- .../models/layoutlm/test_modeling_layoutlm.py | 1 + tests/models/m2m_100/test_modeling_m2m_100.py | 1 + tests/models/marian/test_modeling_marian.py | 1 + tests/models/mbart/test_modeling_mbart.py | 1 + tests/models/opt/test_modeling_opt.py | 1 + tests/models/pegasus/test_modeling_pegasus.py | 1 + tests/models/plbart/test_modeling_plbart.py | 1 + .../test_modeling_speech_to_text.py | 106 ++++++- .../test_modeling_speech_to_text_2.py | 1 + tests/models/swin/test_modeling_swin.py | 15 +- tests/models/t5/test_modeling_t5.py | 2 +- tests/models/trocr/test_modeling_trocr.py | 1 + tests/models/xglm/test_modeling_xglm.py | 123 +++++++- tests/models/xlnet/test_modeling_xlnet.py | 1 + tests/test_modeling_common.py | 21 +- 37 files changed, 515 insertions(+), 146 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 7ebb143e22..595f719ba0 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -93,7 +93,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), float("-inf")) + mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) @@ -114,7 +114,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] inverted_mask = 1.0 - expanded_mask - return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) class BartLearnedPositionalEmbedding(nn.Embedding): @@ -911,7 +911,7 @@ class BartDecoder(BartPretrainedModel): if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(self.device) + ).to(inputs_embeds.device) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index c7a84695a7..0f7bc7f599 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2112,7 +2112,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(self.device) + ).to(inputs_embeds.device) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 612685dbb4..574f4c8731 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -83,7 +83,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), float("-inf")) + mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) @@ -105,7 +105,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] inverted_mask = 1.0 - expanded_mask - return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) class BlenderbotLearnedPositionalEmbedding(nn.Embedding): @@ -850,7 +850,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(self.device) + ).to(inputs_embeds.device) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 9b32fccc1f..c65409be55 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -80,7 +80,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), float("-inf")) + mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) @@ -102,7 +102,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] inverted_mask = 1.0 - expanded_mask - return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) # Copied from transformers.models.blenderbot.modeling_blenderbot.BlenderbotLearnedPositionalEmbedding with Blenderbot->BlenderbotSmall @@ -846,7 +846,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(self.device) + ).to(inputs_embeds.device) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 25137e268d..5c34c658b5 100755 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -57,7 +57,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] inverted_mask = 1.0 - expanded_mask - return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) # contrastive loss function, adapted from @@ -674,7 +674,7 @@ class CLIPTextTransformer(nn.Module): # lazily create causal attention mask, with full attention between the vision tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(bsz, seq_len, seq_len) - mask.fill_(float("-inf")) + mask.fill_(torch.tensor(float("-inf"))) mask.triu_(1) # zero out the lower diagonal mask = mask.unsqueeze(1) # expand mask return mask @@ -1042,8 +1042,8 @@ class CLIPModel(CLIPPreTrainedModel): text_embeds = self.text_projection(text_embeds) # normalized features - image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True) - text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) # cosine similarity as logits logit_scale = self.logit_scale.exp() diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 25c9db5d57..2a48ba5f4f 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -800,7 +800,7 @@ class LayoutLMModel(LayoutLMPreTrainedModel): token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) if bbox is None: - bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device) + bbox = torch.zeros(input_shape + (4,), dtype=torch.long, device=device) extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 1dc7f6144c..5ced761677 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -79,7 +79,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), float("-inf")) + mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) @@ -101,7 +101,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] inverted_mask = 1.0 - expanded_mask - return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): @@ -998,7 +998,7 @@ class M2M100Decoder(M2M100PreTrainedModel): if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(self.device) + ).to(inputs_embeds.device) if attention_mask is not None and combined_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index b8f82275a8..04a6c2d83f 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -81,7 +81,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), float("-inf")) + mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) @@ -103,7 +103,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] inverted_mask = 1.0 - expanded_mask - return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) class MarianSinusoidalPositionalEmbedding(nn.Embedding): @@ -856,7 +856,7 @@ class MarianDecoder(MarianPreTrainedModel): if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(self.device) + ).to(inputs_embeds.device) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 72ee66a45b..d7f4958a8d 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -97,7 +97,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), float("-inf")) + mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) @@ -119,7 +119,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] inverted_mask = 1.0 - expanded_mask - return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) # Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->MBart @@ -909,7 +909,7 @@ class MBartDecoder(MBartPreTrainedModel): if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(self.device) + ).to(inputs_embeds.device) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index df005e356a..8de0d1c3c2 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -61,7 +61,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), float("-inf")) + mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) @@ -82,7 +82,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] inverted_mask = 1.0 - expanded_mask - return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) class OPTLearnedPositionalEmbedding(nn.Embedding): @@ -513,7 +513,7 @@ class OPTDecoder(OPTPreTrainedModel): if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(self.device) + ).to(inputs_embeds.device) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 99ff97b269..51620bbf36 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -80,7 +80,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), float("-inf")) + mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) @@ -102,7 +102,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] inverted_mask = 1.0 - expanded_mask - return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) # Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Pegasus @@ -876,7 +876,7 @@ class PegasusDecoder(PegasusPreTrainedModel): if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(self.device) + ).to(inputs_embeds.device) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 8f341e6399..97bf620a9c 100755 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -94,7 +94,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), float("-inf")) + mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) @@ -116,7 +116,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] inverted_mask = 1.0 - expanded_mask - return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) # Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->PLBart @@ -883,7 +883,7 @@ class PLBartDecoder(PLBartPreTrainedModel): if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(self.device) + ).to(inputs_embeds.device) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index a358b13c1f..623a6b5910 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -69,7 +69,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), float("-inf")) + mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) @@ -91,7 +91,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] inverted_mask = 1.0 - expanded_mask - return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) class Conv1dSubsampler(nn.Module): @@ -888,7 +888,7 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel): if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(self.device) + ).to(inputs_embeds.device) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py index 5c0ea65fcc..d90c4c87b6 100755 --- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py @@ -49,7 +49,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), float("-inf")) + mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) @@ -71,7 +71,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] inverted_mask = 1.0 - expanded_mask - return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) # Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextSinusoidalPositionalEmbedding with Speech2Text->Speech2Text2 @@ -495,7 +495,7 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel): if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(self.device) + ).to(inputs_embeds.device) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index 52e4801832..3a960bb86f 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -50,7 +50,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), float("-inf")) + mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) @@ -72,7 +72,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] inverted_mask = 1.0 - expanded_mask - return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) # Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->TrOCR @@ -524,7 +524,7 @@ class TrOCRDecoder(TrOCRPreTrainedModel): if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(self.device) + ).to(inputs_embeds.device) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 4047958d4f..fa2a8c6eb6 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -120,7 +120,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), float("-inf")) + mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) @@ -142,7 +142,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] inverted_mask = 1.0 - expanded_mask - return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): @@ -577,7 +577,7 @@ class XGLMModel(XGLMPreTrainedModel): if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(self.device) + ).to(inputs_embeds.device) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -712,7 +712,7 @@ class XGLMModel(XGLMPreTrainedModel): hidden_states = inputs_embeds + positions - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = nn.functional.dropout(hidden_states, p=float(self.dropout), training=self.training) # decoder layers all_hidden_states = () if output_hidden_states else None diff --git a/src/transformers/models/xlnet/modeling_xlnet.py b/src/transformers/models/xlnet/modeling_xlnet.py index 3226773e7f..4a299a5a65 100755 --- a/src/transformers/models/xlnet/modeling_xlnet.py +++ b/src/transformers/models/xlnet/modeling_xlnet.py @@ -1056,7 +1056,6 @@ class XLNetModel(XLNetPreTrainedModel): fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len) pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz) - pos_emb = pos_emb.to(self.device) return pos_emb @add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @@ -1206,6 +1205,7 @@ class XLNetModel(XLNetPreTrainedModel): # Positional encoding pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz) + pos_emb = pos_emb.to(output_h.device) pos_emb = self.dropout(pos_emb) # Prepare head mask if needed diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 9516253789..bbf32f9c6f 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -29,27 +29,23 @@ from torch import nn from torch.fx import Graph, GraphModule, Proxy, Tracer from torch.fx.proxy import ParameterProxy -from .. import ( - CONFIG_MAPPING, - MODEL_FOR_CAUSAL_LM_MAPPING, - MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, - MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING, - MODEL_FOR_MASKED_LM_MAPPING, - MODEL_FOR_MULTIPLE_CHOICE_MAPPING, - MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, - MODEL_FOR_PRETRAINING_MAPPING, - MODEL_FOR_QUESTION_ANSWERING_MAPPING, - MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, - MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, - MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, - MODEL_MAPPING, - GPT2DoubleHeadsModel, - PretrainedConfig, - PreTrainedModel, - XLNetForQuestionAnswering, - logging, -) +from .. import PretrainedConfig, PreTrainedModel, logging from ..models.auto import get_values +from ..models.auto.modeling_auto import ( + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, + MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES, + MODEL_FOR_MASKED_LM_MAPPING_NAMES, + MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES, + MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES, + MODEL_FOR_PRETRAINING_MAPPING_NAMES, + MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, + MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, + MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, + MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, + MODEL_MAPPING_NAMES, +) from ..utils import TORCH_FX_REQUIRED_VERSION, is_torch_fx_available from ..utils.versions import importlib_metadata @@ -57,25 +53,25 @@ from ..utils.versions import importlib_metadata logger = logging.get_logger(__name__) -def _generate_supported_model_classes( +def _generate_supported_model_class_names( model_name: Type[PretrainedConfig], supported_tasks: Optional[Union[str, List[str]]] = None, -) -> List[Type[PreTrainedModel]]: +) -> List[str]: - model_config_class = CONFIG_MAPPING[model_name] task_mapping = { - "default": MODEL_MAPPING, - "pretraining": MODEL_FOR_PRETRAINING_MAPPING, - "next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, - "masked-lm": MODEL_FOR_MASKED_LM_MAPPING, - "causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING, - "seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, - "multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING, - "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING, - "sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, - "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, - "masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING, - "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, + "default": MODEL_MAPPING_NAMES, + "pretraining": MODEL_FOR_PRETRAINING_MAPPING_NAMES, + "next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES, + "masked-lm": MODEL_FOR_MASKED_LM_MAPPING_NAMES, + "causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + "seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, + "speech-seq2seq": MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, + "multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES, + "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, + "sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, + "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, + "masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES, + "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, } if supported_tasks is None: @@ -83,55 +79,78 @@ def _generate_supported_model_classes( if isinstance(supported_tasks, str): supported_tasks = [supported_tasks] - model_classes = [] + model_class_names = [] for task in supported_tasks: - model_class = task_mapping[task].get(model_config_class, None) - if model_class: - model_classes.append(model_class) + class_name = task_mapping[task].get(model_name, None) + if class_name: + model_class_names.append(class_name) - return model_classes + return model_class_names _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [ "albert", + "bart", "bert", + "blenderbot", + "blenderbot-small", + "clip", "distilbert", - "mobilebert", "electra", - "megatron-bert", "gpt2", - "gptj", "gpt_neo", - "t5", + "gptj", + "layoutlm", + "m2m_100", + "marian", + "mbart", + "megatron-bert", + "mobilebert", + "mt5", + "opt", + "pegasus", + "plbart", "roberta", - "vit", + "speech_to_text", + "speech_to_text_2", "swin", + "t5", + "trocr", + "vit", + "xglm", # TODO: add support for them as it should be quite easy to do so (small blocking issues). - # "layoutlm", # "xlnet", ] _REGULAR_SUPPORTED_MODELS = [] for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS: if isinstance(item, dict): - _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_classes(**item)) + _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(**item)) else: - _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_classes(item)) + _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(item)) _SPECIAL_SUPPORTED_MODELS = [ - GPT2DoubleHeadsModel, + "CLIPTextModel", + "CLIPVisionModel", + "GPT2DoubleHeadsModel", + "Speech2Text2Decoder", + "TrOCRDecoder", # TODO: add support for them as it should be quite easy to do so (small blocking issues). # XLNetForQuestionAnswering, ] -_SUPPORTED_MODELS = tuple( - sorted(list(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)), key=lambda c: c.__name__) -) +_SUPPORTED_MODELS = tuple(sorted(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS))) def torch_nn_embedding(self, input): return torch.empty(*input.shape, self.weight.shape[-1], device="meta") +def torch_nn_functional_embedding( + input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False +): + return torch.empty(*input.shape, weight.shape[-1], device="meta") + + def torch_nn_layernorm(self, input): return input @@ -176,6 +195,12 @@ def torch_arange(*args, **kwargs): start, end = args else: start, end, step = args + if isinstance(start, float): + start = int(start) + if isinstance(end, float): + start = int(end) + if isinstance(step, float): + step = int(step) step = kwargs.get("step", step) dtype = kwargs.get("dtype") return torch.empty((end - start) // step, dtype=dtype, device="meta") @@ -265,6 +290,14 @@ def torch_matmul(input, other, *, out=None): return torch.empty(*shape, device="meta") +def torch_bmm(input, mat2, *, out=None): + if out is not None: + raise ValueError("Don't support in-place abs for MetaTensor analysis") + batch_size, n, m = input.shape + _, _, p = mat2.shape + return torch.empty(batch_size, n, p, device="meta") + + def torch_einsum(equation, *operands): # TODO: infer shape without performing the computation, this might be quite hard. concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands) @@ -285,13 +318,39 @@ def torch_index_select(input, dim, index, *, out=None): def torch_tensor_index_select(self, dim, index): - return torch_tensor_index_select(self, dim, index) + return torch_index_select(self, dim, index) def torch_roll(input, shifts, dims=None): return input +def torch_flip(input, dims): + return input + + +def torch_tensor_flip(self, dims): + return self + + +def torch_nn_conv1d(self, input): + l_in = input.shape[-1] + shape = None + padding = self.padding + if padding == "valid": + padding = (0, 0) + if padding == "same": + shape = list(input.shape) + if shape is None: + shape = list(input.shape) + l_out = math.floor( + (l_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1 + ) + shape[-1] = l_out + shape[-2] = self.out_channels + return torch.empty(shape, device="meta") + + def torch_nn_conv2d(self, input): h_in, w_in = input.shape[-2:] shape = None @@ -325,6 +384,21 @@ def torch_tensor_unsqueeze(self, dim): return torch_unsqueeze(self, dim) +def torch_unique_consecutive(input, **kwargs): + output = torch.unique_consecutive(torch.zeros_like(input, device="cpu"), **kwargs) + if isinstance(output, torch.Tensor): + return output.to("meta") + else: + return tuple(map(output, lambda x: x.to("meta"))) + + +def torch_nn_functional_one_hot(tensor, num_classes=-1): + if num_classes < 0: + raise ValueError("Don't support automatic num_classes inference for MetaTensor analysis") + shape = list(tensor.shape) + [num_classes] + return torch.empty(shape, device="meta") + + def torch_nn_mseloss(self, input, target): if self.reduction == "none": shape = target.shape @@ -350,14 +424,27 @@ def torch_nn_bcewithlogitsloss(self, input, target): def operator_getitem(a, b): + def to_concrete(t): + if isinstance(t, torch.Tensor): + concrete = torch.ones_like(t, device="cpu") + if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]: + concrete = concrete.to(torch.int64) + return concrete + return t + if isinstance(a, torch.Tensor): # TODO: infer shape without performing the computation. + if isinstance(b, tuple): + b = tuple(map(to_concrete, b)) + else: + b = to_concrete(b) return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta") return operator.getitem(a, b) _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = { torch.nn.Embedding: torch_nn_embedding, + torch.nn.functional.embedding: torch_nn_functional_embedding, torch.nn.LayerNorm: torch_nn_layernorm, torch.nn.Linear: torch_nn_linear, torch.relu: torch_relu, @@ -372,15 +459,20 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = { torch.mul: torch_mul, torch.Tensor.mul: torch_tensor_mul, torch.matmul: torch_matmul, + torch.bmm: torch_bmm, torch.einsum: torch_einsum, torch.Tensor.repeat: torch_tensor_repeat, torch.roll: torch_roll, - # TODO: those might not be needed. - # torch.index_select: torch_index_select, - # torch.Tensor.index_select: torch_tensor_index_select, + torch.flip: torch_flip, + torch.Tensor.flip: torch_tensor_flip, + torch.index_select: torch_index_select, + torch.Tensor.index_select: torch_tensor_index_select, + torch.nn.Conv1d: torch_nn_conv1d, torch.nn.Conv2d: torch_nn_conv2d, torch.unsqueeze: torch_unsqueeze, torch.Tensor.unsqueeze: torch_tensor_unsqueeze, + torch.unique_consecutive: torch_unique_consecutive, + torch.nn.functional.one_hot: torch_nn_functional_one_hot, torch.nn.MSELoss: torch_nn_mseloss, torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss, torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss, @@ -513,7 +605,7 @@ class HFTracer(Tracer): # Feature flag for proxying accesses to buffer values proxy_buffer_attributes: bool = True allow_insert_stateless_mods: bool = True - _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full_like", "eye"] + _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty"] def __init__(self, autowrap_modules=(math,), autowrap_functions=()): @@ -532,22 +624,22 @@ class HFTracer(Tracer): """Generates dummy input for model inference recording.""" # Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored # from pickle, or from the "__class__" attribute in the general case. - model_class = getattr(model, "class_for_deserialization", model.__class__) + model_class_name = getattr(model, "class_for_deserialization", model.__class__).__name__ device = model.device inputs_dict = {} if input_name in ["labels", "start_positions", "end_positions"]: batch_size = shape[0] - if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING): + if model_class_name in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES): inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device) - elif model_class in [ - *get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING), - XLNetForQuestionAnswering, + elif model_class_name in [ + *get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES), + "XLNetForQuestionAnswering", ]: inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device) inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device) - elif model_class in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING): + elif model_class_name in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES): if not hasattr(model.config, "problem_type") or model.config.problem_type is None: raise ValueError( "Could not retrieve the problem type for the sequence classification task, please set " @@ -571,32 +663,49 @@ class HFTracer(Tracer): ) inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device) - elif model_class in [ - *get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING), - *get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING), + elif model_class_name in [ + *get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES), + *get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES), ]: inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device) - elif model_class in [ - *get_values(MODEL_FOR_PRETRAINING_MAPPING), - *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING), - *get_values(MODEL_FOR_CAUSAL_LM_MAPPING), - *get_values(MODEL_FOR_MASKED_LM_MAPPING), - *get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING), - GPT2DoubleHeadsModel, + elif model_class_name in [ + *get_values(MODEL_FOR_PRETRAINING_MAPPING_NAMES), + *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES), + *get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES), + *get_values(MODEL_FOR_MASKED_LM_MAPPING_NAMES), + *get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES), + "GPT2DoubleHeadsModel", ]: inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device) else: - raise NotImplementedError(f"{model_class} not supported yet.") + raise NotImplementedError(f"{model_class_name} not supported yet.") elif "pixel_values" in input_name: batch_size = shape[0] - image_size = model.config.image_size + image_size = getattr(model.config, "image_size", None) + if image_size is None: + if hasattr(model.config, "vision_config"): + image_size = model.config.vision_config.image_size + elif hasattr(model.config, "encoder"): + image_size = model.config.encoder.image_size + else: + raise AttributeError('Could not find the "image_size" field in the model config') + + # If no num_channels is in the config, use some arbitrary value. + num_channels = getattr(model.config, "num_channels", 3) if not isinstance(image_size, collections.abc.Iterable): image_size = (image_size, image_size) height, width = image_size inputs_dict[input_name] = torch.zeros( - batch_size, model.config.num_channels, height, width, dtype=torch.float32, device=device + batch_size, num_channels, height, width, dtype=torch.float32, device=device ) - + elif "bbox" in input_name: + inputs_dict[input_name] = torch.zeros(*shape, 4, dtype=torch.float, device=device) + elif "input_features" in input_name: + inputs_dict[input_name] = torch.zeros( + *shape, model.config.input_feat_per_channel, dtype=torch.float, device=device + ) + elif "inputs" in input_name: + inputs_dict[input_name] = torch.zeros(*shape, dtype=torch.float, device=device) elif "mask" in input_name or "ids" in input_name: inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device) else: @@ -628,6 +737,8 @@ class HFTracer(Tracer): if kind == "call_function": meta_target = _MANUAL_META_OVERRIDES.get(target, target) meta_out = meta_target(*args_metas, **kwargs_metas) + if isinstance(meta_out, torch.Tensor): + meta_out = meta_out.to(device="meta") elif kind == "call_method": method = getattr(args_metas[0].__class__, target) meta_target = _MANUAL_META_OVERRIDES.get(method, method) @@ -731,7 +842,7 @@ class HFTracer(Tracer): sequence_length = _generate_random_int() shape = [batch_size, sequence_length] - if root.__class__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING): + if root.__class__.__name__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES): num_choices = _generate_random_int(low=2, high=5) shape.insert(1, num_choices) @@ -870,11 +981,22 @@ def symbolic_trace( if input_names is None: input_names = model.dummy_inputs.keys() + input_names = list(input_names) + sig = inspect.signature(model.forward) + + if not (set(input_names) <= set(sig.parameters.keys())): + formatted_input_names = input_names[0] if len(input_names) == 1 else ", ".join(input_names) + formatted_allowed_input_names = ", ".join(sig.parameters.keys()) + raise ValueError( + f"The model does not have input(s) named: {formatted_input_names}, expected a subset of the following:" + f" {formatted_allowed_input_names}" + ) + concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names} - if not isinstance(model, _SUPPORTED_MODELS): - supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS)) + if model.__class__.__name__ not in _SUPPORTED_MODELS: + supported_model_names = ", ".join(_SUPPORTED_MODELS) raise NotImplementedError( f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}" ) diff --git a/tests/models/bart/test_modeling_bart.py b/tests/models/bart/test_modeling_bart.py index 279204b574..01c4967516 100644 --- a/tests/models/bart/test_modeling_bart.py +++ b/tests/models/bart/test_modeling_bart.py @@ -413,6 +413,7 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ) all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else () is_encoder_decoder = True + fx_compatible = True test_pruning = False test_missing_keys = False @@ -1386,6 +1387,7 @@ class BartStandaloneDecoderModelTester: class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (BartDecoder, BartForCausalLM) if is_torch_available() else () all_generative_model_classes = (BartForCausalLM,) if is_torch_available() else () + fx_comptatible = True test_pruning = False is_encoder_decoder = False diff --git a/tests/models/blenderbot/test_modeling_blenderbot.py b/tests/models/blenderbot/test_modeling_blenderbot.py index e4dbf474d1..ec626d05e8 100644 --- a/tests/models/blenderbot/test_modeling_blenderbot.py +++ b/tests/models/blenderbot/test_modeling_blenderbot.py @@ -218,6 +218,7 @@ class BlenderbotModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test all_model_classes = (BlenderbotModel, BlenderbotForConditionalGeneration) if is_torch_available() else () all_generative_model_classes = (BlenderbotForConditionalGeneration,) if is_torch_available() else () is_encoder_decoder = True + fx_compatible = True test_pruning = False test_missing_keys = False diff --git a/tests/models/blenderbot_small/test_modeling_blenderbot_small.py b/tests/models/blenderbot_small/test_modeling_blenderbot_small.py index 8bc6304e79..47503b9c7f 100644 --- a/tests/models/blenderbot_small/test_modeling_blenderbot_small.py +++ b/tests/models/blenderbot_small/test_modeling_blenderbot_small.py @@ -213,6 +213,7 @@ class BlenderbotSmallModelTest(ModelTesterMixin, GenerationTesterMixin, unittest all_model_classes = (BlenderbotSmallModel, BlenderbotSmallForConditionalGeneration) if is_torch_available() else () all_generative_model_classes = (BlenderbotSmallForConditionalGeneration,) if is_torch_available() else () is_encoder_decoder = True + fx_compatible = True test_pruning = False test_missing_keys = False diff --git a/tests/models/clip/test_modeling_clip.py b/tests/models/clip/test_modeling_clip.py index 02b982e4ef..ab05f9adf1 100644 --- a/tests/models/clip/test_modeling_clip.py +++ b/tests/models/clip/test_modeling_clip.py @@ -152,7 +152,7 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase): """ all_model_classes = (CLIPVisionModel,) if is_torch_available() else () - + fx_compatible = True test_pruning = False test_resize_embeddings = False test_head_masking = False @@ -303,6 +303,7 @@ class CLIPTextModelTester: class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (CLIPTextModel,) if is_torch_available() else () + fx_compatible = True test_pruning = False test_head_masking = False @@ -388,6 +389,7 @@ class CLIPModelTester: @require_torch class CLIPModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (CLIPModel,) if is_torch_available() else () + fx_compatible = True test_head_masking = False test_pruning = False test_resize_embeddings = False diff --git a/tests/models/layoutlm/test_modeling_layoutlm.py b/tests/models/layoutlm/test_modeling_layoutlm.py index ec2190598e..e2d949611d 100644 --- a/tests/models/layoutlm/test_modeling_layoutlm.py +++ b/tests/models/layoutlm/test_modeling_layoutlm.py @@ -215,6 +215,7 @@ class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else None ) + fx_compatible = True def setUp(self): self.model_tester = LayoutLMModelTester(self) diff --git a/tests/models/m2m_100/test_modeling_m2m_100.py b/tests/models/m2m_100/test_modeling_m2m_100.py index 7685e98886..0d5bdc3ca3 100644 --- a/tests/models/m2m_100/test_modeling_m2m_100.py +++ b/tests/models/m2m_100/test_modeling_m2m_100.py @@ -231,6 +231,7 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase ) all_generative_model_classes = (M2M100ForConditionalGeneration,) if is_torch_available() else () is_encoder_decoder = True + fx_compatible = True test_pruning = False test_missing_keys = False diff --git a/tests/models/marian/test_modeling_marian.py b/tests/models/marian/test_modeling_marian.py index 9c119f936a..1039c4a51d 100644 --- a/tests/models/marian/test_modeling_marian.py +++ b/tests/models/marian/test_modeling_marian.py @@ -230,6 +230,7 @@ class MarianModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase all_model_classes = (MarianModel, MarianMTModel) if is_torch_available() else () all_generative_model_classes = (MarianMTModel,) if is_torch_available() else () is_encoder_decoder = True + fx_compatible = True test_pruning = False test_missing_keys = False diff --git a/tests/models/mbart/test_modeling_mbart.py b/tests/models/mbart/test_modeling_mbart.py index 48b9f57a56..6a8eeed9fb 100644 --- a/tests/models/mbart/test_modeling_mbart.py +++ b/tests/models/mbart/test_modeling_mbart.py @@ -224,6 +224,7 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) ) all_generative_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else () is_encoder_decoder = True + fx_compatible = True test_pruning = False test_missing_keys = False diff --git a/tests/models/opt/test_modeling_opt.py b/tests/models/opt/test_modeling_opt.py index 9a64878765..7fb59c6a1f 100644 --- a/tests/models/opt/test_modeling_opt.py +++ b/tests/models/opt/test_modeling_opt.py @@ -178,6 +178,7 @@ class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (OPTModel, OPTForCausalLM) if is_torch_available() else () all_generative_model_classes = (OPTForCausalLM,) if is_torch_available() else () is_encoder_decoder = False + fx_compatible = True test_pruning = False test_missing_keys = False diff --git a/tests/models/pegasus/test_modeling_pegasus.py b/tests/models/pegasus/test_modeling_pegasus.py index a05e34e57c..d5e9d22df1 100644 --- a/tests/models/pegasus/test_modeling_pegasus.py +++ b/tests/models/pegasus/test_modeling_pegasus.py @@ -229,6 +229,7 @@ class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas all_model_classes = (PegasusModel, PegasusForConditionalGeneration) if is_torch_available() else () all_generative_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else () is_encoder_decoder = True + fx_compatible = True test_resize_position_embeddings = True test_pruning = False test_missing_keys = False diff --git a/tests/models/plbart/test_modeling_plbart.py b/tests/models/plbart/test_modeling_plbart.py index 073db546bf..171531503d 100644 --- a/tests/models/plbart/test_modeling_plbart.py +++ b/tests/models/plbart/test_modeling_plbart.py @@ -219,6 +219,7 @@ class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase ) all_generative_model_classes = (PLBartForConditionalGeneration,) if is_torch_available() else () is_encoder_decoder = True + fx_compatible = True test_pruning = False test_missing_keys = False diff --git a/tests/models/speech_to_text/test_modeling_speech_to_text.py b/tests/models/speech_to_text/test_modeling_speech_to_text.py index 35fa96f1c7..a1a625a9b4 100644 --- a/tests/models/speech_to_text/test_modeling_speech_to_text.py +++ b/tests/models/speech_to_text/test_modeling_speech_to_text.py @@ -17,6 +17,7 @@ import copy import inspect import os +import pickle import tempfile import unittest @@ -30,7 +31,7 @@ from transformers.testing_utils import ( slow, torch_device, ) -from transformers.utils import cached_property +from transformers.utils import cached_property, is_torch_fx_available from ...generation.test_generation_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -43,6 +44,9 @@ if is_torch_available(): from transformers import Speech2TextForConditionalGeneration, Speech2TextModel, Speech2TextProcessor from transformers.models.speech_to_text.modeling_speech_to_text import Speech2TextDecoder, Speech2TextEncoder +if is_torch_fx_available(): + from transformers.utils.fx import symbolic_trace + def prepare_speech_to_text_inputs_dict( config, @@ -271,6 +275,7 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes all_model_classes = (Speech2TextModel, Speech2TextForConditionalGeneration) if is_torch_available() else () all_generative_model_classes = (Speech2TextForConditionalGeneration,) if is_torch_available() else () is_encoder_decoder = True + fx_compatible = True test_pruning = False test_missing_keys = False @@ -715,6 +720,105 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes self.assertTrue(models_equal) + def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False): + if not is_torch_fx_available() or not self.fx_compatible: + return + + configs_no_init = _config_zero_init(config) # To be sure we have no Nan + configs_no_init.return_dict = False + + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + model.to(torch_device) + model.eval() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss) + + try: + if model.config.is_encoder_decoder: + model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward + labels = inputs.get("labels", None) + input_names = [ + "input_ids", + "attention_mask", + "decoder_input_ids", + "decoder_attention_mask", + "input_features", + ] + if labels is not None: + input_names.append("labels") + + filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} + input_names = list(filtered_inputs.keys()) + + model_output = model(**filtered_inputs) + + traced_model = symbolic_trace(model, input_names) + traced_output = traced_model(**filtered_inputs) + else: + input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values", "input_features"] + + labels = inputs.get("labels", None) + start_positions = inputs.get("start_positions", None) + end_positions = inputs.get("end_positions", None) + if labels is not None: + input_names.append("labels") + if start_positions is not None: + input_names.append("start_positions") + if end_positions is not None: + input_names.append("end_positions") + + filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} + input_names = list(filtered_inputs.keys()) + + model_output = model(**filtered_inputs) + + traced_model = symbolic_trace(model, input_names) + traced_output = traced_model(**filtered_inputs) + + except RuntimeError as e: + self.fail(f"Couldn't trace module: {e}") + + def flatten_output(output): + flatten = [] + for x in output: + if isinstance(x, (tuple, list)): + flatten += flatten_output(x) + elif not isinstance(x, torch.Tensor): + continue + else: + flatten.append(x) + return flatten + + model_output = flatten_output(model_output) + traced_output = flatten_output(traced_output) + num_outputs = len(model_output) + + for i in range(num_outputs): + self.assertTrue( + torch.allclose(model_output[i], traced_output[i]), + f"traced {i}th output doesn't match model {i}th output for {model_class}", + ) + + # Test that the model can be serialized and restored properly + with tempfile.TemporaryDirectory() as tmp_dir_name: + pkl_file_name = os.path.join(tmp_dir_name, "model.pkl") + try: + with open(pkl_file_name, "wb") as f: + pickle.dump(traced_model, f) + with open(pkl_file_name, "rb") as f: + loaded = pickle.load(f) + except Exception as e: + self.fail(f"Couldn't serialize / deserialize the traced model: {e}") + + loaded_output = loaded(**filtered_inputs) + loaded_output = flatten_output(loaded_output) + + for i in range(num_outputs): + self.assertTrue( + torch.allclose(model_output[i], loaded_output[i]), + f"serialized model {i}th output doesn't match model {i}th output for {model_class}", + ) + @require_torch @require_torchaudio diff --git a/tests/models/speech_to_text_2/test_modeling_speech_to_text_2.py b/tests/models/speech_to_text_2/test_modeling_speech_to_text_2.py index 88d0550675..d9717b4060 100644 --- a/tests/models/speech_to_text_2/test_modeling_speech_to_text_2.py +++ b/tests/models/speech_to_text_2/test_modeling_speech_to_text_2.py @@ -179,6 +179,7 @@ class Speech2Text2StandaloneDecoderModelTester: class Speech2Text2StandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (Speech2Text2Decoder, Speech2Text2ForCausalLM) if is_torch_available() else () all_generative_model_classes = (Speech2Text2ForCausalLM,) if is_torch_available() else () + fx_compatible = True test_pruning = False def setUp( diff --git a/tests/models/swin/test_modeling_swin.py b/tests/models/swin/test_modeling_swin.py index 47f219d482..0c1f266816 100644 --- a/tests/models/swin/test_modeling_swin.py +++ b/tests/models/swin/test_modeling_swin.py @@ -14,7 +14,6 @@ # limitations under the License. """ Testing suite for the PyTorch Swin model. """ -import copy import inspect import os import pickle @@ -26,7 +25,7 @@ from transformers.testing_utils import require_torch, require_vision, slow, torc from transformers.utils import cached_property, is_torch_available, is_torch_fx_available, is_vision_available from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor if is_torch_available(): @@ -45,14 +44,6 @@ if is_torch_fx_available(): from transformers.utils.fx import symbolic_trace -def _config_zero_init(config): - configs_no_init = copy.deepcopy(config) - for key in configs_no_init.__dict__.keys(): - if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key: - setattr(configs_no_init, key, 1e-10) - return configs_no_init - - class SwinModelTester: def __init__( self, @@ -407,7 +398,9 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"] if labels is not None: input_names.append("labels") + filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} + input_names = list(filtered_inputs.keys()) model_output = model(**filtered_inputs) @@ -427,7 +420,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): input_names.append("end_positions") filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} - input_names = filtered_inputs.keys() + input_names = list(filtered_inputs.keys()) model_output = model(**filtered_inputs) diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index 05a962e354..035e00c05c 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -509,8 +509,8 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else () - fx_compatible = True all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () + fx_compatible = True test_pruning = False test_resize_embeddings = True test_model_parallel = True diff --git a/tests/models/trocr/test_modeling_trocr.py b/tests/models/trocr/test_modeling_trocr.py index 6d8ff0aa60..0c5e6f7ae8 100644 --- a/tests/models/trocr/test_modeling_trocr.py +++ b/tests/models/trocr/test_modeling_trocr.py @@ -161,6 +161,7 @@ class TrOCRStandaloneDecoderModelTester: class TrOCRStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (TrOCRDecoder, TrOCRForCausalLM) if is_torch_available() else () all_generative_model_classes = (TrOCRForCausalLM,) if is_torch_available() else () + fx_compatible = True test_pruning = False def setUp(self): diff --git a/tests/models/xglm/test_modeling_xglm.py b/tests/models/xglm/test_modeling_xglm.py index 37301a79ed..f4da499426 100644 --- a/tests/models/xglm/test_modeling_xglm.py +++ b/tests/models/xglm/test_modeling_xglm.py @@ -13,17 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. - import datetime import math +import os +import pickle +import tempfile import unittest from transformers import XGLMConfig, is_torch_available from transformers.testing_utils import require_torch, slow, torch_device +from transformers.utils import is_torch_fx_available from ...generation.test_generation_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask +from ...test_modeling_common import ( + ModelTesterMixin, + _config_zero_init, + floats_tensor, + ids_tensor, + random_attention_mask, +) if is_torch_available(): @@ -31,6 +40,9 @@ if is_torch_available(): from transformers import XGLM_PRETRAINED_MODEL_ARCHIVE_LIST, XGLMForCausalLM, XGLMModel, XGLMTokenizer +if is_torch_fx_available(): + from transformers.utils.fx import symbolic_trace + class XGLMModelTester: def __init__( @@ -299,6 +311,7 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (XGLMModel, XGLMForCausalLM) if is_torch_available() else () all_generative_model_classes = (XGLMForCausalLM,) if is_torch_available() else () + fx_compatible = True test_missing_keys = False test_pruning = False @@ -337,6 +350,112 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_xglm_weight_initialization(*config_and_inputs) + def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False): + if not is_torch_fx_available() or not self.fx_compatible: + return + + configs_no_init = _config_zero_init(config) # To be sure we have no Nan + configs_no_init.return_dict = False + + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + model.to(torch_device) + model.eval() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss) + + try: + if model.config.is_encoder_decoder: + model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward + labels = inputs.get("labels", None) + input_names = [ + "input_ids", + "attention_mask", + "decoder_input_ids", + "decoder_attention_mask", + "input_features", + ] + if labels is not None: + input_names.append("labels") + + filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} + input_names = list(filtered_inputs.keys()) + + model_output = model(**filtered_inputs) + + traced_model = symbolic_trace(model, input_names) + traced_output = traced_model(**filtered_inputs) + else: + input_names = [ + "input_ids", + "attention_mask", + "token_type_ids", + "pixel_values", + "bbox", + "input_features", + ] + + labels = inputs.get("labels", None) + start_positions = inputs.get("start_positions", None) + end_positions = inputs.get("end_positions", None) + if labels is not None: + input_names.append("labels") + if start_positions is not None: + input_names.append("start_positions") + if end_positions is not None: + input_names.append("end_positions") + + filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} + input_names = list(filtered_inputs.keys()) + + model_output = model(**filtered_inputs) + + traced_model = symbolic_trace(model, input_names) + traced_output = traced_model(**filtered_inputs) + + except RuntimeError as e: + self.fail(f"Couldn't trace module: {e}") + + def flatten_output(output): + flatten = [] + for x in output: + if isinstance(x, (tuple, list)): + flatten += flatten_output(x) + elif not isinstance(x, torch.Tensor): + continue + else: + flatten.append(x) + return flatten + + model_output = flatten_output(model_output) + traced_output = flatten_output(traced_output) + num_outputs = len(model_output) + + for i in range(num_outputs): + self.assertTrue( + torch.allclose(model_output[i], traced_output[i]), + f"traced {i}th output doesn't match model {i}th output for {model_class}", + ) + + # Test that the model can be serialized and restored properly + with tempfile.TemporaryDirectory() as tmp_dir_name: + pkl_file_name = os.path.join(tmp_dir_name, "model.pkl") + try: + with open(pkl_file_name, "wb") as f: + pickle.dump(traced_model, f) + with open(pkl_file_name, "rb") as f: + loaded = pickle.load(f) + except Exception as e: + self.fail(f"Couldn't serialize / deserialize the traced model: {e}") + + loaded_output = loaded(**filtered_inputs) + loaded_output = flatten_output(loaded_output) + + for i in range(num_outputs): + self.assertTrue( + torch.allclose(model_output[i], loaded_output[i]), + f"serialized model {i}th output doesn't match model {i}th output for {model_class}", + ) + @slow def test_batch_generation(self): model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M") diff --git a/tests/models/xlnet/test_modeling_xlnet.py b/tests/models/xlnet/test_modeling_xlnet.py index 2c26315ceb..dca727b299 100644 --- a/tests/models/xlnet/test_modeling_xlnet.py +++ b/tests/models/xlnet/test_modeling_xlnet.py @@ -526,6 +526,7 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) all_generative_model_classes = ( (XLNetLMHeadModel,) if is_torch_available() else () ) # TODO (PVP): Check other models whether language generation is also applicable + fx_compatible = False test_pruning = False # XLNet has 2 QA models -> need to manually set the correct labels for one of them here diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 4b9227cbfd..46d464b0c0 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -738,17 +738,32 @@ class ModelTesterMixin: if model.config.is_encoder_decoder: model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward labels = inputs.get("labels", None) - input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"] + input_names = [ + "input_ids", + "attention_mask", + "decoder_input_ids", + "decoder_attention_mask", + "input_features", + ] if labels is not None: input_names.append("labels") + filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} + input_names = list(filtered_inputs.keys()) model_output = model(**filtered_inputs) traced_model = symbolic_trace(model, input_names) traced_output = traced_model(**filtered_inputs) else: - input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"] + input_names = [ + "input_ids", + "attention_mask", + "token_type_ids", + "pixel_values", + "bbox", + "input_features", + ] labels = inputs.get("labels", None) start_positions = inputs.get("start_positions", None) @@ -761,7 +776,7 @@ class ModelTesterMixin: input_names.append("end_positions") filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} - input_names = filtered_inputs.keys() + input_names = list(filtered_inputs.keys()) model_output = model(**filtered_inputs)