Fx support for multiple model architectures (#17393)
* Support for Bart and LayoutLM, and partial support for XLNet * Support for mbart * A lot of new models supported * Support for other models * LayoutLM fix * Use strings instead of classes
This commit is contained in:
@@ -93,7 +93,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
||||
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
|
||||
mask_cond = torch.arange(mask.size(-1))
|
||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||
mask = mask.to(dtype)
|
||||
@@ -114,7 +114,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
class BartLearnedPositionalEmbedding(nn.Embedding):
|
||||
@@ -911,7 +911,7 @@ class BartDecoder(BartPretrainedModel):
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||
).to(self.device)
|
||||
).to(inputs_embeds.device)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
|
||||
@@ -2112,7 +2112,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||
).to(self.device)
|
||||
).to(inputs_embeds.device)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
|
||||
@@ -83,7 +83,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
||||
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
|
||||
mask_cond = torch.arange(mask.size(-1))
|
||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||
mask = mask.to(dtype)
|
||||
@@ -105,7 +105,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
class BlenderbotLearnedPositionalEmbedding(nn.Embedding):
|
||||
@@ -850,7 +850,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||
).to(self.device)
|
||||
).to(inputs_embeds.device)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
|
||||
@@ -80,7 +80,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
||||
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
|
||||
mask_cond = torch.arange(mask.size(-1))
|
||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||
mask = mask.to(dtype)
|
||||
@@ -102,7 +102,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
# Copied from transformers.models.blenderbot.modeling_blenderbot.BlenderbotLearnedPositionalEmbedding with Blenderbot->BlenderbotSmall
|
||||
@@ -846,7 +846,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||
).to(self.device)
|
||||
).to(inputs_embeds.device)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
|
||||
@@ -57,7 +57,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
# contrastive loss function, adapted from
|
||||
@@ -674,7 +674,7 @@ class CLIPTextTransformer(nn.Module):
|
||||
# lazily create causal attention mask, with full attention between the vision tokens
|
||||
# pytorch uses additive attention mask; fill with -inf
|
||||
mask = torch.empty(bsz, seq_len, seq_len)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.fill_(torch.tensor(float("-inf")))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
mask = mask.unsqueeze(1) # expand mask
|
||||
return mask
|
||||
@@ -1042,8 +1042,8 @@ class CLIPModel(CLIPPreTrainedModel):
|
||||
text_embeds = self.text_projection(text_embeds)
|
||||
|
||||
# normalized features
|
||||
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
|
||||
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
|
||||
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
||||
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
||||
|
||||
# cosine similarity as logits
|
||||
logit_scale = self.logit_scale.exp()
|
||||
|
||||
@@ -800,7 +800,7 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
if bbox is None:
|
||||
bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device)
|
||||
bbox = torch.zeros(input_shape + (4,), dtype=torch.long, device=device)
|
||||
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
||||
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
|
||||
mask_cond = torch.arange(mask.size(-1))
|
||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||
mask = mask.to(dtype)
|
||||
@@ -101,7 +101,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
|
||||
@@ -998,7 +998,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||
).to(self.device)
|
||||
).to(inputs_embeds.device)
|
||||
|
||||
if attention_mask is not None and combined_attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
|
||||
@@ -81,7 +81,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
||||
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
|
||||
mask_cond = torch.arange(mask.size(-1))
|
||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||
mask = mask.to(dtype)
|
||||
@@ -103,7 +103,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
class MarianSinusoidalPositionalEmbedding(nn.Embedding):
|
||||
@@ -856,7 +856,7 @@ class MarianDecoder(MarianPreTrainedModel):
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||
).to(self.device)
|
||||
).to(inputs_embeds.device)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
|
||||
@@ -97,7 +97,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
||||
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
|
||||
mask_cond = torch.arange(mask.size(-1))
|
||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||
mask = mask.to(dtype)
|
||||
@@ -119,7 +119,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->MBart
|
||||
@@ -909,7 +909,7 @@ class MBartDecoder(MBartPreTrainedModel):
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||
).to(self.device)
|
||||
).to(inputs_embeds.device)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
|
||||
@@ -61,7 +61,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
||||
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
|
||||
mask_cond = torch.arange(mask.size(-1))
|
||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||
mask = mask.to(dtype)
|
||||
@@ -82,7 +82,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
class OPTLearnedPositionalEmbedding(nn.Embedding):
|
||||
@@ -513,7 +513,7 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||
).to(self.device)
|
||||
).to(inputs_embeds.device)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
|
||||
@@ -80,7 +80,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
||||
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
|
||||
mask_cond = torch.arange(mask.size(-1))
|
||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||
mask = mask.to(dtype)
|
||||
@@ -102,7 +102,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Pegasus
|
||||
@@ -876,7 +876,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||
).to(self.device)
|
||||
).to(inputs_embeds.device)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
|
||||
@@ -94,7 +94,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
||||
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
|
||||
mask_cond = torch.arange(mask.size(-1))
|
||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||
mask = mask.to(dtype)
|
||||
@@ -116,7 +116,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->PLBart
|
||||
@@ -883,7 +883,7 @@ class PLBartDecoder(PLBartPreTrainedModel):
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||
).to(self.device)
|
||||
).to(inputs_embeds.device)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
|
||||
@@ -69,7 +69,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
||||
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
|
||||
mask_cond = torch.arange(mask.size(-1))
|
||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||
mask = mask.to(dtype)
|
||||
@@ -91,7 +91,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
class Conv1dSubsampler(nn.Module):
|
||||
@@ -888,7 +888,7 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||
).to(self.device)
|
||||
).to(inputs_embeds.device)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
|
||||
@@ -49,7 +49,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
||||
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
|
||||
mask_cond = torch.arange(mask.size(-1))
|
||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||
mask = mask.to(dtype)
|
||||
@@ -71,7 +71,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
# Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextSinusoidalPositionalEmbedding with Speech2Text->Speech2Text2
|
||||
@@ -495,7 +495,7 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel):
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||
).to(self.device)
|
||||
).to(inputs_embeds.device)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
|
||||
@@ -50,7 +50,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
||||
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
|
||||
mask_cond = torch.arange(mask.size(-1))
|
||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||
mask = mask.to(dtype)
|
||||
@@ -72,7 +72,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->TrOCR
|
||||
@@ -524,7 +524,7 @@ class TrOCRDecoder(TrOCRPreTrainedModel):
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||
).to(self.device)
|
||||
).to(inputs_embeds.device)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
|
||||
@@ -120,7 +120,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
||||
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
|
||||
mask_cond = torch.arange(mask.size(-1))
|
||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||
mask = mask.to(dtype)
|
||||
@@ -142,7 +142,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
|
||||
@@ -577,7 +577,7 @@ class XGLMModel(XGLMPreTrainedModel):
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||
).to(self.device)
|
||||
).to(inputs_embeds.device)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
@@ -712,7 +712,7 @@ class XGLMModel(XGLMPreTrainedModel):
|
||||
|
||||
hidden_states = inputs_embeds + positions
|
||||
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=float(self.dropout), training=self.training)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
@@ -1056,7 +1056,6 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
|
||||
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
|
||||
|
||||
pos_emb = pos_emb.to(self.device)
|
||||
return pos_emb
|
||||
|
||||
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@@ -1206,6 +1205,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
|
||||
# Positional encoding
|
||||
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
|
||||
pos_emb = pos_emb.to(output_h.device)
|
||||
pos_emb = self.dropout(pos_emb)
|
||||
|
||||
# Prepare head mask if needed
|
||||
|
||||
@@ -29,27 +29,23 @@ from torch import nn
|
||||
from torch.fx import Graph, GraphModule, Proxy, Tracer
|
||||
from torch.fx.proxy import ParameterProxy
|
||||
|
||||
from .. import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
MODEL_FOR_PRETRAINING_MAPPING,
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
GPT2DoubleHeadsModel,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
XLNetForQuestionAnswering,
|
||||
logging,
|
||||
)
|
||||
from .. import PretrainedConfig, PreTrainedModel, logging
|
||||
from ..models.auto import get_values
|
||||
from ..models.auto.modeling_auto import (
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
|
||||
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
|
||||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
|
||||
MODEL_FOR_PRETRAINING_MAPPING_NAMES,
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
||||
MODEL_MAPPING_NAMES,
|
||||
)
|
||||
from ..utils import TORCH_FX_REQUIRED_VERSION, is_torch_fx_available
|
||||
from ..utils.versions import importlib_metadata
|
||||
|
||||
@@ -57,25 +53,25 @@ from ..utils.versions import importlib_metadata
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _generate_supported_model_classes(
|
||||
def _generate_supported_model_class_names(
|
||||
model_name: Type[PretrainedConfig],
|
||||
supported_tasks: Optional[Union[str, List[str]]] = None,
|
||||
) -> List[Type[PreTrainedModel]]:
|
||||
) -> List[str]:
|
||||
|
||||
model_config_class = CONFIG_MAPPING[model_name]
|
||||
task_mapping = {
|
||||
"default": MODEL_MAPPING,
|
||||
"pretraining": MODEL_FOR_PRETRAINING_MAPPING,
|
||||
"next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
"masked-lm": MODEL_FOR_MASKED_LM_MAPPING,
|
||||
"causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
"seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
"multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
"sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
|
||||
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
"default": MODEL_MAPPING_NAMES,
|
||||
"pretraining": MODEL_FOR_PRETRAINING_MAPPING_NAMES,
|
||||
"next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
|
||||
"masked-lm": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
||||
"causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||
"seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
|
||||
"speech-seq2seq": MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
|
||||
"multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
|
||||
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
|
||||
"sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
|
||||
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
||||
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
|
||||
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||
}
|
||||
|
||||
if supported_tasks is None:
|
||||
@@ -83,55 +79,78 @@ def _generate_supported_model_classes(
|
||||
if isinstance(supported_tasks, str):
|
||||
supported_tasks = [supported_tasks]
|
||||
|
||||
model_classes = []
|
||||
model_class_names = []
|
||||
for task in supported_tasks:
|
||||
model_class = task_mapping[task].get(model_config_class, None)
|
||||
if model_class:
|
||||
model_classes.append(model_class)
|
||||
class_name = task_mapping[task].get(model_name, None)
|
||||
if class_name:
|
||||
model_class_names.append(class_name)
|
||||
|
||||
return model_classes
|
||||
return model_class_names
|
||||
|
||||
|
||||
_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
|
||||
"albert",
|
||||
"bart",
|
||||
"bert",
|
||||
"blenderbot",
|
||||
"blenderbot-small",
|
||||
"clip",
|
||||
"distilbert",
|
||||
"mobilebert",
|
||||
"electra",
|
||||
"megatron-bert",
|
||||
"gpt2",
|
||||
"gptj",
|
||||
"gpt_neo",
|
||||
"t5",
|
||||
"gptj",
|
||||
"layoutlm",
|
||||
"m2m_100",
|
||||
"marian",
|
||||
"mbart",
|
||||
"megatron-bert",
|
||||
"mobilebert",
|
||||
"mt5",
|
||||
"opt",
|
||||
"pegasus",
|
||||
"plbart",
|
||||
"roberta",
|
||||
"vit",
|
||||
"speech_to_text",
|
||||
"speech_to_text_2",
|
||||
"swin",
|
||||
"t5",
|
||||
"trocr",
|
||||
"vit",
|
||||
"xglm",
|
||||
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
|
||||
# "layoutlm",
|
||||
# "xlnet",
|
||||
]
|
||||
|
||||
_REGULAR_SUPPORTED_MODELS = []
|
||||
for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS:
|
||||
if isinstance(item, dict):
|
||||
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_classes(**item))
|
||||
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(**item))
|
||||
else:
|
||||
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_classes(item))
|
||||
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(item))
|
||||
|
||||
_SPECIAL_SUPPORTED_MODELS = [
|
||||
GPT2DoubleHeadsModel,
|
||||
"CLIPTextModel",
|
||||
"CLIPVisionModel",
|
||||
"GPT2DoubleHeadsModel",
|
||||
"Speech2Text2Decoder",
|
||||
"TrOCRDecoder",
|
||||
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
|
||||
# XLNetForQuestionAnswering,
|
||||
]
|
||||
_SUPPORTED_MODELS = tuple(
|
||||
sorted(list(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)), key=lambda c: c.__name__)
|
||||
)
|
||||
_SUPPORTED_MODELS = tuple(sorted(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)))
|
||||
|
||||
|
||||
def torch_nn_embedding(self, input):
|
||||
return torch.empty(*input.shape, self.weight.shape[-1], device="meta")
|
||||
|
||||
|
||||
def torch_nn_functional_embedding(
|
||||
input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False
|
||||
):
|
||||
return torch.empty(*input.shape, weight.shape[-1], device="meta")
|
||||
|
||||
|
||||
def torch_nn_layernorm(self, input):
|
||||
return input
|
||||
|
||||
@@ -176,6 +195,12 @@ def torch_arange(*args, **kwargs):
|
||||
start, end = args
|
||||
else:
|
||||
start, end, step = args
|
||||
if isinstance(start, float):
|
||||
start = int(start)
|
||||
if isinstance(end, float):
|
||||
start = int(end)
|
||||
if isinstance(step, float):
|
||||
step = int(step)
|
||||
step = kwargs.get("step", step)
|
||||
dtype = kwargs.get("dtype")
|
||||
return torch.empty((end - start) // step, dtype=dtype, device="meta")
|
||||
@@ -265,6 +290,14 @@ def torch_matmul(input, other, *, out=None):
|
||||
return torch.empty(*shape, device="meta")
|
||||
|
||||
|
||||
def torch_bmm(input, mat2, *, out=None):
|
||||
if out is not None:
|
||||
raise ValueError("Don't support in-place abs for MetaTensor analysis")
|
||||
batch_size, n, m = input.shape
|
||||
_, _, p = mat2.shape
|
||||
return torch.empty(batch_size, n, p, device="meta")
|
||||
|
||||
|
||||
def torch_einsum(equation, *operands):
|
||||
# TODO: infer shape without performing the computation, this might be quite hard.
|
||||
concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands)
|
||||
@@ -285,13 +318,39 @@ def torch_index_select(input, dim, index, *, out=None):
|
||||
|
||||
|
||||
def torch_tensor_index_select(self, dim, index):
|
||||
return torch_tensor_index_select(self, dim, index)
|
||||
return torch_index_select(self, dim, index)
|
||||
|
||||
|
||||
def torch_roll(input, shifts, dims=None):
|
||||
return input
|
||||
|
||||
|
||||
def torch_flip(input, dims):
|
||||
return input
|
||||
|
||||
|
||||
def torch_tensor_flip(self, dims):
|
||||
return self
|
||||
|
||||
|
||||
def torch_nn_conv1d(self, input):
|
||||
l_in = input.shape[-1]
|
||||
shape = None
|
||||
padding = self.padding
|
||||
if padding == "valid":
|
||||
padding = (0, 0)
|
||||
if padding == "same":
|
||||
shape = list(input.shape)
|
||||
if shape is None:
|
||||
shape = list(input.shape)
|
||||
l_out = math.floor(
|
||||
(l_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
|
||||
)
|
||||
shape[-1] = l_out
|
||||
shape[-2] = self.out_channels
|
||||
return torch.empty(shape, device="meta")
|
||||
|
||||
|
||||
def torch_nn_conv2d(self, input):
|
||||
h_in, w_in = input.shape[-2:]
|
||||
shape = None
|
||||
@@ -325,6 +384,21 @@ def torch_tensor_unsqueeze(self, dim):
|
||||
return torch_unsqueeze(self, dim)
|
||||
|
||||
|
||||
def torch_unique_consecutive(input, **kwargs):
|
||||
output = torch.unique_consecutive(torch.zeros_like(input, device="cpu"), **kwargs)
|
||||
if isinstance(output, torch.Tensor):
|
||||
return output.to("meta")
|
||||
else:
|
||||
return tuple(map(output, lambda x: x.to("meta")))
|
||||
|
||||
|
||||
def torch_nn_functional_one_hot(tensor, num_classes=-1):
|
||||
if num_classes < 0:
|
||||
raise ValueError("Don't support automatic num_classes inference for MetaTensor analysis")
|
||||
shape = list(tensor.shape) + [num_classes]
|
||||
return torch.empty(shape, device="meta")
|
||||
|
||||
|
||||
def torch_nn_mseloss(self, input, target):
|
||||
if self.reduction == "none":
|
||||
shape = target.shape
|
||||
@@ -350,14 +424,27 @@ def torch_nn_bcewithlogitsloss(self, input, target):
|
||||
|
||||
|
||||
def operator_getitem(a, b):
|
||||
def to_concrete(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
concrete = torch.ones_like(t, device="cpu")
|
||||
if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]:
|
||||
concrete = concrete.to(torch.int64)
|
||||
return concrete
|
||||
return t
|
||||
|
||||
if isinstance(a, torch.Tensor):
|
||||
# TODO: infer shape without performing the computation.
|
||||
if isinstance(b, tuple):
|
||||
b = tuple(map(to_concrete, b))
|
||||
else:
|
||||
b = to_concrete(b)
|
||||
return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
|
||||
return operator.getitem(a, b)
|
||||
|
||||
|
||||
_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
|
||||
torch.nn.Embedding: torch_nn_embedding,
|
||||
torch.nn.functional.embedding: torch_nn_functional_embedding,
|
||||
torch.nn.LayerNorm: torch_nn_layernorm,
|
||||
torch.nn.Linear: torch_nn_linear,
|
||||
torch.relu: torch_relu,
|
||||
@@ -372,15 +459,20 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
|
||||
torch.mul: torch_mul,
|
||||
torch.Tensor.mul: torch_tensor_mul,
|
||||
torch.matmul: torch_matmul,
|
||||
torch.bmm: torch_bmm,
|
||||
torch.einsum: torch_einsum,
|
||||
torch.Tensor.repeat: torch_tensor_repeat,
|
||||
torch.roll: torch_roll,
|
||||
# TODO: those might not be needed.
|
||||
# torch.index_select: torch_index_select,
|
||||
# torch.Tensor.index_select: torch_tensor_index_select,
|
||||
torch.flip: torch_flip,
|
||||
torch.Tensor.flip: torch_tensor_flip,
|
||||
torch.index_select: torch_index_select,
|
||||
torch.Tensor.index_select: torch_tensor_index_select,
|
||||
torch.nn.Conv1d: torch_nn_conv1d,
|
||||
torch.nn.Conv2d: torch_nn_conv2d,
|
||||
torch.unsqueeze: torch_unsqueeze,
|
||||
torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
|
||||
torch.unique_consecutive: torch_unique_consecutive,
|
||||
torch.nn.functional.one_hot: torch_nn_functional_one_hot,
|
||||
torch.nn.MSELoss: torch_nn_mseloss,
|
||||
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
|
||||
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
|
||||
@@ -513,7 +605,7 @@ class HFTracer(Tracer):
|
||||
# Feature flag for proxying accesses to buffer values
|
||||
proxy_buffer_attributes: bool = True
|
||||
allow_insert_stateless_mods: bool = True
|
||||
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full_like", "eye"]
|
||||
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty"]
|
||||
|
||||
def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
|
||||
|
||||
@@ -532,22 +624,22 @@ class HFTracer(Tracer):
|
||||
"""Generates dummy input for model inference recording."""
|
||||
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
|
||||
# from pickle, or from the "__class__" attribute in the general case.
|
||||
model_class = getattr(model, "class_for_deserialization", model.__class__)
|
||||
model_class_name = getattr(model, "class_for_deserialization", model.__class__).__name__
|
||||
device = model.device
|
||||
inputs_dict = {}
|
||||
|
||||
if input_name in ["labels", "start_positions", "end_positions"]:
|
||||
|
||||
batch_size = shape[0]
|
||||
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||
if model_class_name in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
|
||||
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||
elif model_class in [
|
||||
*get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING),
|
||||
XLNetForQuestionAnswering,
|
||||
elif model_class_name in [
|
||||
*get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
|
||||
"XLNetForQuestionAnswering",
|
||||
]:
|
||||
inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||
inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||
elif model_class in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
|
||||
elif model_class_name in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES):
|
||||
if not hasattr(model.config, "problem_type") or model.config.problem_type is None:
|
||||
raise ValueError(
|
||||
"Could not retrieve the problem type for the sequence classification task, please set "
|
||||
@@ -571,32 +663,49 @@ class HFTracer(Tracer):
|
||||
)
|
||||
inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device)
|
||||
|
||||
elif model_class in [
|
||||
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING),
|
||||
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
|
||||
elif model_class_name in [
|
||||
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
|
||||
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
|
||||
]:
|
||||
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||
elif model_class in [
|
||||
*get_values(MODEL_FOR_PRETRAINING_MAPPING),
|
||||
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
|
||||
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING),
|
||||
*get_values(MODEL_FOR_MASKED_LM_MAPPING),
|
||||
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
|
||||
GPT2DoubleHeadsModel,
|
||||
elif model_class_name in [
|
||||
*get_values(MODEL_FOR_PRETRAINING_MAPPING_NAMES),
|
||||
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES),
|
||||
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
|
||||
*get_values(MODEL_FOR_MASKED_LM_MAPPING_NAMES),
|
||||
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES),
|
||||
"GPT2DoubleHeadsModel",
|
||||
]:
|
||||
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
|
||||
else:
|
||||
raise NotImplementedError(f"{model_class} not supported yet.")
|
||||
raise NotImplementedError(f"{model_class_name} not supported yet.")
|
||||
elif "pixel_values" in input_name:
|
||||
batch_size = shape[0]
|
||||
image_size = model.config.image_size
|
||||
image_size = getattr(model.config, "image_size", None)
|
||||
if image_size is None:
|
||||
if hasattr(model.config, "vision_config"):
|
||||
image_size = model.config.vision_config.image_size
|
||||
elif hasattr(model.config, "encoder"):
|
||||
image_size = model.config.encoder.image_size
|
||||
else:
|
||||
raise AttributeError('Could not find the "image_size" field in the model config')
|
||||
|
||||
# If no num_channels is in the config, use some arbitrary value.
|
||||
num_channels = getattr(model.config, "num_channels", 3)
|
||||
if not isinstance(image_size, collections.abc.Iterable):
|
||||
image_size = (image_size, image_size)
|
||||
height, width = image_size
|
||||
inputs_dict[input_name] = torch.zeros(
|
||||
batch_size, model.config.num_channels, height, width, dtype=torch.float32, device=device
|
||||
batch_size, num_channels, height, width, dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
elif "bbox" in input_name:
|
||||
inputs_dict[input_name] = torch.zeros(*shape, 4, dtype=torch.float, device=device)
|
||||
elif "input_features" in input_name:
|
||||
inputs_dict[input_name] = torch.zeros(
|
||||
*shape, model.config.input_feat_per_channel, dtype=torch.float, device=device
|
||||
)
|
||||
elif "inputs" in input_name:
|
||||
inputs_dict[input_name] = torch.zeros(*shape, dtype=torch.float, device=device)
|
||||
elif "mask" in input_name or "ids" in input_name:
|
||||
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
|
||||
else:
|
||||
@@ -628,6 +737,8 @@ class HFTracer(Tracer):
|
||||
if kind == "call_function":
|
||||
meta_target = _MANUAL_META_OVERRIDES.get(target, target)
|
||||
meta_out = meta_target(*args_metas, **kwargs_metas)
|
||||
if isinstance(meta_out, torch.Tensor):
|
||||
meta_out = meta_out.to(device="meta")
|
||||
elif kind == "call_method":
|
||||
method = getattr(args_metas[0].__class__, target)
|
||||
meta_target = _MANUAL_META_OVERRIDES.get(method, method)
|
||||
@@ -731,7 +842,7 @@ class HFTracer(Tracer):
|
||||
sequence_length = _generate_random_int()
|
||||
shape = [batch_size, sequence_length]
|
||||
|
||||
if root.__class__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||
if root.__class__.__name__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
|
||||
num_choices = _generate_random_int(low=2, high=5)
|
||||
shape.insert(1, num_choices)
|
||||
|
||||
@@ -870,11 +981,22 @@ def symbolic_trace(
|
||||
if input_names is None:
|
||||
input_names = model.dummy_inputs.keys()
|
||||
|
||||
input_names = list(input_names)
|
||||
|
||||
sig = inspect.signature(model.forward)
|
||||
|
||||
if not (set(input_names) <= set(sig.parameters.keys())):
|
||||
formatted_input_names = input_names[0] if len(input_names) == 1 else ", ".join(input_names)
|
||||
formatted_allowed_input_names = ", ".join(sig.parameters.keys())
|
||||
raise ValueError(
|
||||
f"The model does not have input(s) named: {formatted_input_names}, expected a subset of the following:"
|
||||
f" {formatted_allowed_input_names}"
|
||||
)
|
||||
|
||||
concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
|
||||
|
||||
if not isinstance(model, _SUPPORTED_MODELS):
|
||||
supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS))
|
||||
if model.__class__.__name__ not in _SUPPORTED_MODELS:
|
||||
supported_model_names = ", ".join(_SUPPORTED_MODELS)
|
||||
raise NotImplementedError(
|
||||
f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"
|
||||
)
|
||||
|
||||
@@ -413,6 +413,7 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
)
|
||||
all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
fx_compatible = True
|
||||
test_pruning = False
|
||||
test_missing_keys = False
|
||||
|
||||
@@ -1386,6 +1387,7 @@ class BartStandaloneDecoderModelTester:
|
||||
class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (BartDecoder, BartForCausalLM) if is_torch_available() else ()
|
||||
all_generative_model_classes = (BartForCausalLM,) if is_torch_available() else ()
|
||||
fx_comptatible = True
|
||||
test_pruning = False
|
||||
is_encoder_decoder = False
|
||||
|
||||
|
||||
@@ -218,6 +218,7 @@ class BlenderbotModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
||||
all_model_classes = (BlenderbotModel, BlenderbotForConditionalGeneration) if is_torch_available() else ()
|
||||
all_generative_model_classes = (BlenderbotForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
fx_compatible = True
|
||||
test_pruning = False
|
||||
test_missing_keys = False
|
||||
|
||||
|
||||
@@ -213,6 +213,7 @@ class BlenderbotSmallModelTest(ModelTesterMixin, GenerationTesterMixin, unittest
|
||||
all_model_classes = (BlenderbotSmallModel, BlenderbotSmallForConditionalGeneration) if is_torch_available() else ()
|
||||
all_generative_model_classes = (BlenderbotSmallForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
fx_compatible = True
|
||||
test_pruning = False
|
||||
test_missing_keys = False
|
||||
|
||||
|
||||
@@ -152,7 +152,7 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
|
||||
all_model_classes = (CLIPVisionModel,) if is_torch_available() else ()
|
||||
|
||||
fx_compatible = True
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
@@ -303,6 +303,7 @@ class CLIPTextModelTester:
|
||||
class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (CLIPTextModel,) if is_torch_available() else ()
|
||||
fx_compatible = True
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
|
||||
@@ -388,6 +389,7 @@ class CLIPModelTester:
|
||||
@require_torch
|
||||
class CLIPModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (CLIPModel,) if is_torch_available() else ()
|
||||
fx_compatible = True
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
|
||||
@@ -215,6 +215,7 @@ class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else None
|
||||
)
|
||||
fx_compatible = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = LayoutLMModelTester(self)
|
||||
|
||||
@@ -231,6 +231,7 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
|
||||
)
|
||||
all_generative_model_classes = (M2M100ForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
fx_compatible = True
|
||||
test_pruning = False
|
||||
test_missing_keys = False
|
||||
|
||||
|
||||
@@ -230,6 +230,7 @@ class MarianModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
|
||||
all_model_classes = (MarianModel, MarianMTModel) if is_torch_available() else ()
|
||||
all_generative_model_classes = (MarianMTModel,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
fx_compatible = True
|
||||
test_pruning = False
|
||||
test_missing_keys = False
|
||||
|
||||
|
||||
@@ -224,6 +224,7 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
|
||||
)
|
||||
all_generative_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
fx_compatible = True
|
||||
test_pruning = False
|
||||
test_missing_keys = False
|
||||
|
||||
|
||||
@@ -178,6 +178,7 @@ class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (OPTModel, OPTForCausalLM) if is_torch_available() else ()
|
||||
all_generative_model_classes = (OPTForCausalLM,) if is_torch_available() else ()
|
||||
is_encoder_decoder = False
|
||||
fx_compatible = True
|
||||
test_pruning = False
|
||||
test_missing_keys = False
|
||||
|
||||
|
||||
@@ -229,6 +229,7 @@ class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
all_model_classes = (PegasusModel, PegasusForConditionalGeneration) if is_torch_available() else ()
|
||||
all_generative_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
fx_compatible = True
|
||||
test_resize_position_embeddings = True
|
||||
test_pruning = False
|
||||
test_missing_keys = False
|
||||
|
||||
@@ -219,6 +219,7 @@ class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
|
||||
)
|
||||
all_generative_model_classes = (PLBartForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
fx_compatible = True
|
||||
test_pruning = False
|
||||
test_missing_keys = False
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
import copy
|
||||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
@@ -30,7 +31,7 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import cached_property
|
||||
from transformers.utils import cached_property, is_torch_fx_available
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -43,6 +44,9 @@ if is_torch_available():
|
||||
from transformers import Speech2TextForConditionalGeneration, Speech2TextModel, Speech2TextProcessor
|
||||
from transformers.models.speech_to_text.modeling_speech_to_text import Speech2TextDecoder, Speech2TextEncoder
|
||||
|
||||
if is_torch_fx_available():
|
||||
from transformers.utils.fx import symbolic_trace
|
||||
|
||||
|
||||
def prepare_speech_to_text_inputs_dict(
|
||||
config,
|
||||
@@ -271,6 +275,7 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
|
||||
all_model_classes = (Speech2TextModel, Speech2TextForConditionalGeneration) if is_torch_available() else ()
|
||||
all_generative_model_classes = (Speech2TextForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
fx_compatible = True
|
||||
test_pruning = False
|
||||
test_missing_keys = False
|
||||
|
||||
@@ -715,6 +720,105 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
|
||||
if not is_torch_fx_available() or not self.fx_compatible:
|
||||
return
|
||||
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
configs_no_init.return_dict = False
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
|
||||
|
||||
try:
|
||||
if model.config.is_encoder_decoder:
|
||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||
labels = inputs.get("labels", None)
|
||||
input_names = [
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"decoder_input_ids",
|
||||
"decoder_attention_mask",
|
||||
"input_features",
|
||||
]
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
else:
|
||||
input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values", "input_features"]
|
||||
|
||||
labels = inputs.get("labels", None)
|
||||
start_positions = inputs.get("start_positions", None)
|
||||
end_positions = inputs.get("end_positions", None)
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
if start_positions is not None:
|
||||
input_names.append("start_positions")
|
||||
if end_positions is not None:
|
||||
input_names.append("end_positions")
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
|
||||
except RuntimeError as e:
|
||||
self.fail(f"Couldn't trace module: {e}")
|
||||
|
||||
def flatten_output(output):
|
||||
flatten = []
|
||||
for x in output:
|
||||
if isinstance(x, (tuple, list)):
|
||||
flatten += flatten_output(x)
|
||||
elif not isinstance(x, torch.Tensor):
|
||||
continue
|
||||
else:
|
||||
flatten.append(x)
|
||||
return flatten
|
||||
|
||||
model_output = flatten_output(model_output)
|
||||
traced_output = flatten_output(traced_output)
|
||||
num_outputs = len(model_output)
|
||||
|
||||
for i in range(num_outputs):
|
||||
self.assertTrue(
|
||||
torch.allclose(model_output[i], traced_output[i]),
|
||||
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
# Test that the model can be serialized and restored properly
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
|
||||
try:
|
||||
with open(pkl_file_name, "wb") as f:
|
||||
pickle.dump(traced_model, f)
|
||||
with open(pkl_file_name, "rb") as f:
|
||||
loaded = pickle.load(f)
|
||||
except Exception as e:
|
||||
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
|
||||
|
||||
loaded_output = loaded(**filtered_inputs)
|
||||
loaded_output = flatten_output(loaded_output)
|
||||
|
||||
for i in range(num_outputs):
|
||||
self.assertTrue(
|
||||
torch.allclose(model_output[i], loaded_output[i]),
|
||||
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
|
||||
@@ -179,6 +179,7 @@ class Speech2Text2StandaloneDecoderModelTester:
|
||||
class Speech2Text2StandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Speech2Text2Decoder, Speech2Text2ForCausalLM) if is_torch_available() else ()
|
||||
all_generative_model_classes = (Speech2Text2ForCausalLM,) if is_torch_available() else ()
|
||||
fx_compatible = True
|
||||
test_pruning = False
|
||||
|
||||
def setUp(
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch Swin model. """
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
@@ -26,7 +25,7 @@ from transformers.testing_utils import require_torch, require_vision, slow, torc
|
||||
from transformers.utils import cached_property, is_torch_available, is_torch_fx_available, is_vision_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -45,14 +44,6 @@ if is_torch_fx_available():
|
||||
from transformers.utils.fx import symbolic_trace
|
||||
|
||||
|
||||
def _config_zero_init(config):
|
||||
configs_no_init = copy.deepcopy(config)
|
||||
for key in configs_no_init.__dict__.keys():
|
||||
if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
|
||||
setattr(configs_no_init, key, 1e-10)
|
||||
return configs_no_init
|
||||
|
||||
|
||||
class SwinModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -407,7 +398,9 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
@@ -427,7 +420,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
input_names.append("end_positions")
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = filtered_inputs.keys()
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
|
||||
@@ -509,8 +509,8 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
|
||||
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
|
||||
fx_compatible = True
|
||||
all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
|
||||
fx_compatible = True
|
||||
test_pruning = False
|
||||
test_resize_embeddings = True
|
||||
test_model_parallel = True
|
||||
|
||||
@@ -161,6 +161,7 @@ class TrOCRStandaloneDecoderModelTester:
|
||||
class TrOCRStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TrOCRDecoder, TrOCRForCausalLM) if is_torch_available() else ()
|
||||
all_generative_model_classes = (TrOCRForCausalLM,) if is_torch_available() else ()
|
||||
fx_compatible = True
|
||||
test_pruning = False
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@@ -13,17 +13,26 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import datetime
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import XGLMConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
from transformers.utils import is_torch_fx_available
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
from ...test_modeling_common import (
|
||||
ModelTesterMixin,
|
||||
_config_zero_init,
|
||||
floats_tensor,
|
||||
ids_tensor,
|
||||
random_attention_mask,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -31,6 +40,9 @@ if is_torch_available():
|
||||
|
||||
from transformers import XGLM_PRETRAINED_MODEL_ARCHIVE_LIST, XGLMForCausalLM, XGLMModel, XGLMTokenizer
|
||||
|
||||
if is_torch_fx_available():
|
||||
from transformers.utils.fx import symbolic_trace
|
||||
|
||||
|
||||
class XGLMModelTester:
|
||||
def __init__(
|
||||
@@ -299,6 +311,7 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (XGLMModel, XGLMForCausalLM) if is_torch_available() else ()
|
||||
all_generative_model_classes = (XGLMForCausalLM,) if is_torch_available() else ()
|
||||
fx_compatible = True
|
||||
test_missing_keys = False
|
||||
test_pruning = False
|
||||
|
||||
@@ -337,6 +350,112 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xglm_weight_initialization(*config_and_inputs)
|
||||
|
||||
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
|
||||
if not is_torch_fx_available() or not self.fx_compatible:
|
||||
return
|
||||
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
configs_no_init.return_dict = False
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
|
||||
|
||||
try:
|
||||
if model.config.is_encoder_decoder:
|
||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||
labels = inputs.get("labels", None)
|
||||
input_names = [
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"decoder_input_ids",
|
||||
"decoder_attention_mask",
|
||||
"input_features",
|
||||
]
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
else:
|
||||
input_names = [
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"token_type_ids",
|
||||
"pixel_values",
|
||||
"bbox",
|
||||
"input_features",
|
||||
]
|
||||
|
||||
labels = inputs.get("labels", None)
|
||||
start_positions = inputs.get("start_positions", None)
|
||||
end_positions = inputs.get("end_positions", None)
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
if start_positions is not None:
|
||||
input_names.append("start_positions")
|
||||
if end_positions is not None:
|
||||
input_names.append("end_positions")
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
|
||||
except RuntimeError as e:
|
||||
self.fail(f"Couldn't trace module: {e}")
|
||||
|
||||
def flatten_output(output):
|
||||
flatten = []
|
||||
for x in output:
|
||||
if isinstance(x, (tuple, list)):
|
||||
flatten += flatten_output(x)
|
||||
elif not isinstance(x, torch.Tensor):
|
||||
continue
|
||||
else:
|
||||
flatten.append(x)
|
||||
return flatten
|
||||
|
||||
model_output = flatten_output(model_output)
|
||||
traced_output = flatten_output(traced_output)
|
||||
num_outputs = len(model_output)
|
||||
|
||||
for i in range(num_outputs):
|
||||
self.assertTrue(
|
||||
torch.allclose(model_output[i], traced_output[i]),
|
||||
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
# Test that the model can be serialized and restored properly
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
|
||||
try:
|
||||
with open(pkl_file_name, "wb") as f:
|
||||
pickle.dump(traced_model, f)
|
||||
with open(pkl_file_name, "rb") as f:
|
||||
loaded = pickle.load(f)
|
||||
except Exception as e:
|
||||
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
|
||||
|
||||
loaded_output = loaded(**filtered_inputs)
|
||||
loaded_output = flatten_output(loaded_output)
|
||||
|
||||
for i in range(num_outputs):
|
||||
self.assertTrue(
|
||||
torch.allclose(model_output[i], loaded_output[i]),
|
||||
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_batch_generation(self):
|
||||
model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")
|
||||
|
||||
@@ -526,6 +526,7 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
|
||||
all_generative_model_classes = (
|
||||
(XLNetLMHeadModel,) if is_torch_available() else ()
|
||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||
fx_compatible = False
|
||||
test_pruning = False
|
||||
|
||||
# XLNet has 2 QA models -> need to manually set the correct labels for one of them here
|
||||
|
||||
@@ -738,17 +738,32 @@ class ModelTesterMixin:
|
||||
if model.config.is_encoder_decoder:
|
||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||
labels = inputs.get("labels", None)
|
||||
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
|
||||
input_names = [
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"decoder_input_ids",
|
||||
"decoder_attention_mask",
|
||||
"input_features",
|
||||
]
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
else:
|
||||
input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"]
|
||||
input_names = [
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"token_type_ids",
|
||||
"pixel_values",
|
||||
"bbox",
|
||||
"input_features",
|
||||
]
|
||||
|
||||
labels = inputs.get("labels", None)
|
||||
start_positions = inputs.get("start_positions", None)
|
||||
@@ -761,7 +776,7 @@ class ModelTesterMixin:
|
||||
input_names.append("end_positions")
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = filtered_inputs.keys()
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user