Fx support for multiple model architectures (#17393)
* Support for Bart and LayoutLM, and partial support for XLNet * Support for mbart * A lot of new models supported * Support for other models * LayoutLM fix * Use strings instead of classes
This commit is contained in:
@@ -93,7 +93,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
|||||||
Make causal mask used for bi-directional self-attention.
|
Make causal mask used for bi-directional self-attention.
|
||||||
"""
|
"""
|
||||||
bsz, tgt_len = input_ids_shape
|
bsz, tgt_len = input_ids_shape
|
||||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
|
||||||
mask_cond = torch.arange(mask.size(-1))
|
mask_cond = torch.arange(mask.size(-1))
|
||||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||||
mask = mask.to(dtype)
|
mask = mask.to(dtype)
|
||||||
@@ -114,7 +114,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|||||||
|
|
||||||
inverted_mask = 1.0 - expanded_mask
|
inverted_mask = 1.0 - expanded_mask
|
||||||
|
|
||||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||||
|
|
||||||
|
|
||||||
class BartLearnedPositionalEmbedding(nn.Embedding):
|
class BartLearnedPositionalEmbedding(nn.Embedding):
|
||||||
@@ -911,7 +911,7 @@ class BartDecoder(BartPretrainedModel):
|
|||||||
if input_shape[-1] > 1:
|
if input_shape[-1] > 1:
|
||||||
combined_attention_mask = _make_causal_mask(
|
combined_attention_mask = _make_causal_mask(
|
||||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||||
).to(self.device)
|
).to(inputs_embeds.device)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
|||||||
@@ -2112,7 +2112,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
|
|||||||
if input_shape[-1] > 1:
|
if input_shape[-1] > 1:
|
||||||
combined_attention_mask = _make_causal_mask(
|
combined_attention_mask = _make_causal_mask(
|
||||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||||
).to(self.device)
|
).to(inputs_embeds.device)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
|||||||
Make causal mask used for bi-directional self-attention.
|
Make causal mask used for bi-directional self-attention.
|
||||||
"""
|
"""
|
||||||
bsz, tgt_len = input_ids_shape
|
bsz, tgt_len = input_ids_shape
|
||||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
|
||||||
mask_cond = torch.arange(mask.size(-1))
|
mask_cond = torch.arange(mask.size(-1))
|
||||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||||
mask = mask.to(dtype)
|
mask = mask.to(dtype)
|
||||||
@@ -105,7 +105,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|||||||
|
|
||||||
inverted_mask = 1.0 - expanded_mask
|
inverted_mask = 1.0 - expanded_mask
|
||||||
|
|
||||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||||
|
|
||||||
|
|
||||||
class BlenderbotLearnedPositionalEmbedding(nn.Embedding):
|
class BlenderbotLearnedPositionalEmbedding(nn.Embedding):
|
||||||
@@ -850,7 +850,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
|||||||
if input_shape[-1] > 1:
|
if input_shape[-1] > 1:
|
||||||
combined_attention_mask = _make_causal_mask(
|
combined_attention_mask = _make_causal_mask(
|
||||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||||
).to(self.device)
|
).to(inputs_embeds.device)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
|||||||
Make causal mask used for bi-directional self-attention.
|
Make causal mask used for bi-directional self-attention.
|
||||||
"""
|
"""
|
||||||
bsz, tgt_len = input_ids_shape
|
bsz, tgt_len = input_ids_shape
|
||||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
|
||||||
mask_cond = torch.arange(mask.size(-1))
|
mask_cond = torch.arange(mask.size(-1))
|
||||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||||
mask = mask.to(dtype)
|
mask = mask.to(dtype)
|
||||||
@@ -102,7 +102,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|||||||
|
|
||||||
inverted_mask = 1.0 - expanded_mask
|
inverted_mask = 1.0 - expanded_mask
|
||||||
|
|
||||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.blenderbot.modeling_blenderbot.BlenderbotLearnedPositionalEmbedding with Blenderbot->BlenderbotSmall
|
# Copied from transformers.models.blenderbot.modeling_blenderbot.BlenderbotLearnedPositionalEmbedding with Blenderbot->BlenderbotSmall
|
||||||
@@ -846,7 +846,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
|||||||
if input_shape[-1] > 1:
|
if input_shape[-1] > 1:
|
||||||
combined_attention_mask = _make_causal_mask(
|
combined_attention_mask = _make_causal_mask(
|
||||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||||
).to(self.device)
|
).to(inputs_embeds.device)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|||||||
|
|
||||||
inverted_mask = 1.0 - expanded_mask
|
inverted_mask = 1.0 - expanded_mask
|
||||||
|
|
||||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||||
|
|
||||||
|
|
||||||
# contrastive loss function, adapted from
|
# contrastive loss function, adapted from
|
||||||
@@ -674,7 +674,7 @@ class CLIPTextTransformer(nn.Module):
|
|||||||
# lazily create causal attention mask, with full attention between the vision tokens
|
# lazily create causal attention mask, with full attention between the vision tokens
|
||||||
# pytorch uses additive attention mask; fill with -inf
|
# pytorch uses additive attention mask; fill with -inf
|
||||||
mask = torch.empty(bsz, seq_len, seq_len)
|
mask = torch.empty(bsz, seq_len, seq_len)
|
||||||
mask.fill_(float("-inf"))
|
mask.fill_(torch.tensor(float("-inf")))
|
||||||
mask.triu_(1) # zero out the lower diagonal
|
mask.triu_(1) # zero out the lower diagonal
|
||||||
mask = mask.unsqueeze(1) # expand mask
|
mask = mask.unsqueeze(1) # expand mask
|
||||||
return mask
|
return mask
|
||||||
@@ -1042,8 +1042,8 @@ class CLIPModel(CLIPPreTrainedModel):
|
|||||||
text_embeds = self.text_projection(text_embeds)
|
text_embeds = self.text_projection(text_embeds)
|
||||||
|
|
||||||
# normalized features
|
# normalized features
|
||||||
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
|
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
||||||
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
|
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
||||||
|
|
||||||
# cosine similarity as logits
|
# cosine similarity as logits
|
||||||
logit_scale = self.logit_scale.exp()
|
logit_scale = self.logit_scale.exp()
|
||||||
|
|||||||
@@ -800,7 +800,7 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
|
|||||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||||
|
|
||||||
if bbox is None:
|
if bbox is None:
|
||||||
bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device)
|
bbox = torch.zeros(input_shape + (4,), dtype=torch.long, device=device)
|
||||||
|
|
||||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||||
|
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
|||||||
Make causal mask used for bi-directional self-attention.
|
Make causal mask used for bi-directional self-attention.
|
||||||
"""
|
"""
|
||||||
bsz, tgt_len = input_ids_shape
|
bsz, tgt_len = input_ids_shape
|
||||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
|
||||||
mask_cond = torch.arange(mask.size(-1))
|
mask_cond = torch.arange(mask.size(-1))
|
||||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||||
mask = mask.to(dtype)
|
mask = mask.to(dtype)
|
||||||
@@ -101,7 +101,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|||||||
|
|
||||||
inverted_mask = 1.0 - expanded_mask
|
inverted_mask = 1.0 - expanded_mask
|
||||||
|
|
||||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||||
|
|
||||||
|
|
||||||
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
|
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
|
||||||
@@ -998,7 +998,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
|
|||||||
if input_shape[-1] > 1:
|
if input_shape[-1] > 1:
|
||||||
combined_attention_mask = _make_causal_mask(
|
combined_attention_mask = _make_causal_mask(
|
||||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||||
).to(self.device)
|
).to(inputs_embeds.device)
|
||||||
|
|
||||||
if attention_mask is not None and combined_attention_mask is not None:
|
if attention_mask is not None and combined_attention_mask is not None:
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
|||||||
Make causal mask used for bi-directional self-attention.
|
Make causal mask used for bi-directional self-attention.
|
||||||
"""
|
"""
|
||||||
bsz, tgt_len = input_ids_shape
|
bsz, tgt_len = input_ids_shape
|
||||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
|
||||||
mask_cond = torch.arange(mask.size(-1))
|
mask_cond = torch.arange(mask.size(-1))
|
||||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||||
mask = mask.to(dtype)
|
mask = mask.to(dtype)
|
||||||
@@ -103,7 +103,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|||||||
|
|
||||||
inverted_mask = 1.0 - expanded_mask
|
inverted_mask = 1.0 - expanded_mask
|
||||||
|
|
||||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||||
|
|
||||||
|
|
||||||
class MarianSinusoidalPositionalEmbedding(nn.Embedding):
|
class MarianSinusoidalPositionalEmbedding(nn.Embedding):
|
||||||
@@ -856,7 +856,7 @@ class MarianDecoder(MarianPreTrainedModel):
|
|||||||
if input_shape[-1] > 1:
|
if input_shape[-1] > 1:
|
||||||
combined_attention_mask = _make_causal_mask(
|
combined_attention_mask = _make_causal_mask(
|
||||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||||
).to(self.device)
|
).to(inputs_embeds.device)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
|||||||
Make causal mask used for bi-directional self-attention.
|
Make causal mask used for bi-directional self-attention.
|
||||||
"""
|
"""
|
||||||
bsz, tgt_len = input_ids_shape
|
bsz, tgt_len = input_ids_shape
|
||||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
|
||||||
mask_cond = torch.arange(mask.size(-1))
|
mask_cond = torch.arange(mask.size(-1))
|
||||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||||
mask = mask.to(dtype)
|
mask = mask.to(dtype)
|
||||||
@@ -119,7 +119,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|||||||
|
|
||||||
inverted_mask = 1.0 - expanded_mask
|
inverted_mask = 1.0 - expanded_mask
|
||||||
|
|
||||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->MBart
|
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->MBart
|
||||||
@@ -909,7 +909,7 @@ class MBartDecoder(MBartPreTrainedModel):
|
|||||||
if input_shape[-1] > 1:
|
if input_shape[-1] > 1:
|
||||||
combined_attention_mask = _make_causal_mask(
|
combined_attention_mask = _make_causal_mask(
|
||||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||||
).to(self.device)
|
).to(inputs_embeds.device)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
|||||||
Make causal mask used for bi-directional self-attention.
|
Make causal mask used for bi-directional self-attention.
|
||||||
"""
|
"""
|
||||||
bsz, tgt_len = input_ids_shape
|
bsz, tgt_len = input_ids_shape
|
||||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
|
||||||
mask_cond = torch.arange(mask.size(-1))
|
mask_cond = torch.arange(mask.size(-1))
|
||||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||||
mask = mask.to(dtype)
|
mask = mask.to(dtype)
|
||||||
@@ -82,7 +82,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|||||||
|
|
||||||
inverted_mask = 1.0 - expanded_mask
|
inverted_mask = 1.0 - expanded_mask
|
||||||
|
|
||||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||||
|
|
||||||
|
|
||||||
class OPTLearnedPositionalEmbedding(nn.Embedding):
|
class OPTLearnedPositionalEmbedding(nn.Embedding):
|
||||||
@@ -513,7 +513,7 @@ class OPTDecoder(OPTPreTrainedModel):
|
|||||||
if input_shape[-1] > 1:
|
if input_shape[-1] > 1:
|
||||||
combined_attention_mask = _make_causal_mask(
|
combined_attention_mask = _make_causal_mask(
|
||||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||||
).to(self.device)
|
).to(inputs_embeds.device)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
|||||||
Make causal mask used for bi-directional self-attention.
|
Make causal mask used for bi-directional self-attention.
|
||||||
"""
|
"""
|
||||||
bsz, tgt_len = input_ids_shape
|
bsz, tgt_len = input_ids_shape
|
||||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
|
||||||
mask_cond = torch.arange(mask.size(-1))
|
mask_cond = torch.arange(mask.size(-1))
|
||||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||||
mask = mask.to(dtype)
|
mask = mask.to(dtype)
|
||||||
@@ -102,7 +102,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|||||||
|
|
||||||
inverted_mask = 1.0 - expanded_mask
|
inverted_mask = 1.0 - expanded_mask
|
||||||
|
|
||||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Pegasus
|
# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Pegasus
|
||||||
@@ -876,7 +876,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
|
|||||||
if input_shape[-1] > 1:
|
if input_shape[-1] > 1:
|
||||||
combined_attention_mask = _make_causal_mask(
|
combined_attention_mask = _make_causal_mask(
|
||||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||||
).to(self.device)
|
).to(inputs_embeds.device)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
|||||||
Make causal mask used for bi-directional self-attention.
|
Make causal mask used for bi-directional self-attention.
|
||||||
"""
|
"""
|
||||||
bsz, tgt_len = input_ids_shape
|
bsz, tgt_len = input_ids_shape
|
||||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
|
||||||
mask_cond = torch.arange(mask.size(-1))
|
mask_cond = torch.arange(mask.size(-1))
|
||||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||||
mask = mask.to(dtype)
|
mask = mask.to(dtype)
|
||||||
@@ -116,7 +116,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|||||||
|
|
||||||
inverted_mask = 1.0 - expanded_mask
|
inverted_mask = 1.0 - expanded_mask
|
||||||
|
|
||||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->PLBart
|
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->PLBart
|
||||||
@@ -883,7 +883,7 @@ class PLBartDecoder(PLBartPreTrainedModel):
|
|||||||
if input_shape[-1] > 1:
|
if input_shape[-1] > 1:
|
||||||
combined_attention_mask = _make_causal_mask(
|
combined_attention_mask = _make_causal_mask(
|
||||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||||
).to(self.device)
|
).to(inputs_embeds.device)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
|||||||
Make causal mask used for bi-directional self-attention.
|
Make causal mask used for bi-directional self-attention.
|
||||||
"""
|
"""
|
||||||
bsz, tgt_len = input_ids_shape
|
bsz, tgt_len = input_ids_shape
|
||||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
|
||||||
mask_cond = torch.arange(mask.size(-1))
|
mask_cond = torch.arange(mask.size(-1))
|
||||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||||
mask = mask.to(dtype)
|
mask = mask.to(dtype)
|
||||||
@@ -91,7 +91,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|||||||
|
|
||||||
inverted_mask = 1.0 - expanded_mask
|
inverted_mask = 1.0 - expanded_mask
|
||||||
|
|
||||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||||
|
|
||||||
|
|
||||||
class Conv1dSubsampler(nn.Module):
|
class Conv1dSubsampler(nn.Module):
|
||||||
@@ -888,7 +888,7 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
|
|||||||
if input_shape[-1] > 1:
|
if input_shape[-1] > 1:
|
||||||
combined_attention_mask = _make_causal_mask(
|
combined_attention_mask = _make_causal_mask(
|
||||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||||
).to(self.device)
|
).to(inputs_embeds.device)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
|||||||
Make causal mask used for bi-directional self-attention.
|
Make causal mask used for bi-directional self-attention.
|
||||||
"""
|
"""
|
||||||
bsz, tgt_len = input_ids_shape
|
bsz, tgt_len = input_ids_shape
|
||||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
|
||||||
mask_cond = torch.arange(mask.size(-1))
|
mask_cond = torch.arange(mask.size(-1))
|
||||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||||
mask = mask.to(dtype)
|
mask = mask.to(dtype)
|
||||||
@@ -71,7 +71,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|||||||
|
|
||||||
inverted_mask = 1.0 - expanded_mask
|
inverted_mask = 1.0 - expanded_mask
|
||||||
|
|
||||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextSinusoidalPositionalEmbedding with Speech2Text->Speech2Text2
|
# Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextSinusoidalPositionalEmbedding with Speech2Text->Speech2Text2
|
||||||
@@ -495,7 +495,7 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel):
|
|||||||
if input_shape[-1] > 1:
|
if input_shape[-1] > 1:
|
||||||
combined_attention_mask = _make_causal_mask(
|
combined_attention_mask = _make_causal_mask(
|
||||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||||
).to(self.device)
|
).to(inputs_embeds.device)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
|||||||
Make causal mask used for bi-directional self-attention.
|
Make causal mask used for bi-directional self-attention.
|
||||||
"""
|
"""
|
||||||
bsz, tgt_len = input_ids_shape
|
bsz, tgt_len = input_ids_shape
|
||||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
|
||||||
mask_cond = torch.arange(mask.size(-1))
|
mask_cond = torch.arange(mask.size(-1))
|
||||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||||
mask = mask.to(dtype)
|
mask = mask.to(dtype)
|
||||||
@@ -72,7 +72,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|||||||
|
|
||||||
inverted_mask = 1.0 - expanded_mask
|
inverted_mask = 1.0 - expanded_mask
|
||||||
|
|
||||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->TrOCR
|
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->TrOCR
|
||||||
@@ -524,7 +524,7 @@ class TrOCRDecoder(TrOCRPreTrainedModel):
|
|||||||
if input_shape[-1] > 1:
|
if input_shape[-1] > 1:
|
||||||
combined_attention_mask = _make_causal_mask(
|
combined_attention_mask = _make_causal_mask(
|
||||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||||
).to(self.device)
|
).to(inputs_embeds.device)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
|||||||
Make causal mask used for bi-directional self-attention.
|
Make causal mask used for bi-directional self-attention.
|
||||||
"""
|
"""
|
||||||
bsz, tgt_len = input_ids_shape
|
bsz, tgt_len = input_ids_shape
|
||||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
|
||||||
mask_cond = torch.arange(mask.size(-1))
|
mask_cond = torch.arange(mask.size(-1))
|
||||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||||
mask = mask.to(dtype)
|
mask = mask.to(dtype)
|
||||||
@@ -142,7 +142,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|||||||
|
|
||||||
inverted_mask = 1.0 - expanded_mask
|
inverted_mask = 1.0 - expanded_mask
|
||||||
|
|
||||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||||
|
|
||||||
|
|
||||||
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
|
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
|
||||||
@@ -577,7 +577,7 @@ class XGLMModel(XGLMPreTrainedModel):
|
|||||||
if input_shape[-1] > 1:
|
if input_shape[-1] > 1:
|
||||||
combined_attention_mask = _make_causal_mask(
|
combined_attention_mask = _make_causal_mask(
|
||||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||||
).to(self.device)
|
).to(inputs_embeds.device)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
@@ -712,7 +712,7 @@ class XGLMModel(XGLMPreTrainedModel):
|
|||||||
|
|
||||||
hidden_states = inputs_embeds + positions
|
hidden_states = inputs_embeds + positions
|
||||||
|
|
||||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = nn.functional.dropout(hidden_states, p=float(self.dropout), training=self.training)
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
|||||||
@@ -1056,7 +1056,6 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
|
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
|
||||||
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
|
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
|
||||||
|
|
||||||
pos_emb = pos_emb.to(self.device)
|
|
||||||
return pos_emb
|
return pos_emb
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
@@ -1206,6 +1205,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
|
|
||||||
# Positional encoding
|
# Positional encoding
|
||||||
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
|
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
|
||||||
|
pos_emb = pos_emb.to(output_h.device)
|
||||||
pos_emb = self.dropout(pos_emb)
|
pos_emb = self.dropout(pos_emb)
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
|
|||||||
@@ -29,27 +29,23 @@ from torch import nn
|
|||||||
from torch.fx import Graph, GraphModule, Proxy, Tracer
|
from torch.fx import Graph, GraphModule, Proxy, Tracer
|
||||||
from torch.fx.proxy import ParameterProxy
|
from torch.fx.proxy import ParameterProxy
|
||||||
|
|
||||||
from .. import (
|
from .. import PretrainedConfig, PreTrainedModel, logging
|
||||||
CONFIG_MAPPING,
|
|
||||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
|
||||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
|
||||||
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
|
|
||||||
MODEL_FOR_MASKED_LM_MAPPING,
|
|
||||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
|
||||||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
|
||||||
MODEL_FOR_PRETRAINING_MAPPING,
|
|
||||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
|
||||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
|
||||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
|
||||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
|
||||||
MODEL_MAPPING,
|
|
||||||
GPT2DoubleHeadsModel,
|
|
||||||
PretrainedConfig,
|
|
||||||
PreTrainedModel,
|
|
||||||
XLNetForQuestionAnswering,
|
|
||||||
logging,
|
|
||||||
)
|
|
||||||
from ..models.auto import get_values
|
from ..models.auto import get_values
|
||||||
|
from ..models.auto.modeling_auto import (
|
||||||
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||||
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||||
|
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
|
||||||
|
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
||||||
|
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
|
||||||
|
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
|
||||||
|
MODEL_FOR_PRETRAINING_MAPPING_NAMES,
|
||||||
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
|
||||||
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
|
||||||
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
|
||||||
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
|
||||||
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
||||||
|
MODEL_MAPPING_NAMES,
|
||||||
|
)
|
||||||
from ..utils import TORCH_FX_REQUIRED_VERSION, is_torch_fx_available
|
from ..utils import TORCH_FX_REQUIRED_VERSION, is_torch_fx_available
|
||||||
from ..utils.versions import importlib_metadata
|
from ..utils.versions import importlib_metadata
|
||||||
|
|
||||||
@@ -57,25 +53,25 @@ from ..utils.versions import importlib_metadata
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _generate_supported_model_classes(
|
def _generate_supported_model_class_names(
|
||||||
model_name: Type[PretrainedConfig],
|
model_name: Type[PretrainedConfig],
|
||||||
supported_tasks: Optional[Union[str, List[str]]] = None,
|
supported_tasks: Optional[Union[str, List[str]]] = None,
|
||||||
) -> List[Type[PreTrainedModel]]:
|
) -> List[str]:
|
||||||
|
|
||||||
model_config_class = CONFIG_MAPPING[model_name]
|
|
||||||
task_mapping = {
|
task_mapping = {
|
||||||
"default": MODEL_MAPPING,
|
"default": MODEL_MAPPING_NAMES,
|
||||||
"pretraining": MODEL_FOR_PRETRAINING_MAPPING,
|
"pretraining": MODEL_FOR_PRETRAINING_MAPPING_NAMES,
|
||||||
"next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
"next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
|
||||||
"masked-lm": MODEL_FOR_MASKED_LM_MAPPING,
|
"masked-lm": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
||||||
"causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING,
|
"causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||||
"seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
"seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
|
||||||
"multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
"speech-seq2seq": MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
|
||||||
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
"multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
|
||||||
"sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
|
||||||
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
"sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
|
||||||
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
|
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
||||||
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
|
||||||
|
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||||
}
|
}
|
||||||
|
|
||||||
if supported_tasks is None:
|
if supported_tasks is None:
|
||||||
@@ -83,55 +79,78 @@ def _generate_supported_model_classes(
|
|||||||
if isinstance(supported_tasks, str):
|
if isinstance(supported_tasks, str):
|
||||||
supported_tasks = [supported_tasks]
|
supported_tasks = [supported_tasks]
|
||||||
|
|
||||||
model_classes = []
|
model_class_names = []
|
||||||
for task in supported_tasks:
|
for task in supported_tasks:
|
||||||
model_class = task_mapping[task].get(model_config_class, None)
|
class_name = task_mapping[task].get(model_name, None)
|
||||||
if model_class:
|
if class_name:
|
||||||
model_classes.append(model_class)
|
model_class_names.append(class_name)
|
||||||
|
|
||||||
return model_classes
|
return model_class_names
|
||||||
|
|
||||||
|
|
||||||
_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
|
_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
|
||||||
"albert",
|
"albert",
|
||||||
|
"bart",
|
||||||
"bert",
|
"bert",
|
||||||
|
"blenderbot",
|
||||||
|
"blenderbot-small",
|
||||||
|
"clip",
|
||||||
"distilbert",
|
"distilbert",
|
||||||
"mobilebert",
|
|
||||||
"electra",
|
"electra",
|
||||||
"megatron-bert",
|
|
||||||
"gpt2",
|
"gpt2",
|
||||||
"gptj",
|
|
||||||
"gpt_neo",
|
"gpt_neo",
|
||||||
"t5",
|
"gptj",
|
||||||
|
"layoutlm",
|
||||||
|
"m2m_100",
|
||||||
|
"marian",
|
||||||
|
"mbart",
|
||||||
|
"megatron-bert",
|
||||||
|
"mobilebert",
|
||||||
|
"mt5",
|
||||||
|
"opt",
|
||||||
|
"pegasus",
|
||||||
|
"plbart",
|
||||||
"roberta",
|
"roberta",
|
||||||
"vit",
|
"speech_to_text",
|
||||||
|
"speech_to_text_2",
|
||||||
"swin",
|
"swin",
|
||||||
|
"t5",
|
||||||
|
"trocr",
|
||||||
|
"vit",
|
||||||
|
"xglm",
|
||||||
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
|
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
|
||||||
# "layoutlm",
|
|
||||||
# "xlnet",
|
# "xlnet",
|
||||||
]
|
]
|
||||||
|
|
||||||
_REGULAR_SUPPORTED_MODELS = []
|
_REGULAR_SUPPORTED_MODELS = []
|
||||||
for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS:
|
for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS:
|
||||||
if isinstance(item, dict):
|
if isinstance(item, dict):
|
||||||
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_classes(**item))
|
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(**item))
|
||||||
else:
|
else:
|
||||||
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_classes(item))
|
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(item))
|
||||||
|
|
||||||
_SPECIAL_SUPPORTED_MODELS = [
|
_SPECIAL_SUPPORTED_MODELS = [
|
||||||
GPT2DoubleHeadsModel,
|
"CLIPTextModel",
|
||||||
|
"CLIPVisionModel",
|
||||||
|
"GPT2DoubleHeadsModel",
|
||||||
|
"Speech2Text2Decoder",
|
||||||
|
"TrOCRDecoder",
|
||||||
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
|
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
|
||||||
# XLNetForQuestionAnswering,
|
# XLNetForQuestionAnswering,
|
||||||
]
|
]
|
||||||
_SUPPORTED_MODELS = tuple(
|
_SUPPORTED_MODELS = tuple(sorted(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)))
|
||||||
sorted(list(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)), key=lambda c: c.__name__)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def torch_nn_embedding(self, input):
|
def torch_nn_embedding(self, input):
|
||||||
return torch.empty(*input.shape, self.weight.shape[-1], device="meta")
|
return torch.empty(*input.shape, self.weight.shape[-1], device="meta")
|
||||||
|
|
||||||
|
|
||||||
|
def torch_nn_functional_embedding(
|
||||||
|
input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False
|
||||||
|
):
|
||||||
|
return torch.empty(*input.shape, weight.shape[-1], device="meta")
|
||||||
|
|
||||||
|
|
||||||
def torch_nn_layernorm(self, input):
|
def torch_nn_layernorm(self, input):
|
||||||
return input
|
return input
|
||||||
|
|
||||||
@@ -176,6 +195,12 @@ def torch_arange(*args, **kwargs):
|
|||||||
start, end = args
|
start, end = args
|
||||||
else:
|
else:
|
||||||
start, end, step = args
|
start, end, step = args
|
||||||
|
if isinstance(start, float):
|
||||||
|
start = int(start)
|
||||||
|
if isinstance(end, float):
|
||||||
|
start = int(end)
|
||||||
|
if isinstance(step, float):
|
||||||
|
step = int(step)
|
||||||
step = kwargs.get("step", step)
|
step = kwargs.get("step", step)
|
||||||
dtype = kwargs.get("dtype")
|
dtype = kwargs.get("dtype")
|
||||||
return torch.empty((end - start) // step, dtype=dtype, device="meta")
|
return torch.empty((end - start) // step, dtype=dtype, device="meta")
|
||||||
@@ -265,6 +290,14 @@ def torch_matmul(input, other, *, out=None):
|
|||||||
return torch.empty(*shape, device="meta")
|
return torch.empty(*shape, device="meta")
|
||||||
|
|
||||||
|
|
||||||
|
def torch_bmm(input, mat2, *, out=None):
|
||||||
|
if out is not None:
|
||||||
|
raise ValueError("Don't support in-place abs for MetaTensor analysis")
|
||||||
|
batch_size, n, m = input.shape
|
||||||
|
_, _, p = mat2.shape
|
||||||
|
return torch.empty(batch_size, n, p, device="meta")
|
||||||
|
|
||||||
|
|
||||||
def torch_einsum(equation, *operands):
|
def torch_einsum(equation, *operands):
|
||||||
# TODO: infer shape without performing the computation, this might be quite hard.
|
# TODO: infer shape without performing the computation, this might be quite hard.
|
||||||
concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands)
|
concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands)
|
||||||
@@ -285,13 +318,39 @@ def torch_index_select(input, dim, index, *, out=None):
|
|||||||
|
|
||||||
|
|
||||||
def torch_tensor_index_select(self, dim, index):
|
def torch_tensor_index_select(self, dim, index):
|
||||||
return torch_tensor_index_select(self, dim, index)
|
return torch_index_select(self, dim, index)
|
||||||
|
|
||||||
|
|
||||||
def torch_roll(input, shifts, dims=None):
|
def torch_roll(input, shifts, dims=None):
|
||||||
return input
|
return input
|
||||||
|
|
||||||
|
|
||||||
|
def torch_flip(input, dims):
|
||||||
|
return input
|
||||||
|
|
||||||
|
|
||||||
|
def torch_tensor_flip(self, dims):
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def torch_nn_conv1d(self, input):
|
||||||
|
l_in = input.shape[-1]
|
||||||
|
shape = None
|
||||||
|
padding = self.padding
|
||||||
|
if padding == "valid":
|
||||||
|
padding = (0, 0)
|
||||||
|
if padding == "same":
|
||||||
|
shape = list(input.shape)
|
||||||
|
if shape is None:
|
||||||
|
shape = list(input.shape)
|
||||||
|
l_out = math.floor(
|
||||||
|
(l_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
|
||||||
|
)
|
||||||
|
shape[-1] = l_out
|
||||||
|
shape[-2] = self.out_channels
|
||||||
|
return torch.empty(shape, device="meta")
|
||||||
|
|
||||||
|
|
||||||
def torch_nn_conv2d(self, input):
|
def torch_nn_conv2d(self, input):
|
||||||
h_in, w_in = input.shape[-2:]
|
h_in, w_in = input.shape[-2:]
|
||||||
shape = None
|
shape = None
|
||||||
@@ -325,6 +384,21 @@ def torch_tensor_unsqueeze(self, dim):
|
|||||||
return torch_unsqueeze(self, dim)
|
return torch_unsqueeze(self, dim)
|
||||||
|
|
||||||
|
|
||||||
|
def torch_unique_consecutive(input, **kwargs):
|
||||||
|
output = torch.unique_consecutive(torch.zeros_like(input, device="cpu"), **kwargs)
|
||||||
|
if isinstance(output, torch.Tensor):
|
||||||
|
return output.to("meta")
|
||||||
|
else:
|
||||||
|
return tuple(map(output, lambda x: x.to("meta")))
|
||||||
|
|
||||||
|
|
||||||
|
def torch_nn_functional_one_hot(tensor, num_classes=-1):
|
||||||
|
if num_classes < 0:
|
||||||
|
raise ValueError("Don't support automatic num_classes inference for MetaTensor analysis")
|
||||||
|
shape = list(tensor.shape) + [num_classes]
|
||||||
|
return torch.empty(shape, device="meta")
|
||||||
|
|
||||||
|
|
||||||
def torch_nn_mseloss(self, input, target):
|
def torch_nn_mseloss(self, input, target):
|
||||||
if self.reduction == "none":
|
if self.reduction == "none":
|
||||||
shape = target.shape
|
shape = target.shape
|
||||||
@@ -350,14 +424,27 @@ def torch_nn_bcewithlogitsloss(self, input, target):
|
|||||||
|
|
||||||
|
|
||||||
def operator_getitem(a, b):
|
def operator_getitem(a, b):
|
||||||
|
def to_concrete(t):
|
||||||
|
if isinstance(t, torch.Tensor):
|
||||||
|
concrete = torch.ones_like(t, device="cpu")
|
||||||
|
if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]:
|
||||||
|
concrete = concrete.to(torch.int64)
|
||||||
|
return concrete
|
||||||
|
return t
|
||||||
|
|
||||||
if isinstance(a, torch.Tensor):
|
if isinstance(a, torch.Tensor):
|
||||||
# TODO: infer shape without performing the computation.
|
# TODO: infer shape without performing the computation.
|
||||||
|
if isinstance(b, tuple):
|
||||||
|
b = tuple(map(to_concrete, b))
|
||||||
|
else:
|
||||||
|
b = to_concrete(b)
|
||||||
return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
|
return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
|
||||||
return operator.getitem(a, b)
|
return operator.getitem(a, b)
|
||||||
|
|
||||||
|
|
||||||
_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
|
_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
|
||||||
torch.nn.Embedding: torch_nn_embedding,
|
torch.nn.Embedding: torch_nn_embedding,
|
||||||
|
torch.nn.functional.embedding: torch_nn_functional_embedding,
|
||||||
torch.nn.LayerNorm: torch_nn_layernorm,
|
torch.nn.LayerNorm: torch_nn_layernorm,
|
||||||
torch.nn.Linear: torch_nn_linear,
|
torch.nn.Linear: torch_nn_linear,
|
||||||
torch.relu: torch_relu,
|
torch.relu: torch_relu,
|
||||||
@@ -372,15 +459,20 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
|
|||||||
torch.mul: torch_mul,
|
torch.mul: torch_mul,
|
||||||
torch.Tensor.mul: torch_tensor_mul,
|
torch.Tensor.mul: torch_tensor_mul,
|
||||||
torch.matmul: torch_matmul,
|
torch.matmul: torch_matmul,
|
||||||
|
torch.bmm: torch_bmm,
|
||||||
torch.einsum: torch_einsum,
|
torch.einsum: torch_einsum,
|
||||||
torch.Tensor.repeat: torch_tensor_repeat,
|
torch.Tensor.repeat: torch_tensor_repeat,
|
||||||
torch.roll: torch_roll,
|
torch.roll: torch_roll,
|
||||||
# TODO: those might not be needed.
|
torch.flip: torch_flip,
|
||||||
# torch.index_select: torch_index_select,
|
torch.Tensor.flip: torch_tensor_flip,
|
||||||
# torch.Tensor.index_select: torch_tensor_index_select,
|
torch.index_select: torch_index_select,
|
||||||
|
torch.Tensor.index_select: torch_tensor_index_select,
|
||||||
|
torch.nn.Conv1d: torch_nn_conv1d,
|
||||||
torch.nn.Conv2d: torch_nn_conv2d,
|
torch.nn.Conv2d: torch_nn_conv2d,
|
||||||
torch.unsqueeze: torch_unsqueeze,
|
torch.unsqueeze: torch_unsqueeze,
|
||||||
torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
|
torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
|
||||||
|
torch.unique_consecutive: torch_unique_consecutive,
|
||||||
|
torch.nn.functional.one_hot: torch_nn_functional_one_hot,
|
||||||
torch.nn.MSELoss: torch_nn_mseloss,
|
torch.nn.MSELoss: torch_nn_mseloss,
|
||||||
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
|
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
|
||||||
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
|
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
|
||||||
@@ -513,7 +605,7 @@ class HFTracer(Tracer):
|
|||||||
# Feature flag for proxying accesses to buffer values
|
# Feature flag for proxying accesses to buffer values
|
||||||
proxy_buffer_attributes: bool = True
|
proxy_buffer_attributes: bool = True
|
||||||
allow_insert_stateless_mods: bool = True
|
allow_insert_stateless_mods: bool = True
|
||||||
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full_like", "eye"]
|
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty"]
|
||||||
|
|
||||||
def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
|
def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
|
||||||
|
|
||||||
@@ -532,22 +624,22 @@ class HFTracer(Tracer):
|
|||||||
"""Generates dummy input for model inference recording."""
|
"""Generates dummy input for model inference recording."""
|
||||||
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
|
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
|
||||||
# from pickle, or from the "__class__" attribute in the general case.
|
# from pickle, or from the "__class__" attribute in the general case.
|
||||||
model_class = getattr(model, "class_for_deserialization", model.__class__)
|
model_class_name = getattr(model, "class_for_deserialization", model.__class__).__name__
|
||||||
device = model.device
|
device = model.device
|
||||||
inputs_dict = {}
|
inputs_dict = {}
|
||||||
|
|
||||||
if input_name in ["labels", "start_positions", "end_positions"]:
|
if input_name in ["labels", "start_positions", "end_positions"]:
|
||||||
|
|
||||||
batch_size = shape[0]
|
batch_size = shape[0]
|
||||||
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
if model_class_name in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
|
||||||
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
|
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||||
elif model_class in [
|
elif model_class_name in [
|
||||||
*get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING),
|
*get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
|
||||||
XLNetForQuestionAnswering,
|
"XLNetForQuestionAnswering",
|
||||||
]:
|
]:
|
||||||
inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
|
inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||||
inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
|
inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||||
elif model_class in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
|
elif model_class_name in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES):
|
||||||
if not hasattr(model.config, "problem_type") or model.config.problem_type is None:
|
if not hasattr(model.config, "problem_type") or model.config.problem_type is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Could not retrieve the problem type for the sequence classification task, please set "
|
"Could not retrieve the problem type for the sequence classification task, please set "
|
||||||
@@ -571,32 +663,49 @@ class HFTracer(Tracer):
|
|||||||
)
|
)
|
||||||
inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device)
|
inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device)
|
||||||
|
|
||||||
elif model_class in [
|
elif model_class_name in [
|
||||||
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING),
|
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
|
||||||
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
|
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
|
||||||
]:
|
]:
|
||||||
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
|
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||||
elif model_class in [
|
elif model_class_name in [
|
||||||
*get_values(MODEL_FOR_PRETRAINING_MAPPING),
|
*get_values(MODEL_FOR_PRETRAINING_MAPPING_NAMES),
|
||||||
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
|
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES),
|
||||||
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING),
|
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
|
||||||
*get_values(MODEL_FOR_MASKED_LM_MAPPING),
|
*get_values(MODEL_FOR_MASKED_LM_MAPPING_NAMES),
|
||||||
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
|
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES),
|
||||||
GPT2DoubleHeadsModel,
|
"GPT2DoubleHeadsModel",
|
||||||
]:
|
]:
|
||||||
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
|
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"{model_class} not supported yet.")
|
raise NotImplementedError(f"{model_class_name} not supported yet.")
|
||||||
elif "pixel_values" in input_name:
|
elif "pixel_values" in input_name:
|
||||||
batch_size = shape[0]
|
batch_size = shape[0]
|
||||||
image_size = model.config.image_size
|
image_size = getattr(model.config, "image_size", None)
|
||||||
|
if image_size is None:
|
||||||
|
if hasattr(model.config, "vision_config"):
|
||||||
|
image_size = model.config.vision_config.image_size
|
||||||
|
elif hasattr(model.config, "encoder"):
|
||||||
|
image_size = model.config.encoder.image_size
|
||||||
|
else:
|
||||||
|
raise AttributeError('Could not find the "image_size" field in the model config')
|
||||||
|
|
||||||
|
# If no num_channels is in the config, use some arbitrary value.
|
||||||
|
num_channels = getattr(model.config, "num_channels", 3)
|
||||||
if not isinstance(image_size, collections.abc.Iterable):
|
if not isinstance(image_size, collections.abc.Iterable):
|
||||||
image_size = (image_size, image_size)
|
image_size = (image_size, image_size)
|
||||||
height, width = image_size
|
height, width = image_size
|
||||||
inputs_dict[input_name] = torch.zeros(
|
inputs_dict[input_name] = torch.zeros(
|
||||||
batch_size, model.config.num_channels, height, width, dtype=torch.float32, device=device
|
batch_size, num_channels, height, width, dtype=torch.float32, device=device
|
||||||
)
|
)
|
||||||
|
elif "bbox" in input_name:
|
||||||
|
inputs_dict[input_name] = torch.zeros(*shape, 4, dtype=torch.float, device=device)
|
||||||
|
elif "input_features" in input_name:
|
||||||
|
inputs_dict[input_name] = torch.zeros(
|
||||||
|
*shape, model.config.input_feat_per_channel, dtype=torch.float, device=device
|
||||||
|
)
|
||||||
|
elif "inputs" in input_name:
|
||||||
|
inputs_dict[input_name] = torch.zeros(*shape, dtype=torch.float, device=device)
|
||||||
elif "mask" in input_name or "ids" in input_name:
|
elif "mask" in input_name or "ids" in input_name:
|
||||||
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
|
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
|
||||||
else:
|
else:
|
||||||
@@ -628,6 +737,8 @@ class HFTracer(Tracer):
|
|||||||
if kind == "call_function":
|
if kind == "call_function":
|
||||||
meta_target = _MANUAL_META_OVERRIDES.get(target, target)
|
meta_target = _MANUAL_META_OVERRIDES.get(target, target)
|
||||||
meta_out = meta_target(*args_metas, **kwargs_metas)
|
meta_out = meta_target(*args_metas, **kwargs_metas)
|
||||||
|
if isinstance(meta_out, torch.Tensor):
|
||||||
|
meta_out = meta_out.to(device="meta")
|
||||||
elif kind == "call_method":
|
elif kind == "call_method":
|
||||||
method = getattr(args_metas[0].__class__, target)
|
method = getattr(args_metas[0].__class__, target)
|
||||||
meta_target = _MANUAL_META_OVERRIDES.get(method, method)
|
meta_target = _MANUAL_META_OVERRIDES.get(method, method)
|
||||||
@@ -731,7 +842,7 @@ class HFTracer(Tracer):
|
|||||||
sequence_length = _generate_random_int()
|
sequence_length = _generate_random_int()
|
||||||
shape = [batch_size, sequence_length]
|
shape = [batch_size, sequence_length]
|
||||||
|
|
||||||
if root.__class__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
if root.__class__.__name__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
|
||||||
num_choices = _generate_random_int(low=2, high=5)
|
num_choices = _generate_random_int(low=2, high=5)
|
||||||
shape.insert(1, num_choices)
|
shape.insert(1, num_choices)
|
||||||
|
|
||||||
@@ -870,11 +981,22 @@ def symbolic_trace(
|
|||||||
if input_names is None:
|
if input_names is None:
|
||||||
input_names = model.dummy_inputs.keys()
|
input_names = model.dummy_inputs.keys()
|
||||||
|
|
||||||
|
input_names = list(input_names)
|
||||||
|
|
||||||
sig = inspect.signature(model.forward)
|
sig = inspect.signature(model.forward)
|
||||||
|
|
||||||
|
if not (set(input_names) <= set(sig.parameters.keys())):
|
||||||
|
formatted_input_names = input_names[0] if len(input_names) == 1 else ", ".join(input_names)
|
||||||
|
formatted_allowed_input_names = ", ".join(sig.parameters.keys())
|
||||||
|
raise ValueError(
|
||||||
|
f"The model does not have input(s) named: {formatted_input_names}, expected a subset of the following:"
|
||||||
|
f" {formatted_allowed_input_names}"
|
||||||
|
)
|
||||||
|
|
||||||
concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
|
concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
|
||||||
|
|
||||||
if not isinstance(model, _SUPPORTED_MODELS):
|
if model.__class__.__name__ not in _SUPPORTED_MODELS:
|
||||||
supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS))
|
supported_model_names = ", ".join(_SUPPORTED_MODELS)
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"
|
f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -413,6 +413,7 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
|
all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
|
fx_compatible = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
|
|
||||||
@@ -1386,6 +1387,7 @@ class BartStandaloneDecoderModelTester:
|
|||||||
class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (BartDecoder, BartForCausalLM) if is_torch_available() else ()
|
all_model_classes = (BartDecoder, BartForCausalLM) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (BartForCausalLM,) if is_torch_available() else ()
|
all_generative_model_classes = (BartForCausalLM,) if is_torch_available() else ()
|
||||||
|
fx_comptatible = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
|
|
||||||
|
|||||||
@@ -218,6 +218,7 @@ class BlenderbotModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
|||||||
all_model_classes = (BlenderbotModel, BlenderbotForConditionalGeneration) if is_torch_available() else ()
|
all_model_classes = (BlenderbotModel, BlenderbotForConditionalGeneration) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (BlenderbotForConditionalGeneration,) if is_torch_available() else ()
|
all_generative_model_classes = (BlenderbotForConditionalGeneration,) if is_torch_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
|
fx_compatible = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
|
|
||||||
|
|||||||
@@ -213,6 +213,7 @@ class BlenderbotSmallModelTest(ModelTesterMixin, GenerationTesterMixin, unittest
|
|||||||
all_model_classes = (BlenderbotSmallModel, BlenderbotSmallForConditionalGeneration) if is_torch_available() else ()
|
all_model_classes = (BlenderbotSmallModel, BlenderbotSmallForConditionalGeneration) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (BlenderbotSmallForConditionalGeneration,) if is_torch_available() else ()
|
all_generative_model_classes = (BlenderbotSmallForConditionalGeneration,) if is_torch_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
|
fx_compatible = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
|
|
||||||
|
|||||||
@@ -152,7 +152,7 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
all_model_classes = (CLIPVisionModel,) if is_torch_available() else ()
|
all_model_classes = (CLIPVisionModel,) if is_torch_available() else ()
|
||||||
|
fx_compatible = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
@@ -303,6 +303,7 @@ class CLIPTextModelTester:
|
|||||||
class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
|
class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (CLIPTextModel,) if is_torch_available() else ()
|
all_model_classes = (CLIPTextModel,) if is_torch_available() else ()
|
||||||
|
fx_compatible = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
|
||||||
@@ -388,6 +389,7 @@ class CLIPModelTester:
|
|||||||
@require_torch
|
@require_torch
|
||||||
class CLIPModelTest(ModelTesterMixin, unittest.TestCase):
|
class CLIPModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (CLIPModel,) if is_torch_available() else ()
|
all_model_classes = (CLIPModel,) if is_torch_available() else ()
|
||||||
|
fx_compatible = True
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
|
|||||||
@@ -215,6 +215,7 @@ class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
fx_compatible = True
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = LayoutLMModelTester(self)
|
self.model_tester = LayoutLMModelTester(self)
|
||||||
|
|||||||
@@ -231,6 +231,7 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
|
|||||||
)
|
)
|
||||||
all_generative_model_classes = (M2M100ForConditionalGeneration,) if is_torch_available() else ()
|
all_generative_model_classes = (M2M100ForConditionalGeneration,) if is_torch_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
|
fx_compatible = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
|
|
||||||
|
|||||||
@@ -230,6 +230,7 @@ class MarianModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
|
|||||||
all_model_classes = (MarianModel, MarianMTModel) if is_torch_available() else ()
|
all_model_classes = (MarianModel, MarianMTModel) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (MarianMTModel,) if is_torch_available() else ()
|
all_generative_model_classes = (MarianMTModel,) if is_torch_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
|
fx_compatible = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
|
|
||||||
|
|||||||
@@ -224,6 +224,7 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
|
|||||||
)
|
)
|
||||||
all_generative_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else ()
|
all_generative_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
|
fx_compatible = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
|
|
||||||
|
|||||||
@@ -178,6 +178,7 @@ class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
all_model_classes = (OPTModel, OPTForCausalLM) if is_torch_available() else ()
|
all_model_classes = (OPTModel, OPTForCausalLM) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (OPTForCausalLM,) if is_torch_available() else ()
|
all_generative_model_classes = (OPTForCausalLM,) if is_torch_available() else ()
|
||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
|
fx_compatible = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
|
|
||||||
|
|||||||
@@ -229,6 +229,7 @@ class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
|||||||
all_model_classes = (PegasusModel, PegasusForConditionalGeneration) if is_torch_available() else ()
|
all_model_classes = (PegasusModel, PegasusForConditionalGeneration) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
|
all_generative_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
|
fx_compatible = True
|
||||||
test_resize_position_embeddings = True
|
test_resize_position_embeddings = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
|
|||||||
@@ -219,6 +219,7 @@ class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
|
|||||||
)
|
)
|
||||||
all_generative_model_classes = (PLBartForConditionalGeneration,) if is_torch_available() else ()
|
all_generative_model_classes = (PLBartForConditionalGeneration,) if is_torch_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
|
fx_compatible = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@@ -30,7 +31,7 @@ from transformers.testing_utils import (
|
|||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
from transformers.utils import cached_property
|
from transformers.utils import cached_property, is_torch_fx_available
|
||||||
|
|
||||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@@ -43,6 +44,9 @@ if is_torch_available():
|
|||||||
from transformers import Speech2TextForConditionalGeneration, Speech2TextModel, Speech2TextProcessor
|
from transformers import Speech2TextForConditionalGeneration, Speech2TextModel, Speech2TextProcessor
|
||||||
from transformers.models.speech_to_text.modeling_speech_to_text import Speech2TextDecoder, Speech2TextEncoder
|
from transformers.models.speech_to_text.modeling_speech_to_text import Speech2TextDecoder, Speech2TextEncoder
|
||||||
|
|
||||||
|
if is_torch_fx_available():
|
||||||
|
from transformers.utils.fx import symbolic_trace
|
||||||
|
|
||||||
|
|
||||||
def prepare_speech_to_text_inputs_dict(
|
def prepare_speech_to_text_inputs_dict(
|
||||||
config,
|
config,
|
||||||
@@ -271,6 +275,7 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
|
|||||||
all_model_classes = (Speech2TextModel, Speech2TextForConditionalGeneration) if is_torch_available() else ()
|
all_model_classes = (Speech2TextModel, Speech2TextForConditionalGeneration) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (Speech2TextForConditionalGeneration,) if is_torch_available() else ()
|
all_generative_model_classes = (Speech2TextForConditionalGeneration,) if is_torch_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
|
fx_compatible = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
|
|
||||||
@@ -715,6 +720,105 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
|
|||||||
|
|
||||||
self.assertTrue(models_equal)
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
|
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
|
||||||
|
if not is_torch_fx_available() or not self.fx_compatible:
|
||||||
|
return
|
||||||
|
|
||||||
|
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||||
|
configs_no_init.return_dict = False
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config=configs_no_init)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||||
|
labels = inputs.get("labels", None)
|
||||||
|
input_names = [
|
||||||
|
"input_ids",
|
||||||
|
"attention_mask",
|
||||||
|
"decoder_input_ids",
|
||||||
|
"decoder_attention_mask",
|
||||||
|
"input_features",
|
||||||
|
]
|
||||||
|
if labels is not None:
|
||||||
|
input_names.append("labels")
|
||||||
|
|
||||||
|
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||||
|
input_names = list(filtered_inputs.keys())
|
||||||
|
|
||||||
|
model_output = model(**filtered_inputs)
|
||||||
|
|
||||||
|
traced_model = symbolic_trace(model, input_names)
|
||||||
|
traced_output = traced_model(**filtered_inputs)
|
||||||
|
else:
|
||||||
|
input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values", "input_features"]
|
||||||
|
|
||||||
|
labels = inputs.get("labels", None)
|
||||||
|
start_positions = inputs.get("start_positions", None)
|
||||||
|
end_positions = inputs.get("end_positions", None)
|
||||||
|
if labels is not None:
|
||||||
|
input_names.append("labels")
|
||||||
|
if start_positions is not None:
|
||||||
|
input_names.append("start_positions")
|
||||||
|
if end_positions is not None:
|
||||||
|
input_names.append("end_positions")
|
||||||
|
|
||||||
|
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||||
|
input_names = list(filtered_inputs.keys())
|
||||||
|
|
||||||
|
model_output = model(**filtered_inputs)
|
||||||
|
|
||||||
|
traced_model = symbolic_trace(model, input_names)
|
||||||
|
traced_output = traced_model(**filtered_inputs)
|
||||||
|
|
||||||
|
except RuntimeError as e:
|
||||||
|
self.fail(f"Couldn't trace module: {e}")
|
||||||
|
|
||||||
|
def flatten_output(output):
|
||||||
|
flatten = []
|
||||||
|
for x in output:
|
||||||
|
if isinstance(x, (tuple, list)):
|
||||||
|
flatten += flatten_output(x)
|
||||||
|
elif not isinstance(x, torch.Tensor):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
flatten.append(x)
|
||||||
|
return flatten
|
||||||
|
|
||||||
|
model_output = flatten_output(model_output)
|
||||||
|
traced_output = flatten_output(traced_output)
|
||||||
|
num_outputs = len(model_output)
|
||||||
|
|
||||||
|
for i in range(num_outputs):
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(model_output[i], traced_output[i]),
|
||||||
|
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test that the model can be serialized and restored properly
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||||
|
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
|
||||||
|
try:
|
||||||
|
with open(pkl_file_name, "wb") as f:
|
||||||
|
pickle.dump(traced_model, f)
|
||||||
|
with open(pkl_file_name, "rb") as f:
|
||||||
|
loaded = pickle.load(f)
|
||||||
|
except Exception as e:
|
||||||
|
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
|
||||||
|
|
||||||
|
loaded_output = loaded(**filtered_inputs)
|
||||||
|
loaded_output = flatten_output(loaded_output)
|
||||||
|
|
||||||
|
for i in range(num_outputs):
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(model_output[i], loaded_output[i]),
|
||||||
|
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_torchaudio
|
@require_torchaudio
|
||||||
|
|||||||
@@ -179,6 +179,7 @@ class Speech2Text2StandaloneDecoderModelTester:
|
|||||||
class Speech2Text2StandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
class Speech2Text2StandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (Speech2Text2Decoder, Speech2Text2ForCausalLM) if is_torch_available() else ()
|
all_model_classes = (Speech2Text2Decoder, Speech2Text2ForCausalLM) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (Speech2Text2ForCausalLM,) if is_torch_available() else ()
|
all_generative_model_classes = (Speech2Text2ForCausalLM,) if is_torch_available() else ()
|
||||||
|
fx_compatible = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
|
||||||
def setUp(
|
def setUp(
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Testing suite for the PyTorch Swin model. """
|
""" Testing suite for the PyTorch Swin model. """
|
||||||
|
|
||||||
import copy
|
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
@@ -26,7 +25,7 @@ from transformers.testing_utils import require_torch, require_vision, slow, torc
|
|||||||
from transformers.utils import cached_property, is_torch_available, is_torch_fx_available, is_vision_available
|
from transformers.utils import cached_property, is_torch_available, is_torch_fx_available, is_vision_available
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -45,14 +44,6 @@ if is_torch_fx_available():
|
|||||||
from transformers.utils.fx import symbolic_trace
|
from transformers.utils.fx import symbolic_trace
|
||||||
|
|
||||||
|
|
||||||
def _config_zero_init(config):
|
|
||||||
configs_no_init = copy.deepcopy(config)
|
|
||||||
for key in configs_no_init.__dict__.keys():
|
|
||||||
if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
|
|
||||||
setattr(configs_no_init, key, 1e-10)
|
|
||||||
return configs_no_init
|
|
||||||
|
|
||||||
|
|
||||||
class SwinModelTester:
|
class SwinModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -407,7 +398,9 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
|
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
input_names.append("labels")
|
input_names.append("labels")
|
||||||
|
|
||||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||||
|
input_names = list(filtered_inputs.keys())
|
||||||
|
|
||||||
model_output = model(**filtered_inputs)
|
model_output = model(**filtered_inputs)
|
||||||
|
|
||||||
@@ -427,7 +420,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
input_names.append("end_positions")
|
input_names.append("end_positions")
|
||||||
|
|
||||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||||
input_names = filtered_inputs.keys()
|
input_names = list(filtered_inputs.keys())
|
||||||
|
|
||||||
model_output = model(**filtered_inputs)
|
model_output = model(**filtered_inputs)
|
||||||
|
|
||||||
|
|||||||
@@ -509,8 +509,8 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
|
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
|
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
|
||||||
fx_compatible = True
|
|
||||||
all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
|
all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
|
||||||
|
fx_compatible = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
test_model_parallel = True
|
test_model_parallel = True
|
||||||
|
|||||||
@@ -161,6 +161,7 @@ class TrOCRStandaloneDecoderModelTester:
|
|||||||
class TrOCRStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
class TrOCRStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (TrOCRDecoder, TrOCRForCausalLM) if is_torch_available() else ()
|
all_model_classes = (TrOCRDecoder, TrOCRForCausalLM) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (TrOCRForCausalLM,) if is_torch_available() else ()
|
all_generative_model_classes = (TrOCRForCausalLM,) if is_torch_available() else ()
|
||||||
|
fx_compatible = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|||||||
@@ -13,17 +13,26 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import XGLMConfig, is_torch_available
|
from transformers import XGLMConfig, is_torch_available
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
from transformers.utils import is_torch_fx_available
|
||||||
|
|
||||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
from ...test_modeling_common import (
|
||||||
|
ModelTesterMixin,
|
||||||
|
_config_zero_init,
|
||||||
|
floats_tensor,
|
||||||
|
ids_tensor,
|
||||||
|
random_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -31,6 +40,9 @@ if is_torch_available():
|
|||||||
|
|
||||||
from transformers import XGLM_PRETRAINED_MODEL_ARCHIVE_LIST, XGLMForCausalLM, XGLMModel, XGLMTokenizer
|
from transformers import XGLM_PRETRAINED_MODEL_ARCHIVE_LIST, XGLMForCausalLM, XGLMModel, XGLMTokenizer
|
||||||
|
|
||||||
|
if is_torch_fx_available():
|
||||||
|
from transformers.utils.fx import symbolic_trace
|
||||||
|
|
||||||
|
|
||||||
class XGLMModelTester:
|
class XGLMModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -299,6 +311,7 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
all_model_classes = (XGLMModel, XGLMForCausalLM) if is_torch_available() else ()
|
all_model_classes = (XGLMModel, XGLMForCausalLM) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (XGLMForCausalLM,) if is_torch_available() else ()
|
all_generative_model_classes = (XGLMForCausalLM,) if is_torch_available() else ()
|
||||||
|
fx_compatible = True
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
|
||||||
@@ -337,6 +350,112 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_xglm_weight_initialization(*config_and_inputs)
|
self.model_tester.create_and_check_xglm_weight_initialization(*config_and_inputs)
|
||||||
|
|
||||||
|
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
|
||||||
|
if not is_torch_fx_available() or not self.fx_compatible:
|
||||||
|
return
|
||||||
|
|
||||||
|
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||||
|
configs_no_init.return_dict = False
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config=configs_no_init)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||||
|
labels = inputs.get("labels", None)
|
||||||
|
input_names = [
|
||||||
|
"input_ids",
|
||||||
|
"attention_mask",
|
||||||
|
"decoder_input_ids",
|
||||||
|
"decoder_attention_mask",
|
||||||
|
"input_features",
|
||||||
|
]
|
||||||
|
if labels is not None:
|
||||||
|
input_names.append("labels")
|
||||||
|
|
||||||
|
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||||
|
input_names = list(filtered_inputs.keys())
|
||||||
|
|
||||||
|
model_output = model(**filtered_inputs)
|
||||||
|
|
||||||
|
traced_model = symbolic_trace(model, input_names)
|
||||||
|
traced_output = traced_model(**filtered_inputs)
|
||||||
|
else:
|
||||||
|
input_names = [
|
||||||
|
"input_ids",
|
||||||
|
"attention_mask",
|
||||||
|
"token_type_ids",
|
||||||
|
"pixel_values",
|
||||||
|
"bbox",
|
||||||
|
"input_features",
|
||||||
|
]
|
||||||
|
|
||||||
|
labels = inputs.get("labels", None)
|
||||||
|
start_positions = inputs.get("start_positions", None)
|
||||||
|
end_positions = inputs.get("end_positions", None)
|
||||||
|
if labels is not None:
|
||||||
|
input_names.append("labels")
|
||||||
|
if start_positions is not None:
|
||||||
|
input_names.append("start_positions")
|
||||||
|
if end_positions is not None:
|
||||||
|
input_names.append("end_positions")
|
||||||
|
|
||||||
|
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||||
|
input_names = list(filtered_inputs.keys())
|
||||||
|
|
||||||
|
model_output = model(**filtered_inputs)
|
||||||
|
|
||||||
|
traced_model = symbolic_trace(model, input_names)
|
||||||
|
traced_output = traced_model(**filtered_inputs)
|
||||||
|
|
||||||
|
except RuntimeError as e:
|
||||||
|
self.fail(f"Couldn't trace module: {e}")
|
||||||
|
|
||||||
|
def flatten_output(output):
|
||||||
|
flatten = []
|
||||||
|
for x in output:
|
||||||
|
if isinstance(x, (tuple, list)):
|
||||||
|
flatten += flatten_output(x)
|
||||||
|
elif not isinstance(x, torch.Tensor):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
flatten.append(x)
|
||||||
|
return flatten
|
||||||
|
|
||||||
|
model_output = flatten_output(model_output)
|
||||||
|
traced_output = flatten_output(traced_output)
|
||||||
|
num_outputs = len(model_output)
|
||||||
|
|
||||||
|
for i in range(num_outputs):
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(model_output[i], traced_output[i]),
|
||||||
|
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test that the model can be serialized and restored properly
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||||
|
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
|
||||||
|
try:
|
||||||
|
with open(pkl_file_name, "wb") as f:
|
||||||
|
pickle.dump(traced_model, f)
|
||||||
|
with open(pkl_file_name, "rb") as f:
|
||||||
|
loaded = pickle.load(f)
|
||||||
|
except Exception as e:
|
||||||
|
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
|
||||||
|
|
||||||
|
loaded_output = loaded(**filtered_inputs)
|
||||||
|
loaded_output = flatten_output(loaded_output)
|
||||||
|
|
||||||
|
for i in range(num_outputs):
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(model_output[i], loaded_output[i]),
|
||||||
|
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
|
||||||
|
)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_batch_generation(self):
|
def test_batch_generation(self):
|
||||||
model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")
|
model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")
|
||||||
|
|||||||
@@ -526,6 +526,7 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
|
|||||||
all_generative_model_classes = (
|
all_generative_model_classes = (
|
||||||
(XLNetLMHeadModel,) if is_torch_available() else ()
|
(XLNetLMHeadModel,) if is_torch_available() else ()
|
||||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||||
|
fx_compatible = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
|
||||||
# XLNet has 2 QA models -> need to manually set the correct labels for one of them here
|
# XLNet has 2 QA models -> need to manually set the correct labels for one of them here
|
||||||
|
|||||||
@@ -738,17 +738,32 @@ class ModelTesterMixin:
|
|||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||||
labels = inputs.get("labels", None)
|
labels = inputs.get("labels", None)
|
||||||
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
|
input_names = [
|
||||||
|
"input_ids",
|
||||||
|
"attention_mask",
|
||||||
|
"decoder_input_ids",
|
||||||
|
"decoder_attention_mask",
|
||||||
|
"input_features",
|
||||||
|
]
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
input_names.append("labels")
|
input_names.append("labels")
|
||||||
|
|
||||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||||
|
input_names = list(filtered_inputs.keys())
|
||||||
|
|
||||||
model_output = model(**filtered_inputs)
|
model_output = model(**filtered_inputs)
|
||||||
|
|
||||||
traced_model = symbolic_trace(model, input_names)
|
traced_model = symbolic_trace(model, input_names)
|
||||||
traced_output = traced_model(**filtered_inputs)
|
traced_output = traced_model(**filtered_inputs)
|
||||||
else:
|
else:
|
||||||
input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"]
|
input_names = [
|
||||||
|
"input_ids",
|
||||||
|
"attention_mask",
|
||||||
|
"token_type_ids",
|
||||||
|
"pixel_values",
|
||||||
|
"bbox",
|
||||||
|
"input_features",
|
||||||
|
]
|
||||||
|
|
||||||
labels = inputs.get("labels", None)
|
labels = inputs.get("labels", None)
|
||||||
start_positions = inputs.get("start_positions", None)
|
start_positions = inputs.get("start_positions", None)
|
||||||
@@ -761,7 +776,7 @@ class ModelTesterMixin:
|
|||||||
input_names.append("end_positions")
|
input_names.append("end_positions")
|
||||||
|
|
||||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||||
input_names = filtered_inputs.keys()
|
input_names = list(filtered_inputs.keys())
|
||||||
|
|
||||||
model_output = model(**filtered_inputs)
|
model_output = model(**filtered_inputs)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user