Fx support for multiple model architectures (#17393)

* Support for Bart and LayoutLM, and partial support for XLNet

* Support for mbart

* A lot of new models supported

* Support for other models

* LayoutLM fix

* Use strings instead of classes
This commit is contained in:
Michael Benayoun
2022-05-31 10:02:55 +02:00
committed by GitHub
parent 04681c1d81
commit 28d0048218
37 changed files with 515 additions and 146 deletions

View File

@@ -93,7 +93,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
@@ -114,7 +114,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class BartLearnedPositionalEmbedding(nn.Embedding): class BartLearnedPositionalEmbedding(nn.Embedding):
@@ -911,7 +911,7 @@ class BartDecoder(BartPretrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]

View File

@@ -2112,7 +2112,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]

View File

@@ -83,7 +83,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
@@ -105,7 +105,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class BlenderbotLearnedPositionalEmbedding(nn.Embedding): class BlenderbotLearnedPositionalEmbedding(nn.Embedding):
@@ -850,7 +850,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]

View File

@@ -80,7 +80,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
@@ -102,7 +102,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.blenderbot.modeling_blenderbot.BlenderbotLearnedPositionalEmbedding with Blenderbot->BlenderbotSmall # Copied from transformers.models.blenderbot.modeling_blenderbot.BlenderbotLearnedPositionalEmbedding with Blenderbot->BlenderbotSmall
@@ -846,7 +846,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]

View File

@@ -57,7 +57,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# contrastive loss function, adapted from # contrastive loss function, adapted from
@@ -674,7 +674,7 @@ class CLIPTextTransformer(nn.Module):
# lazily create causal attention mask, with full attention between the vision tokens # lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf # pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len) mask = torch.empty(bsz, seq_len, seq_len)
mask.fill_(float("-inf")) mask.fill_(torch.tensor(float("-inf")))
mask.triu_(1) # zero out the lower diagonal mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask mask = mask.unsqueeze(1) # expand mask
return mask return mask
@@ -1042,8 +1042,8 @@ class CLIPModel(CLIPPreTrainedModel):
text_embeds = self.text_projection(text_embeds) text_embeds = self.text_projection(text_embeds)
# normalized features # normalized features
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True) image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
# cosine similarity as logits # cosine similarity as logits
logit_scale = self.logit_scale.exp() logit_scale = self.logit_scale.exp()

View File

@@ -800,7 +800,7 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
if bbox is None: if bbox is None:
bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device) bbox = torch.zeros(input_shape + (4,), dtype=torch.long, device=device)
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

View File

@@ -79,7 +79,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
@@ -101,7 +101,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
@@ -998,7 +998,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None and combined_attention_mask is not None: if attention_mask is not None and combined_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]

View File

@@ -81,7 +81,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
@@ -103,7 +103,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class MarianSinusoidalPositionalEmbedding(nn.Embedding): class MarianSinusoidalPositionalEmbedding(nn.Embedding):
@@ -856,7 +856,7 @@ class MarianDecoder(MarianPreTrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]

View File

@@ -97,7 +97,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
@@ -119,7 +119,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->MBart # Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->MBart
@@ -909,7 +909,7 @@ class MBartDecoder(MBartPreTrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]

View File

@@ -61,7 +61,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
@@ -82,7 +82,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class OPTLearnedPositionalEmbedding(nn.Embedding): class OPTLearnedPositionalEmbedding(nn.Embedding):
@@ -513,7 +513,7 @@ class OPTDecoder(OPTPreTrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]

View File

@@ -80,7 +80,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
@@ -102,7 +102,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Pegasus # Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Pegasus
@@ -876,7 +876,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]

View File

@@ -94,7 +94,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
@@ -116,7 +116,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->PLBart # Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->PLBart
@@ -883,7 +883,7 @@ class PLBartDecoder(PLBartPreTrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]

View File

@@ -69,7 +69,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
@@ -91,7 +91,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class Conv1dSubsampler(nn.Module): class Conv1dSubsampler(nn.Module):
@@ -888,7 +888,7 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]

View File

@@ -49,7 +49,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
@@ -71,7 +71,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextSinusoidalPositionalEmbedding with Speech2Text->Speech2Text2 # Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextSinusoidalPositionalEmbedding with Speech2Text->Speech2Text2
@@ -495,7 +495,7 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]

View File

@@ -50,7 +50,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
@@ -72,7 +72,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->TrOCR # Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->TrOCR
@@ -524,7 +524,7 @@ class TrOCRDecoder(TrOCRPreTrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]

View File

@@ -120,7 +120,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
@@ -142,7 +142,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
@@ -577,7 +577,7 @@ class XGLMModel(XGLMPreTrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
@@ -712,7 +712,7 @@ class XGLMModel(XGLMPreTrainedModel):
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=float(self.dropout), training=self.training)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None

View File

@@ -1056,7 +1056,6 @@ class XLNetModel(XLNetPreTrainedModel):
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len) fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz) pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
pos_emb = pos_emb.to(self.device)
return pos_emb return pos_emb
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@@ -1206,6 +1205,7 @@ class XLNetModel(XLNetPreTrainedModel):
# Positional encoding # Positional encoding
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz) pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
pos_emb = pos_emb.to(output_h.device)
pos_emb = self.dropout(pos_emb) pos_emb = self.dropout(pos_emb)
# Prepare head mask if needed # Prepare head mask if needed

View File

@@ -29,27 +29,23 @@ from torch import nn
from torch.fx import Graph, GraphModule, Proxy, Tracer from torch.fx import Graph, GraphModule, Proxy, Tracer
from torch.fx.proxy import ParameterProxy from torch.fx.proxy import ParameterProxy
from .. import ( from .. import PretrainedConfig, PreTrainedModel, logging
CONFIG_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_MAPPING,
GPT2DoubleHeadsModel,
PretrainedConfig,
PreTrainedModel,
XLNetForQuestionAnswering,
logging,
)
from ..models.auto import get_values from ..models.auto import get_values
from ..models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
MODEL_FOR_PRETRAINING_MAPPING_NAMES,
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
MODEL_MAPPING_NAMES,
)
from ..utils import TORCH_FX_REQUIRED_VERSION, is_torch_fx_available from ..utils import TORCH_FX_REQUIRED_VERSION, is_torch_fx_available
from ..utils.versions import importlib_metadata from ..utils.versions import importlib_metadata
@@ -57,25 +53,25 @@ from ..utils.versions import importlib_metadata
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def _generate_supported_model_classes( def _generate_supported_model_class_names(
model_name: Type[PretrainedConfig], model_name: Type[PretrainedConfig],
supported_tasks: Optional[Union[str, List[str]]] = None, supported_tasks: Optional[Union[str, List[str]]] = None,
) -> List[Type[PreTrainedModel]]: ) -> List[str]:
model_config_class = CONFIG_MAPPING[model_name]
task_mapping = { task_mapping = {
"default": MODEL_MAPPING, "default": MODEL_MAPPING_NAMES,
"pretraining": MODEL_FOR_PRETRAINING_MAPPING, "pretraining": MODEL_FOR_PRETRAINING_MAPPING_NAMES,
"next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, "next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
"masked-lm": MODEL_FOR_MASKED_LM_MAPPING, "masked-lm": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
"causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING, "causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
"seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, "seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
"multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING, "speech-seq2seq": MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING, "multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
"sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, "sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING, "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, "masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
} }
if supported_tasks is None: if supported_tasks is None:
@@ -83,55 +79,78 @@ def _generate_supported_model_classes(
if isinstance(supported_tasks, str): if isinstance(supported_tasks, str):
supported_tasks = [supported_tasks] supported_tasks = [supported_tasks]
model_classes = [] model_class_names = []
for task in supported_tasks: for task in supported_tasks:
model_class = task_mapping[task].get(model_config_class, None) class_name = task_mapping[task].get(model_name, None)
if model_class: if class_name:
model_classes.append(model_class) model_class_names.append(class_name)
return model_classes return model_class_names
_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
"albert", "albert",
"bart",
"bert", "bert",
"blenderbot",
"blenderbot-small",
"clip",
"distilbert", "distilbert",
"mobilebert",
"electra", "electra",
"megatron-bert",
"gpt2", "gpt2",
"gptj",
"gpt_neo", "gpt_neo",
"t5", "gptj",
"layoutlm",
"m2m_100",
"marian",
"mbart",
"megatron-bert",
"mobilebert",
"mt5",
"opt",
"pegasus",
"plbart",
"roberta", "roberta",
"vit", "speech_to_text",
"speech_to_text_2",
"swin", "swin",
"t5",
"trocr",
"vit",
"xglm",
# TODO: add support for them as it should be quite easy to do so (small blocking issues). # TODO: add support for them as it should be quite easy to do so (small blocking issues).
# "layoutlm",
# "xlnet", # "xlnet",
] ]
_REGULAR_SUPPORTED_MODELS = [] _REGULAR_SUPPORTED_MODELS = []
for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS: for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS:
if isinstance(item, dict): if isinstance(item, dict):
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_classes(**item)) _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(**item))
else: else:
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_classes(item)) _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(item))
_SPECIAL_SUPPORTED_MODELS = [ _SPECIAL_SUPPORTED_MODELS = [
GPT2DoubleHeadsModel, "CLIPTextModel",
"CLIPVisionModel",
"GPT2DoubleHeadsModel",
"Speech2Text2Decoder",
"TrOCRDecoder",
# TODO: add support for them as it should be quite easy to do so (small blocking issues). # TODO: add support for them as it should be quite easy to do so (small blocking issues).
# XLNetForQuestionAnswering, # XLNetForQuestionAnswering,
] ]
_SUPPORTED_MODELS = tuple( _SUPPORTED_MODELS = tuple(sorted(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)))
sorted(list(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)), key=lambda c: c.__name__)
)
def torch_nn_embedding(self, input): def torch_nn_embedding(self, input):
return torch.empty(*input.shape, self.weight.shape[-1], device="meta") return torch.empty(*input.shape, self.weight.shape[-1], device="meta")
def torch_nn_functional_embedding(
input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False
):
return torch.empty(*input.shape, weight.shape[-1], device="meta")
def torch_nn_layernorm(self, input): def torch_nn_layernorm(self, input):
return input return input
@@ -176,6 +195,12 @@ def torch_arange(*args, **kwargs):
start, end = args start, end = args
else: else:
start, end, step = args start, end, step = args
if isinstance(start, float):
start = int(start)
if isinstance(end, float):
start = int(end)
if isinstance(step, float):
step = int(step)
step = kwargs.get("step", step) step = kwargs.get("step", step)
dtype = kwargs.get("dtype") dtype = kwargs.get("dtype")
return torch.empty((end - start) // step, dtype=dtype, device="meta") return torch.empty((end - start) // step, dtype=dtype, device="meta")
@@ -265,6 +290,14 @@ def torch_matmul(input, other, *, out=None):
return torch.empty(*shape, device="meta") return torch.empty(*shape, device="meta")
def torch_bmm(input, mat2, *, out=None):
if out is not None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
batch_size, n, m = input.shape
_, _, p = mat2.shape
return torch.empty(batch_size, n, p, device="meta")
def torch_einsum(equation, *operands): def torch_einsum(equation, *operands):
# TODO: infer shape without performing the computation, this might be quite hard. # TODO: infer shape without performing the computation, this might be quite hard.
concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands) concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands)
@@ -285,13 +318,39 @@ def torch_index_select(input, dim, index, *, out=None):
def torch_tensor_index_select(self, dim, index): def torch_tensor_index_select(self, dim, index):
return torch_tensor_index_select(self, dim, index) return torch_index_select(self, dim, index)
def torch_roll(input, shifts, dims=None): def torch_roll(input, shifts, dims=None):
return input return input
def torch_flip(input, dims):
return input
def torch_tensor_flip(self, dims):
return self
def torch_nn_conv1d(self, input):
l_in = input.shape[-1]
shape = None
padding = self.padding
if padding == "valid":
padding = (0, 0)
if padding == "same":
shape = list(input.shape)
if shape is None:
shape = list(input.shape)
l_out = math.floor(
(l_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
)
shape[-1] = l_out
shape[-2] = self.out_channels
return torch.empty(shape, device="meta")
def torch_nn_conv2d(self, input): def torch_nn_conv2d(self, input):
h_in, w_in = input.shape[-2:] h_in, w_in = input.shape[-2:]
shape = None shape = None
@@ -325,6 +384,21 @@ def torch_tensor_unsqueeze(self, dim):
return torch_unsqueeze(self, dim) return torch_unsqueeze(self, dim)
def torch_unique_consecutive(input, **kwargs):
output = torch.unique_consecutive(torch.zeros_like(input, device="cpu"), **kwargs)
if isinstance(output, torch.Tensor):
return output.to("meta")
else:
return tuple(map(output, lambda x: x.to("meta")))
def torch_nn_functional_one_hot(tensor, num_classes=-1):
if num_classes < 0:
raise ValueError("Don't support automatic num_classes inference for MetaTensor analysis")
shape = list(tensor.shape) + [num_classes]
return torch.empty(shape, device="meta")
def torch_nn_mseloss(self, input, target): def torch_nn_mseloss(self, input, target):
if self.reduction == "none": if self.reduction == "none":
shape = target.shape shape = target.shape
@@ -350,14 +424,27 @@ def torch_nn_bcewithlogitsloss(self, input, target):
def operator_getitem(a, b): def operator_getitem(a, b):
def to_concrete(t):
if isinstance(t, torch.Tensor):
concrete = torch.ones_like(t, device="cpu")
if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]:
concrete = concrete.to(torch.int64)
return concrete
return t
if isinstance(a, torch.Tensor): if isinstance(a, torch.Tensor):
# TODO: infer shape without performing the computation. # TODO: infer shape without performing the computation.
if isinstance(b, tuple):
b = tuple(map(to_concrete, b))
else:
b = to_concrete(b)
return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta") return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
return operator.getitem(a, b) return operator.getitem(a, b)
_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = { _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
torch.nn.Embedding: torch_nn_embedding, torch.nn.Embedding: torch_nn_embedding,
torch.nn.functional.embedding: torch_nn_functional_embedding,
torch.nn.LayerNorm: torch_nn_layernorm, torch.nn.LayerNorm: torch_nn_layernorm,
torch.nn.Linear: torch_nn_linear, torch.nn.Linear: torch_nn_linear,
torch.relu: torch_relu, torch.relu: torch_relu,
@@ -372,15 +459,20 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
torch.mul: torch_mul, torch.mul: torch_mul,
torch.Tensor.mul: torch_tensor_mul, torch.Tensor.mul: torch_tensor_mul,
torch.matmul: torch_matmul, torch.matmul: torch_matmul,
torch.bmm: torch_bmm,
torch.einsum: torch_einsum, torch.einsum: torch_einsum,
torch.Tensor.repeat: torch_tensor_repeat, torch.Tensor.repeat: torch_tensor_repeat,
torch.roll: torch_roll, torch.roll: torch_roll,
# TODO: those might not be needed. torch.flip: torch_flip,
# torch.index_select: torch_index_select, torch.Tensor.flip: torch_tensor_flip,
# torch.Tensor.index_select: torch_tensor_index_select, torch.index_select: torch_index_select,
torch.Tensor.index_select: torch_tensor_index_select,
torch.nn.Conv1d: torch_nn_conv1d,
torch.nn.Conv2d: torch_nn_conv2d, torch.nn.Conv2d: torch_nn_conv2d,
torch.unsqueeze: torch_unsqueeze, torch.unsqueeze: torch_unsqueeze,
torch.Tensor.unsqueeze: torch_tensor_unsqueeze, torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
torch.unique_consecutive: torch_unique_consecutive,
torch.nn.functional.one_hot: torch_nn_functional_one_hot,
torch.nn.MSELoss: torch_nn_mseloss, torch.nn.MSELoss: torch_nn_mseloss,
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss, torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss, torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
@@ -513,7 +605,7 @@ class HFTracer(Tracer):
# Feature flag for proxying accesses to buffer values # Feature flag for proxying accesses to buffer values
proxy_buffer_attributes: bool = True proxy_buffer_attributes: bool = True
allow_insert_stateless_mods: bool = True allow_insert_stateless_mods: bool = True
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full_like", "eye"] _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty"]
def __init__(self, autowrap_modules=(math,), autowrap_functions=()): def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
@@ -532,22 +624,22 @@ class HFTracer(Tracer):
"""Generates dummy input for model inference recording.""" """Generates dummy input for model inference recording."""
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored # Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
# from pickle, or from the "__class__" attribute in the general case. # from pickle, or from the "__class__" attribute in the general case.
model_class = getattr(model, "class_for_deserialization", model.__class__) model_class_name = getattr(model, "class_for_deserialization", model.__class__).__name__
device = model.device device = model.device
inputs_dict = {} inputs_dict = {}
if input_name in ["labels", "start_positions", "end_positions"]: if input_name in ["labels", "start_positions", "end_positions"]:
batch_size = shape[0] batch_size = shape[0]
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING): if model_class_name in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device) inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
elif model_class in [ elif model_class_name in [
*get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING), *get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
XLNetForQuestionAnswering, "XLNetForQuestionAnswering",
]: ]:
inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device) inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device) inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
elif model_class in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING): elif model_class_name in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES):
if not hasattr(model.config, "problem_type") or model.config.problem_type is None: if not hasattr(model.config, "problem_type") or model.config.problem_type is None:
raise ValueError( raise ValueError(
"Could not retrieve the problem type for the sequence classification task, please set " "Could not retrieve the problem type for the sequence classification task, please set "
@@ -571,32 +663,49 @@ class HFTracer(Tracer):
) )
inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device) inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device)
elif model_class in [ elif model_class_name in [
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING), *get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING), *get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
]: ]:
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device) inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
elif model_class in [ elif model_class_name in [
*get_values(MODEL_FOR_PRETRAINING_MAPPING), *get_values(MODEL_FOR_PRETRAINING_MAPPING_NAMES),
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING), *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES),
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING), *get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
*get_values(MODEL_FOR_MASKED_LM_MAPPING), *get_values(MODEL_FOR_MASKED_LM_MAPPING_NAMES),
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING), *get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES),
GPT2DoubleHeadsModel, "GPT2DoubleHeadsModel",
]: ]:
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device) inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
else: else:
raise NotImplementedError(f"{model_class} not supported yet.") raise NotImplementedError(f"{model_class_name} not supported yet.")
elif "pixel_values" in input_name: elif "pixel_values" in input_name:
batch_size = shape[0] batch_size = shape[0]
image_size = model.config.image_size image_size = getattr(model.config, "image_size", None)
if image_size is None:
if hasattr(model.config, "vision_config"):
image_size = model.config.vision_config.image_size
elif hasattr(model.config, "encoder"):
image_size = model.config.encoder.image_size
else:
raise AttributeError('Could not find the "image_size" field in the model config')
# If no num_channels is in the config, use some arbitrary value.
num_channels = getattr(model.config, "num_channels", 3)
if not isinstance(image_size, collections.abc.Iterable): if not isinstance(image_size, collections.abc.Iterable):
image_size = (image_size, image_size) image_size = (image_size, image_size)
height, width = image_size height, width = image_size
inputs_dict[input_name] = torch.zeros( inputs_dict[input_name] = torch.zeros(
batch_size, model.config.num_channels, height, width, dtype=torch.float32, device=device batch_size, num_channels, height, width, dtype=torch.float32, device=device
) )
elif "bbox" in input_name:
inputs_dict[input_name] = torch.zeros(*shape, 4, dtype=torch.float, device=device)
elif "input_features" in input_name:
inputs_dict[input_name] = torch.zeros(
*shape, model.config.input_feat_per_channel, dtype=torch.float, device=device
)
elif "inputs" in input_name:
inputs_dict[input_name] = torch.zeros(*shape, dtype=torch.float, device=device)
elif "mask" in input_name or "ids" in input_name: elif "mask" in input_name or "ids" in input_name:
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device) inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
else: else:
@@ -628,6 +737,8 @@ class HFTracer(Tracer):
if kind == "call_function": if kind == "call_function":
meta_target = _MANUAL_META_OVERRIDES.get(target, target) meta_target = _MANUAL_META_OVERRIDES.get(target, target)
meta_out = meta_target(*args_metas, **kwargs_metas) meta_out = meta_target(*args_metas, **kwargs_metas)
if isinstance(meta_out, torch.Tensor):
meta_out = meta_out.to(device="meta")
elif kind == "call_method": elif kind == "call_method":
method = getattr(args_metas[0].__class__, target) method = getattr(args_metas[0].__class__, target)
meta_target = _MANUAL_META_OVERRIDES.get(method, method) meta_target = _MANUAL_META_OVERRIDES.get(method, method)
@@ -731,7 +842,7 @@ class HFTracer(Tracer):
sequence_length = _generate_random_int() sequence_length = _generate_random_int()
shape = [batch_size, sequence_length] shape = [batch_size, sequence_length]
if root.__class__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING): if root.__class__.__name__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
num_choices = _generate_random_int(low=2, high=5) num_choices = _generate_random_int(low=2, high=5)
shape.insert(1, num_choices) shape.insert(1, num_choices)
@@ -870,11 +981,22 @@ def symbolic_trace(
if input_names is None: if input_names is None:
input_names = model.dummy_inputs.keys() input_names = model.dummy_inputs.keys()
input_names = list(input_names)
sig = inspect.signature(model.forward) sig = inspect.signature(model.forward)
if not (set(input_names) <= set(sig.parameters.keys())):
formatted_input_names = input_names[0] if len(input_names) == 1 else ", ".join(input_names)
formatted_allowed_input_names = ", ".join(sig.parameters.keys())
raise ValueError(
f"The model does not have input(s) named: {formatted_input_names}, expected a subset of the following:"
f" {formatted_allowed_input_names}"
)
concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names} concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
if not isinstance(model, _SUPPORTED_MODELS): if model.__class__.__name__ not in _SUPPORTED_MODELS:
supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS)) supported_model_names = ", ".join(_SUPPORTED_MODELS)
raise NotImplementedError( raise NotImplementedError(
f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}" f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"
) )

View File

@@ -413,6 +413,7 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
) )
all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
fx_compatible = True
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False
@@ -1386,6 +1387,7 @@ class BartStandaloneDecoderModelTester:
class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (BartDecoder, BartForCausalLM) if is_torch_available() else () all_model_classes = (BartDecoder, BartForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (BartForCausalLM,) if is_torch_available() else () all_generative_model_classes = (BartForCausalLM,) if is_torch_available() else ()
fx_comptatible = True
test_pruning = False test_pruning = False
is_encoder_decoder = False is_encoder_decoder = False

View File

@@ -218,6 +218,7 @@ class BlenderbotModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
all_model_classes = (BlenderbotModel, BlenderbotForConditionalGeneration) if is_torch_available() else () all_model_classes = (BlenderbotModel, BlenderbotForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (BlenderbotForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (BlenderbotForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
fx_compatible = True
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False

View File

@@ -213,6 +213,7 @@ class BlenderbotSmallModelTest(ModelTesterMixin, GenerationTesterMixin, unittest
all_model_classes = (BlenderbotSmallModel, BlenderbotSmallForConditionalGeneration) if is_torch_available() else () all_model_classes = (BlenderbotSmallModel, BlenderbotSmallForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (BlenderbotSmallForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (BlenderbotSmallForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
fx_compatible = True
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False

View File

@@ -152,7 +152,7 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
""" """
all_model_classes = (CLIPVisionModel,) if is_torch_available() else () all_model_classes = (CLIPVisionModel,) if is_torch_available() else ()
fx_compatible = True
test_pruning = False test_pruning = False
test_resize_embeddings = False test_resize_embeddings = False
test_head_masking = False test_head_masking = False
@@ -303,6 +303,7 @@ class CLIPTextModelTester:
class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase): class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (CLIPTextModel,) if is_torch_available() else () all_model_classes = (CLIPTextModel,) if is_torch_available() else ()
fx_compatible = True
test_pruning = False test_pruning = False
test_head_masking = False test_head_masking = False
@@ -388,6 +389,7 @@ class CLIPModelTester:
@require_torch @require_torch
class CLIPModelTest(ModelTesterMixin, unittest.TestCase): class CLIPModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (CLIPModel,) if is_torch_available() else () all_model_classes = (CLIPModel,) if is_torch_available() else ()
fx_compatible = True
test_head_masking = False test_head_masking = False
test_pruning = False test_pruning = False
test_resize_embeddings = False test_resize_embeddings = False

View File

@@ -215,6 +215,7 @@ class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else None else None
) )
fx_compatible = True
def setUp(self): def setUp(self):
self.model_tester = LayoutLMModelTester(self) self.model_tester = LayoutLMModelTester(self)

View File

@@ -231,6 +231,7 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
) )
all_generative_model_classes = (M2M100ForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (M2M100ForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
fx_compatible = True
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False

View File

@@ -230,6 +230,7 @@ class MarianModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
all_model_classes = (MarianModel, MarianMTModel) if is_torch_available() else () all_model_classes = (MarianModel, MarianMTModel) if is_torch_available() else ()
all_generative_model_classes = (MarianMTModel,) if is_torch_available() else () all_generative_model_classes = (MarianMTModel,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
fx_compatible = True
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False

View File

@@ -224,6 +224,7 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
) )
all_generative_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
fx_compatible = True
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False

View File

@@ -178,6 +178,7 @@ class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (OPTModel, OPTForCausalLM) if is_torch_available() else () all_model_classes = (OPTModel, OPTForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (OPTForCausalLM,) if is_torch_available() else () all_generative_model_classes = (OPTForCausalLM,) if is_torch_available() else ()
is_encoder_decoder = False is_encoder_decoder = False
fx_compatible = True
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False

View File

@@ -229,6 +229,7 @@ class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
all_model_classes = (PegasusModel, PegasusForConditionalGeneration) if is_torch_available() else () all_model_classes = (PegasusModel, PegasusForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
fx_compatible = True
test_resize_position_embeddings = True test_resize_position_embeddings = True
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False

View File

@@ -219,6 +219,7 @@ class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
) )
all_generative_model_classes = (PLBartForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (PLBartForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
fx_compatible = True
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False

View File

@@ -17,6 +17,7 @@
import copy import copy
import inspect import inspect
import os import os
import pickle
import tempfile import tempfile
import unittest import unittest
@@ -30,7 +31,7 @@ from transformers.testing_utils import (
slow, slow,
torch_device, torch_device,
) )
from transformers.utils import cached_property from transformers.utils import cached_property, is_torch_fx_available
from ...generation.test_generation_utils import GenerationTesterMixin from ...generation.test_generation_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
@@ -43,6 +44,9 @@ if is_torch_available():
from transformers import Speech2TextForConditionalGeneration, Speech2TextModel, Speech2TextProcessor from transformers import Speech2TextForConditionalGeneration, Speech2TextModel, Speech2TextProcessor
from transformers.models.speech_to_text.modeling_speech_to_text import Speech2TextDecoder, Speech2TextEncoder from transformers.models.speech_to_text.modeling_speech_to_text import Speech2TextDecoder, Speech2TextEncoder
if is_torch_fx_available():
from transformers.utils.fx import symbolic_trace
def prepare_speech_to_text_inputs_dict( def prepare_speech_to_text_inputs_dict(
config, config,
@@ -271,6 +275,7 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
all_model_classes = (Speech2TextModel, Speech2TextForConditionalGeneration) if is_torch_available() else () all_model_classes = (Speech2TextModel, Speech2TextForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (Speech2TextForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (Speech2TextForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
fx_compatible = True
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False
@@ -715,6 +720,105 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
self.assertTrue(models_equal) self.assertTrue(models_equal)
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
if not is_torch_fx_available() or not self.fx_compatible:
return
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.return_dict = False
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
try:
if model.config.is_encoder_decoder:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
labels = inputs.get("labels", None)
input_names = [
"input_ids",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
"input_features",
]
if labels is not None:
input_names.append("labels")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
else:
input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values", "input_features"]
labels = inputs.get("labels", None)
start_positions = inputs.get("start_positions", None)
end_positions = inputs.get("end_positions", None)
if labels is not None:
input_names.append("labels")
if start_positions is not None:
input_names.append("start_positions")
if end_positions is not None:
input_names.append("end_positions")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
except RuntimeError as e:
self.fail(f"Couldn't trace module: {e}")
def flatten_output(output):
flatten = []
for x in output:
if isinstance(x, (tuple, list)):
flatten += flatten_output(x)
elif not isinstance(x, torch.Tensor):
continue
else:
flatten.append(x)
return flatten
model_output = flatten_output(model_output)
traced_output = flatten_output(traced_output)
num_outputs = len(model_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], traced_output[i]),
f"traced {i}th output doesn't match model {i}th output for {model_class}",
)
# Test that the model can be serialized and restored properly
with tempfile.TemporaryDirectory() as tmp_dir_name:
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
try:
with open(pkl_file_name, "wb") as f:
pickle.dump(traced_model, f)
with open(pkl_file_name, "rb") as f:
loaded = pickle.load(f)
except Exception as e:
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
loaded_output = loaded(**filtered_inputs)
loaded_output = flatten_output(loaded_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], loaded_output[i]),
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
)
@require_torch @require_torch
@require_torchaudio @require_torchaudio

View File

@@ -179,6 +179,7 @@ class Speech2Text2StandaloneDecoderModelTester:
class Speech2Text2StandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class Speech2Text2StandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (Speech2Text2Decoder, Speech2Text2ForCausalLM) if is_torch_available() else () all_model_classes = (Speech2Text2Decoder, Speech2Text2ForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (Speech2Text2ForCausalLM,) if is_torch_available() else () all_generative_model_classes = (Speech2Text2ForCausalLM,) if is_torch_available() else ()
fx_compatible = True
test_pruning = False test_pruning = False
def setUp( def setUp(

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch Swin model. """ """ Testing suite for the PyTorch Swin model. """
import copy
import inspect import inspect
import os import os
import pickle import pickle
@@ -26,7 +25,7 @@ from transformers.testing_utils import require_torch, require_vision, slow, torc
from transformers.utils import cached_property, is_torch_available, is_torch_fx_available, is_vision_available from transformers.utils import cached_property, is_torch_available, is_torch_fx_available, is_vision_available
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
if is_torch_available(): if is_torch_available():
@@ -45,14 +44,6 @@ if is_torch_fx_available():
from transformers.utils.fx import symbolic_trace from transformers.utils.fx import symbolic_trace
def _config_zero_init(config):
configs_no_init = copy.deepcopy(config)
for key in configs_no_init.__dict__.keys():
if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
setattr(configs_no_init, key, 1e-10)
return configs_no_init
class SwinModelTester: class SwinModelTester:
def __init__( def __init__(
self, self,
@@ -407,7 +398,9 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"] input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
if labels is not None: if labels is not None:
input_names.append("labels") input_names.append("labels")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs) model_output = model(**filtered_inputs)
@@ -427,7 +420,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
input_names.append("end_positions") input_names.append("end_positions")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = filtered_inputs.keys() input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs) model_output = model(**filtered_inputs)

View File

@@ -509,8 +509,8 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
fx_compatible = True
all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
fx_compatible = True
test_pruning = False test_pruning = False
test_resize_embeddings = True test_resize_embeddings = True
test_model_parallel = True test_model_parallel = True

View File

@@ -161,6 +161,7 @@ class TrOCRStandaloneDecoderModelTester:
class TrOCRStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class TrOCRStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (TrOCRDecoder, TrOCRForCausalLM) if is_torch_available() else () all_model_classes = (TrOCRDecoder, TrOCRForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (TrOCRForCausalLM,) if is_torch_available() else () all_generative_model_classes = (TrOCRForCausalLM,) if is_torch_available() else ()
fx_compatible = True
test_pruning = False test_pruning = False
def setUp(self): def setUp(self):

View File

@@ -13,17 +13,26 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import datetime import datetime
import math import math
import os
import pickle
import tempfile
import unittest import unittest
from transformers import XGLMConfig, is_torch_available from transformers import XGLMConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from transformers.utils import is_torch_fx_available
from ...generation.test_generation_utils import GenerationTesterMixin from ...generation.test_generation_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_common import (
ModelTesterMixin,
_config_zero_init,
floats_tensor,
ids_tensor,
random_attention_mask,
)
if is_torch_available(): if is_torch_available():
@@ -31,6 +40,9 @@ if is_torch_available():
from transformers import XGLM_PRETRAINED_MODEL_ARCHIVE_LIST, XGLMForCausalLM, XGLMModel, XGLMTokenizer from transformers import XGLM_PRETRAINED_MODEL_ARCHIVE_LIST, XGLMForCausalLM, XGLMModel, XGLMTokenizer
if is_torch_fx_available():
from transformers.utils.fx import symbolic_trace
class XGLMModelTester: class XGLMModelTester:
def __init__( def __init__(
@@ -299,6 +311,7 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (XGLMModel, XGLMForCausalLM) if is_torch_available() else () all_model_classes = (XGLMModel, XGLMForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (XGLMForCausalLM,) if is_torch_available() else () all_generative_model_classes = (XGLMForCausalLM,) if is_torch_available() else ()
fx_compatible = True
test_missing_keys = False test_missing_keys = False
test_pruning = False test_pruning = False
@@ -337,6 +350,112 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xglm_weight_initialization(*config_and_inputs) self.model_tester.create_and_check_xglm_weight_initialization(*config_and_inputs)
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
if not is_torch_fx_available() or not self.fx_compatible:
return
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.return_dict = False
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
try:
if model.config.is_encoder_decoder:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
labels = inputs.get("labels", None)
input_names = [
"input_ids",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
"input_features",
]
if labels is not None:
input_names.append("labels")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
else:
input_names = [
"input_ids",
"attention_mask",
"token_type_ids",
"pixel_values",
"bbox",
"input_features",
]
labels = inputs.get("labels", None)
start_positions = inputs.get("start_positions", None)
end_positions = inputs.get("end_positions", None)
if labels is not None:
input_names.append("labels")
if start_positions is not None:
input_names.append("start_positions")
if end_positions is not None:
input_names.append("end_positions")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
except RuntimeError as e:
self.fail(f"Couldn't trace module: {e}")
def flatten_output(output):
flatten = []
for x in output:
if isinstance(x, (tuple, list)):
flatten += flatten_output(x)
elif not isinstance(x, torch.Tensor):
continue
else:
flatten.append(x)
return flatten
model_output = flatten_output(model_output)
traced_output = flatten_output(traced_output)
num_outputs = len(model_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], traced_output[i]),
f"traced {i}th output doesn't match model {i}th output for {model_class}",
)
# Test that the model can be serialized and restored properly
with tempfile.TemporaryDirectory() as tmp_dir_name:
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
try:
with open(pkl_file_name, "wb") as f:
pickle.dump(traced_model, f)
with open(pkl_file_name, "rb") as f:
loaded = pickle.load(f)
except Exception as e:
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
loaded_output = loaded(**filtered_inputs)
loaded_output = flatten_output(loaded_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], loaded_output[i]),
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
)
@slow @slow
def test_batch_generation(self): def test_batch_generation(self):
model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M") model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")

View File

@@ -526,6 +526,7 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
all_generative_model_classes = ( all_generative_model_classes = (
(XLNetLMHeadModel,) if is_torch_available() else () (XLNetLMHeadModel,) if is_torch_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable ) # TODO (PVP): Check other models whether language generation is also applicable
fx_compatible = False
test_pruning = False test_pruning = False
# XLNet has 2 QA models -> need to manually set the correct labels for one of them here # XLNet has 2 QA models -> need to manually set the correct labels for one of them here

View File

@@ -738,17 +738,32 @@ class ModelTesterMixin:
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
labels = inputs.get("labels", None) labels = inputs.get("labels", None)
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"] input_names = [
"input_ids",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
"input_features",
]
if labels is not None: if labels is not None:
input_names.append("labels") input_names.append("labels")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs) model_output = model(**filtered_inputs)
traced_model = symbolic_trace(model, input_names) traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs) traced_output = traced_model(**filtered_inputs)
else: else:
input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"] input_names = [
"input_ids",
"attention_mask",
"token_type_ids",
"pixel_values",
"bbox",
"input_features",
]
labels = inputs.get("labels", None) labels = inputs.get("labels", None)
start_positions = inputs.get("start_positions", None) start_positions = inputs.get("start_positions", None)
@@ -761,7 +776,7 @@ class ModelTesterMixin:
input_names.append("end_positions") input_names.append("end_positions")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = filtered_inputs.keys() input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs) model_output = model(**filtered_inputs)