FX tracing improvement (#14321)
* Change the way tracing happens, enabling dynamic axes out of the box * Update the tests and modeling xlnet * Add the non recoding of leaf modules to avoid recording more values for the methods to record than what will be seen at tracing time (which would otherwise desynchronize the recorded values and the values that need to be given to the proxies during tracing, causing errors). * Comments and making tracing work for gpt-j and xlnet * Refactore things related to num_choices (and batch_size, sequence_length) * Update fx to work on PyTorch 1.10 * Postpone autowrap_function feature usage for later * Add copyrights * Remove unnecessary file * Fix issue with add_new_model_like * Apply suggestions
This commit is contained in:
@@ -1189,6 +1189,16 @@ def create_new_model_like(
|
|||||||
if "tokenization" not in str(f) and "processor" not in str(f) and "feature_extraction" not in str(f)
|
if "tokenization" not in str(f) and "processor" not in str(f) and "feature_extraction" not in str(f)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def disable_fx_test(filename: Path) -> bool:
|
||||||
|
with open(filename) as fp:
|
||||||
|
content = fp.read()
|
||||||
|
new_content = re.sub(r"fx_compatible\s*=\s*True", "fx_compatible = False", content)
|
||||||
|
with open(filename, "w") as fp:
|
||||||
|
fp.write(new_content)
|
||||||
|
return content != new_content
|
||||||
|
|
||||||
|
disabled_fx_test = False
|
||||||
|
|
||||||
for test_file in files_to_adapt:
|
for test_file in files_to_adapt:
|
||||||
new_test_file_name = test_file.name.replace(
|
new_test_file_name = test_file.name.replace(
|
||||||
old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased
|
old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased
|
||||||
@@ -1201,6 +1211,13 @@ def create_new_model_like(
|
|||||||
dest_file=dest_file,
|
dest_file=dest_file,
|
||||||
add_copied_from=False,
|
add_copied_from=False,
|
||||||
)
|
)
|
||||||
|
disabled_fx_test = disabled_fx_test | disable_fx_test(dest_file)
|
||||||
|
|
||||||
|
if disabled_fx_test:
|
||||||
|
print(
|
||||||
|
"The tests for symbolic tracing with torch.fx were disabled, you can add those once symbolic tracing works "
|
||||||
|
"for your new model."
|
||||||
|
)
|
||||||
|
|
||||||
# 4. Add model to auto classes
|
# 4. Add model to auto classes
|
||||||
add_model_to_auto_classes(old_model_patterns, new_model_patterns, model_classes)
|
add_model_to_auto_classes(old_model_patterns, new_model_patterns, model_classes)
|
||||||
|
|||||||
@@ -322,7 +322,7 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HUGGINGFACE_CO_RESOLVE_ENDPOIN
|
|||||||
HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}"
|
HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}"
|
||||||
|
|
||||||
# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
|
# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
|
||||||
TORCH_FX_REQUIRED_VERSION = version.parse("1.9")
|
TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
|
||||||
TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION = version.parse("1.8")
|
TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION = version.parse("1.8")
|
||||||
|
|
||||||
_is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False
|
_is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False
|
||||||
|
|||||||
@@ -247,6 +247,27 @@ class ModuleUtilsMixin:
|
|||||||
|
|
||||||
return encoder_extended_attention_mask
|
return encoder_extended_attention_mask
|
||||||
|
|
||||||
|
def create_extended_attention_mask_for_decoder(self, input_shape, attention_mask, device):
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
seq_ids = torch.arange(seq_length, device=device)
|
||||||
|
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
||||||
|
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
|
||||||
|
# causal and attention masks must have same type with pytorch version < 1.3
|
||||||
|
causal_mask = causal_mask.to(attention_mask.dtype)
|
||||||
|
|
||||||
|
if causal_mask.shape[1] < attention_mask.shape[1]:
|
||||||
|
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
||||||
|
causal_mask = torch.cat(
|
||||||
|
[
|
||||||
|
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
|
||||||
|
causal_mask,
|
||||||
|
],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||||
|
return extended_attention_mask
|
||||||
|
|
||||||
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor:
|
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
||||||
@@ -271,26 +292,9 @@ class ModuleUtilsMixin:
|
|||||||
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
||||||
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
if self.config.is_decoder:
|
if self.config.is_decoder:
|
||||||
batch_size, seq_length = input_shape
|
extended_attention_mask = self.create_extended_attention_mask_for_decoder(
|
||||||
seq_ids = torch.arange(seq_length, device=device)
|
input_shape, attention_mask, device
|
||||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
|
||||||
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
|
|
||||||
# causal and attention masks must have same type with pytorch version < 1.3
|
|
||||||
causal_mask = causal_mask.to(attention_mask.dtype)
|
|
||||||
|
|
||||||
if causal_mask.shape[1] < attention_mask.shape[1]:
|
|
||||||
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
|
||||||
causal_mask = torch.cat(
|
|
||||||
[
|
|
||||||
torch.ones(
|
|
||||||
(batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype
|
|
||||||
),
|
|
||||||
causal_mask,
|
|
||||||
],
|
|
||||||
axis=-1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
|
||||||
else:
|
else:
|
||||||
extended_attention_mask = attention_mask[:, None, None, :]
|
extended_attention_mask = attention_mask[:, None, None, :]
|
||||||
else:
|
else:
|
||||||
@@ -1861,7 +1865,7 @@ class Conv1D(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
size_out = x.size()[:-1] + (self.nf,)
|
size_out = x.size()[:-1] + (self.nf,)
|
||||||
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
|
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
|
||||||
x = x.view(*size_out)
|
x = x.view(size_out)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -293,7 +293,7 @@ class AlbertAttention(nn.Module):
|
|||||||
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores
|
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores
|
||||||
def transpose_for_scores(self, x):
|
def transpose_for_scores(self, x):
|
||||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(new_x_shape)
|
||||||
return x.permute(0, 2, 1, 3)
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
|
|||||||
@@ -252,7 +252,7 @@ class BertSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def transpose_for_scores(self, x):
|
def transpose_for_scores(self, x):
|
||||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(new_x_shape)
|
||||||
return x.permute(0, 2, 1, 3)
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -341,7 +341,7 @@ class BertSelfAttention(nn.Module):
|
|||||||
|
|
||||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(new_context_layer_shape)
|
||||||
|
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
|
||||||
|
|||||||
@@ -245,7 +245,7 @@ class ElectraSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def transpose_for_scores(self, x):
|
def transpose_for_scores(self, x):
|
||||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(new_x_shape)
|
||||||
return x.permute(0, 2, 1, 3)
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -334,7 +334,7 @@ class ElectraSelfAttention(nn.Module):
|
|||||||
|
|
||||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(new_context_layer_shape)
|
||||||
|
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
|
||||||
|
|||||||
@@ -193,7 +193,7 @@ class GPT2Attention(nn.Module):
|
|||||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||||
|
|
||||||
if self.scale_attn_weights:
|
if self.scale_attn_weights:
|
||||||
attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)
|
attn_weights = attn_weights / (value.size(-1) ** 0.5)
|
||||||
|
|
||||||
# Layer-wise attention scaling
|
# Layer-wise attention scaling
|
||||||
if self.scale_attn_by_inverse_layer_idx:
|
if self.scale_attn_by_inverse_layer_idx:
|
||||||
@@ -281,7 +281,7 @@ class GPT2Attention(nn.Module):
|
|||||||
Splits hidden_size dim into attn_head_size and num_heads
|
Splits hidden_size dim into attn_head_size and num_heads
|
||||||
"""
|
"""
|
||||||
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
||||||
tensor = tensor.view(*new_shape)
|
tensor = tensor.view(new_shape)
|
||||||
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
||||||
|
|
||||||
def _merge_heads(self, tensor, num_heads, attn_head_size):
|
def _merge_heads(self, tensor, num_heads, attn_head_size):
|
||||||
@@ -915,7 +915,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
|
|
||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
|
||||||
hidden_states = hidden_states.view(*output_shape)
|
hidden_states = hidden_states.view(output_shape)
|
||||||
# Add last hidden state
|
# Add last hidden state
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
@@ -1410,7 +1410,7 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
|
|||||||
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||||
)
|
)
|
||||||
|
|
||||||
pooled_logits = logits[range(batch_size), sequence_lengths]
|
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
|||||||
@@ -173,7 +173,7 @@ class GPTNeoSelfAttention(nn.Module):
|
|||||||
Splits hidden_size dim into attn_head_size and num_heads
|
Splits hidden_size dim into attn_head_size and num_heads
|
||||||
"""
|
"""
|
||||||
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
||||||
tensor = tensor.view(*new_shape)
|
tensor = tensor.view(new_shape)
|
||||||
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
||||||
|
|
||||||
def _merge_heads(self, tensor, num_heads, attn_head_size):
|
def _merge_heads(self, tensor, num_heads, attn_head_size):
|
||||||
@@ -637,7 +637,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
|
|||||||
|
|
||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
|
||||||
hidden_states = hidden_states.view(*output_shape)
|
hidden_states = hidden_states.view(output_shape)
|
||||||
# Add last hidden state
|
# Add last hidden state
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
@@ -891,7 +891,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
|
|||||||
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||||
)
|
)
|
||||||
|
|
||||||
pooled_logits = logits[torch.arange(batch_size), sequence_lengths]
|
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ class GPTJAttention(nn.Module):
|
|||||||
Splits hidden dim into attn_head_size and num_attention_heads
|
Splits hidden dim into attn_head_size and num_attention_heads
|
||||||
"""
|
"""
|
||||||
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
|
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
|
||||||
tensor = tensor.view(*new_shape)
|
tensor = tensor.view(new_shape)
|
||||||
if rotary:
|
if rotary:
|
||||||
return tensor
|
return tensor
|
||||||
if len(tensor.shape) == 5:
|
if len(tensor.shape) == 5:
|
||||||
@@ -665,7 +665,7 @@ class GPTJModel(GPTJPreTrainedModel):
|
|||||||
|
|
||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
|
||||||
hidden_states = hidden_states.view(*output_shape)
|
hidden_states = hidden_states.view(output_shape)
|
||||||
# Add last hidden state
|
# Add last hidden state
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
@@ -945,7 +945,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
|
|||||||
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||||
)
|
)
|
||||||
|
|
||||||
pooled_logits = logits[range(batch_size), sequence_lengths]
|
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
|||||||
@@ -160,7 +160,7 @@ class LayoutLMSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def transpose_for_scores(self, x):
|
def transpose_for_scores(self, x):
|
||||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(new_x_shape)
|
||||||
return x.permute(0, 2, 1, 3)
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -249,7 +249,7 @@ class LayoutLMSelfAttention(nn.Module):
|
|||||||
|
|
||||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(new_context_layer_shape)
|
||||||
|
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
|
||||||
|
|||||||
@@ -223,7 +223,7 @@ class MegatronBertSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def transpose_for_scores(self, x):
|
def transpose_for_scores(self, x):
|
||||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(new_x_shape)
|
||||||
return x.permute(0, 2, 1, 3)
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -312,7 +312,7 @@ class MegatronBertSelfAttention(nn.Module):
|
|||||||
|
|
||||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(new_context_layer_shape)
|
||||||
|
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
|
||||||
|
|||||||
@@ -237,7 +237,7 @@ class MobileBertSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def transpose_for_scores(self, x):
|
def transpose_for_scores(self, x):
|
||||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(new_x_shape)
|
||||||
return x.permute(0, 2, 1, 3)
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -274,7 +274,7 @@ class MobileBertSelfAttention(nn.Module):
|
|||||||
context_layer = torch.matmul(attention_probs, value_layer)
|
context_layer = torch.matmul(attention_probs, value_layer)
|
||||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(new_context_layer_shape)
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|||||||
@@ -260,7 +260,7 @@ class RealmSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def transpose_for_scores(self, x):
|
def transpose_for_scores(self, x):
|
||||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(new_x_shape)
|
||||||
return x.permute(0, 2, 1, 3)
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -349,7 +349,7 @@ class RealmSelfAttention(nn.Module):
|
|||||||
|
|
||||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(new_context_layer_shape)
|
||||||
|
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
|
||||||
|
|||||||
@@ -187,7 +187,7 @@ class RobertaSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def transpose_for_scores(self, x):
|
def transpose_for_scores(self, x):
|
||||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(new_x_shape)
|
||||||
return x.permute(0, 2, 1, 3)
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -276,7 +276,7 @@ class RobertaSelfAttention(nn.Module):
|
|||||||
|
|
||||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(new_context_layer_shape)
|
||||||
|
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
|
||||||
|
|||||||
@@ -127,7 +127,7 @@ class SplinterSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def transpose_for_scores(self, x):
|
def transpose_for_scores(self, x):
|
||||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(new_x_shape)
|
||||||
return x.permute(0, 2, 1, 3)
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -216,7 +216,7 @@ class SplinterSelfAttention(nn.Module):
|
|||||||
|
|
||||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(new_context_layer_shape)
|
||||||
|
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
|
||||||
|
|||||||
@@ -181,7 +181,7 @@ class XLMRobertaXLSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def transpose_for_scores(self, x):
|
def transpose_for_scores(self, x):
|
||||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(new_x_shape)
|
||||||
return x.permute(0, 2, 1, 3)
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -270,7 +270,7 @@ class XLMRobertaXLSelfAttention(nn.Module):
|
|||||||
|
|
||||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(new_context_layer_shape)
|
||||||
|
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,24 @@
|
|||||||
import copy
|
# coding=utf-8
|
||||||
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
|
import math
|
||||||
import random
|
import random
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
from types import ModuleType
|
||||||
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@@ -26,17 +42,11 @@ from .. import (
|
|||||||
GPT2DoubleHeadsModel,
|
GPT2DoubleHeadsModel,
|
||||||
PretrainedConfig,
|
PretrainedConfig,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
|
XLNetForQuestionAnswering,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
from ..file_utils import TORCH_FX_REQUIRED_VERSION, importlib_metadata, is_torch_fx_available
|
from ..file_utils import TORCH_FX_REQUIRED_VERSION, importlib_metadata, is_torch_fx_available
|
||||||
from ..models.auto import get_values
|
from ..models.auto import get_values
|
||||||
from .fx_transformations import (
|
|
||||||
_cache_attributes,
|
|
||||||
_patch_arguments_,
|
|
||||||
_restore_attributes_,
|
|
||||||
transform_to_dynamic_input_,
|
|
||||||
transformation,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -46,6 +56,7 @@ def _generate_supported_model_classes(
|
|||||||
model_name: Type[PretrainedConfig],
|
model_name: Type[PretrainedConfig],
|
||||||
supported_tasks: Optional[Union[str, List[str]]] = None,
|
supported_tasks: Optional[Union[str, List[str]]] = None,
|
||||||
) -> List[Type[PreTrainedModel]]:
|
) -> List[Type[PreTrainedModel]]:
|
||||||
|
|
||||||
model_config_class = CONFIG_MAPPING[model_name]
|
model_config_class = CONFIG_MAPPING[model_name]
|
||||||
task_mapping = {
|
task_mapping = {
|
||||||
"default": MODEL_MAPPING,
|
"default": MODEL_MAPPING,
|
||||||
@@ -86,15 +97,10 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
|
|||||||
"gptj",
|
"gptj",
|
||||||
"gpt_neo",
|
"gpt_neo",
|
||||||
"t5",
|
"t5",
|
||||||
]
|
"roberta",
|
||||||
|
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
|
||||||
_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS_FOR_DYNAMIC_AXES = [
|
# "layoutlm",
|
||||||
"albert",
|
# "xlnet",
|
||||||
"bert",
|
|
||||||
"distilbert",
|
|
||||||
"mobilebert",
|
|
||||||
"electra",
|
|
||||||
"megatron-bert",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
_REGULAR_SUPPORTED_MODELS = []
|
_REGULAR_SUPPORTED_MODELS = []
|
||||||
@@ -106,21 +112,11 @@ for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS:
|
|||||||
|
|
||||||
_SPECIAL_SUPPORTED_MODELS = [
|
_SPECIAL_SUPPORTED_MODELS = [
|
||||||
GPT2DoubleHeadsModel,
|
GPT2DoubleHeadsModel,
|
||||||
|
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
|
||||||
|
# XLNetForQuestionAnswering,
|
||||||
]
|
]
|
||||||
_SUPPORTED_MODELS = tuple(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)
|
_SUPPORTED_MODELS = tuple(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)
|
||||||
|
|
||||||
_REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES = []
|
|
||||||
for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS_FOR_DYNAMIC_AXES:
|
|
||||||
if isinstance(item, dict):
|
|
||||||
_REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES.extend(_generate_supported_model_classes(**item))
|
|
||||||
else:
|
|
||||||
_REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES.extend(_generate_supported_model_classes(item))
|
|
||||||
|
|
||||||
_SPECIAL_SUPPORTED_MODELS_FOR_DYNAMIC_AXES = []
|
|
||||||
_SUPPORTED_MODELS_FOR_DYNAMIC_AXES = tuple(
|
|
||||||
_REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES + _SPECIAL_SUPPORTED_MODELS_FOR_DYNAMIC_AXES
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class HFProxy(Proxy):
|
class HFProxy(Proxy):
|
||||||
"""
|
"""
|
||||||
@@ -134,6 +130,7 @@ class HFProxy(Proxy):
|
|||||||
if hasattr(self, "tracer") and self.tracer is not None:
|
if hasattr(self, "tracer") and self.tracer is not None:
|
||||||
self.device = self.tracer.root.device
|
self.device = self.tracer.root.device
|
||||||
self.dtype = next(self.tracer.root.parameters()).dtype
|
self.dtype = next(self.tracer.root.parameters()).dtype
|
||||||
|
self.cache = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self):
|
def shape(self):
|
||||||
@@ -145,13 +142,181 @@ class HFProxy(Proxy):
|
|||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if self.cache is not None:
|
||||||
|
return self.cache == other
|
||||||
|
elif isinstance(other, HFProxy):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return super().__eq__(other)
|
||||||
|
|
||||||
def _wrap_method_for_model_recording(model, method_name, cache_name):
|
def __ne__(self, other):
|
||||||
|
return not self == other
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
if self.cache is not None:
|
||||||
|
if isinstance(self.cache, int):
|
||||||
|
return self.cache
|
||||||
|
elif isinstance(self.cache, (torch.Size, list, tuple)):
|
||||||
|
return len(self.cache)
|
||||||
|
else:
|
||||||
|
return super().__len__(self)
|
||||||
|
return super().__len__(self)
|
||||||
|
|
||||||
|
def __torch_function__(self, orig_method, types, args=None, kwargs=None):
|
||||||
|
proxy = super().__torch_function__(orig_method, types, args=args, kwargs=kwargs)
|
||||||
|
proxy.cache = self.cache
|
||||||
|
return proxy
|
||||||
|
|
||||||
|
|
||||||
|
def _function_to_leaf(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
|
"""Wrapper that marks func as a leaf function, meaning that it will not be traced through by HFTracer."""
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def _function_leaf_getter(func_name: str, mapping: Dict[str, Callable[..., Any]]) -> Callable[..., Any]:
|
||||||
|
@functools.wraps(mapping[func_name])
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
return mapping[func_name](*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def _create_recorded_proxy_method(proxy: HFProxy, method_name: str, cache_name: str, return_proxy: bool):
|
||||||
|
"""
|
||||||
|
Helper function that sets a recorded torch.Tensor method as a HFProxy method that will use the recorded values
|
||||||
|
during symbolic tracing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
original_method = getattr(torch.Tensor, method_name)
|
||||||
|
|
||||||
|
@functools.wraps(original_method)
|
||||||
|
def method(*args, **kwargs):
|
||||||
|
cache = getattr(args[0].tracer.root, cache_name)
|
||||||
|
res = cache.pop(0)
|
||||||
|
if return_proxy:
|
||||||
|
proxy = args[0].__torch_function__(
|
||||||
|
original_method,
|
||||||
|
None,
|
||||||
|
args=args,
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
|
proxy.cache = res
|
||||||
|
return proxy
|
||||||
|
return res
|
||||||
|
|
||||||
|
method.__name__ = method_name
|
||||||
|
bound_method = method.__get__(proxy, proxy.__class__)
|
||||||
|
setattr(proxy, method_name, bound_method)
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_tensor_methods(original_methods: Dict[str, Callable[..., Any]]):
|
||||||
|
"""Helper function that resets the monkey patched torch.Tensor methods to their original values."""
|
||||||
|
for name, method in original_methods.items():
|
||||||
|
setattr(torch.Tensor, name, method)
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
|
||||||
|
if forbidden_values is None:
|
||||||
|
forbidden_values = []
|
||||||
|
value = random.randint(low, high)
|
||||||
|
while value in forbidden_values:
|
||||||
|
value = random.randint(low, high)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class HFTracer(Tracer):
|
||||||
|
"""
|
||||||
|
Tracer that is able to symbolically trace models from the library. To do that, it uses the HFProxy instead of the
|
||||||
|
regular PyTorch torch.fx.Proxy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_DEFAULT_METHODS_TO_RECORD = {"__bool__": False, "size": True, "dim": False}
|
||||||
|
from transformers import modeling_utils
|
||||||
|
|
||||||
|
_FUNCTIONS_TO_AUTOWRAP = {
|
||||||
|
torch: {"arange", "zeros", "ones", "full_like", "eye"},
|
||||||
|
modeling_utils.ModuleUtilsMixin: {"create_extended_attention_mask_for_decoder"},
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, autowrap_modules=(math,), autowrap_functions=(), enable_cpatching=False):
|
||||||
|
|
||||||
|
# Loading the leaf functions register
|
||||||
|
self._leaf_functions_register = {}
|
||||||
|
for module, names in self._FUNCTIONS_TO_AUTOWRAP.items():
|
||||||
|
for name in names:
|
||||||
|
self._register_leaf_function(module, name)
|
||||||
|
|
||||||
|
# TODO: adapt the way leaf function are wrapped with the "autowrap function" feature from Tracer.
|
||||||
|
# autowrap_functions = autowrap_functions + tuple(
|
||||||
|
# patched for (_, _, patched) in self._leaf_functions_register.values()
|
||||||
|
# )
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions, enable_cpatching=enable_cpatching
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_torch_fx_available():
|
||||||
|
torch_version = version.parse(importlib_metadata.version("torch"))
|
||||||
|
raise ImportError(
|
||||||
|
f"Found an incompatible version of torch. Found version {torch_version}, but only version "
|
||||||
|
f"{TORCH_FX_REQUIRED_VERSION} is supported."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.prev_module = None
|
||||||
|
self.recorded_methods = None
|
||||||
|
|
||||||
|
def _register_leaf_function(self, module: ModuleType, name: str):
|
||||||
|
"""Registers the function called name in module as a leaf function."""
|
||||||
|
orig_func = getattr(module, name)
|
||||||
|
patched_func = _function_to_leaf(orig_func)
|
||||||
|
patched_func.__module__ = __name__
|
||||||
|
self._leaf_functions_register[name] = (module, orig_func, patched_func)
|
||||||
|
|
||||||
|
def _patch_leaf_functions_for_root(self, root: PreTrainedModel, restore: bool = False):
|
||||||
|
"""Patches leaf functions specifically for root."""
|
||||||
|
for name in self._leaf_functions_register:
|
||||||
|
module, orig_func, patched_func = self._leaf_functions_register[name]
|
||||||
|
if restore:
|
||||||
|
root.__class__.forward.__globals__.pop(name)
|
||||||
|
setattr(module, name, orig_func)
|
||||||
|
else:
|
||||||
|
root.__class__.forward.__globals__[name] = patched_func
|
||||||
|
leaf_getter = _function_leaf_getter(name, root.__class__.forward.__globals__)
|
||||||
|
leaf_getter.__module__ = __name__
|
||||||
|
setattr(module, name, leaf_getter)
|
||||||
|
|
||||||
|
def _method_is_called_in_leaf_module(self, module_ids: List[int]) -> bool:
|
||||||
|
"""
|
||||||
|
Finds out if the method (that is being recorded) is called inside a leaf module, this allows to not record
|
||||||
|
outputs that will not be encountered by the tracer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
currentframe = inspect.currentframe()
|
||||||
|
while currentframe:
|
||||||
|
if currentframe is None:
|
||||||
|
return False
|
||||||
|
module = currentframe.f_locals.get("self", None)
|
||||||
|
if id(module) in module_ids and self.is_leaf_module(module, "Not used anyway"):
|
||||||
|
return True
|
||||||
|
currentframe = currentframe.f_back
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _wrap_method_for_model_recording(
|
||||||
|
self, model: PreTrainedModel, method_name: str, cache_name: str, module_ids: List[int]
|
||||||
|
):
|
||||||
"""Helper function that wraps a torch.Tensor method to record its outputs during forward pass."""
|
"""Helper function that wraps a torch.Tensor method to record its outputs during forward pass."""
|
||||||
method = getattr(torch.Tensor, method_name)
|
method = getattr(torch.Tensor, method_name)
|
||||||
|
|
||||||
@functools.wraps(method)
|
@functools.wraps(method)
|
||||||
def wrapped(*args, **kwargs):
|
def wrapped(*args, **kwargs):
|
||||||
|
if self._method_is_called_in_leaf_module(module_ids):
|
||||||
|
return method(*args, **kwargs)
|
||||||
if not hasattr(model, cache_name):
|
if not hasattr(model, cache_name):
|
||||||
setattr(model, cache_name, [])
|
setattr(model, cache_name, [])
|
||||||
cache = getattr(model, cache_name)
|
cache = getattr(model, cache_name)
|
||||||
@@ -161,50 +326,14 @@ def _wrap_method_for_model_recording(model, method_name, cache_name):
|
|||||||
|
|
||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
|
def _monkey_patch_tensor_methods_for_model_recording(self, model: PreTrainedModel, method_names: Iterable[str]):
|
||||||
def _create_recorded_proxy_method(proxy, method_name, cache_name):
|
|
||||||
"""
|
"""
|
||||||
Helper function that sets a recorded torch.Tensor method as a HFProxy method that will use the recorded values
|
Helper function that patches torch.Tensor methods (specified by the method_names list) to record model
|
||||||
during symbolic tracing.
|
inference before symbolic tracing.
|
||||||
"""
|
"""
|
||||||
|
cache_names = {}
|
||||||
def method(self, *args, **kwargs):
|
original_methods = {}
|
||||||
cache = getattr(self.tracer.root, cache_name)
|
module_ids = set(id(mod) for mod in model.modules())
|
||||||
res = cache.pop(0)
|
|
||||||
return res
|
|
||||||
|
|
||||||
method.__name__ = method_name
|
|
||||||
bound_method = method.__get__(proxy, proxy.__class__)
|
|
||||||
setattr(proxy, method_name, bound_method)
|
|
||||||
|
|
||||||
|
|
||||||
def _wrap_method_for_model_tracing(model, method_name, cache_name):
|
|
||||||
"""
|
|
||||||
Helper function that sets a recorded torch.Tensor method as a torch.Tensor method that will use the recorded values
|
|
||||||
during symbolic tracing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
original_method = getattr(torch.Tensor, method_name)
|
|
||||||
|
|
||||||
@functools.wraps(original_method)
|
|
||||||
def method(*args, **kwargs):
|
|
||||||
cache = getattr(model, cache_name)
|
|
||||||
res = cache.pop(0)
|
|
||||||
return res
|
|
||||||
|
|
||||||
setattr(torch.Tensor, method_name, method)
|
|
||||||
|
|
||||||
if method_name == "size":
|
|
||||||
setattr(torch.Tensor, "shape", property(getattr(torch.Tensor, method_name)))
|
|
||||||
|
|
||||||
|
|
||||||
def _monkey_patch_tensor_methods_for_model_recording(model, method_names):
|
|
||||||
"""
|
|
||||||
Helper function that patches torch.Tensor methods (specified by the method_names list) to record model inference
|
|
||||||
before symbolic tracing.
|
|
||||||
"""
|
|
||||||
cache_names = dict()
|
|
||||||
original_methods = dict()
|
|
||||||
for method_name in method_names:
|
for method_name in method_names:
|
||||||
cache_name = f"cache_{method_name}"
|
cache_name = f"cache_{method_name}"
|
||||||
cache_names[method_name] = cache_name
|
cache_names[method_name] = cache_name
|
||||||
@@ -212,7 +341,11 @@ def _monkey_patch_tensor_methods_for_model_recording(model, method_names):
|
|||||||
logger.info(f"torch.Tensor has no method called {method_name}, skipping patching.")
|
logger.info(f"torch.Tensor has no method called {method_name}, skipping patching.")
|
||||||
continue
|
continue
|
||||||
original_methods[method_name] = getattr(torch.Tensor, method_name)
|
original_methods[method_name] = getattr(torch.Tensor, method_name)
|
||||||
setattr(torch.Tensor, method_name, _wrap_method_for_model_recording(model, method_name, cache_name))
|
setattr(
|
||||||
|
torch.Tensor,
|
||||||
|
method_name,
|
||||||
|
self._wrap_method_for_model_recording(model, method_name, cache_name, module_ids),
|
||||||
|
)
|
||||||
|
|
||||||
if method_name == "size":
|
if method_name == "size":
|
||||||
original_methods["shape"] = torch.Tensor.shape
|
original_methods["shape"] = torch.Tensor.shape
|
||||||
@@ -220,65 +353,22 @@ def _monkey_patch_tensor_methods_for_model_recording(model, method_names):
|
|||||||
|
|
||||||
return cache_names, original_methods
|
return cache_names, original_methods
|
||||||
|
|
||||||
|
def _generate_dummy_input(
|
||||||
def _reset_tensor_methods(original_methods):
|
self, model: PreTrainedModel, input_name: str, shape: List[int]
|
||||||
"""Helper function that resets the monkey patched torch.Tensor methods to their original values."""
|
) -> Dict[str, torch.Tensor]:
|
||||||
for name, method in original_methods.items():
|
|
||||||
setattr(torch.Tensor, name, method)
|
|
||||||
|
|
||||||
|
|
||||||
class HFTracer(Tracer):
|
|
||||||
"""
|
|
||||||
Tracer that is able to symbolically trace models from the library. To do that, it uses the HFProxy instead of the
|
|
||||||
regular PyTorch torch.fx.Proxy.
|
|
||||||
"""
|
|
||||||
|
|
||||||
default_methods_to_record = {"__bool__", "size", "dim"}
|
|
||||||
|
|
||||||
def __init__(self, batch_size=1, sequence_length=[128, 128], num_choices=-1):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if not is_torch_fx_available():
|
|
||||||
torch_version = version.parse(importlib_metadata.version("torch"))
|
|
||||||
raise ImportError(
|
|
||||||
f"Found an incompatible version of torch. Found version {torch_version}, but only version "
|
|
||||||
f"{TORCH_FX_REQUIRED_VERSION} is supported."
|
|
||||||
)
|
|
||||||
|
|
||||||
encoder_sequence_length = sequence_length[0] if isinstance(sequence_length, (list, tuple)) else sequence_length
|
|
||||||
decoder_sequence_length = (
|
|
||||||
sequence_length[1] if isinstance(sequence_length, (list, tuple)) else encoder_sequence_length
|
|
||||||
)
|
|
||||||
self.encoder_shape = [batch_size, encoder_sequence_length]
|
|
||||||
self.decoder_shape = (
|
|
||||||
[batch_size, decoder_sequence_length] if decoder_sequence_length > 0 else list(self.encoder_shape)
|
|
||||||
)
|
|
||||||
self.num_choices = num_choices
|
|
||||||
if self.num_choices > 0:
|
|
||||||
self.encoder_shape = [batch_size, self.num_choices, encoder_sequence_length]
|
|
||||||
self.decoder_shape = [batch_size, self.num_choices, decoder_sequence_length]
|
|
||||||
|
|
||||||
self.prev_module = None
|
|
||||||
self.recorded_methods = None
|
|
||||||
|
|
||||||
def proxy(self, node: Node):
|
|
||||||
p = HFProxy(node, self)
|
|
||||||
if self.recorded_methods:
|
|
||||||
for method_name, cache_name in self.recorded_methods.items():
|
|
||||||
_create_recorded_proxy_method(p, method_name, cache_name)
|
|
||||||
return p
|
|
||||||
|
|
||||||
def _generate_dummy_input(self, model, input_name):
|
|
||||||
"""Generates dummy input for model inference recording."""
|
"""Generates dummy input for model inference recording."""
|
||||||
model_class = model.__class__
|
model_class = model.__class__
|
||||||
device = model.device
|
device = model.device
|
||||||
inputs_dict = dict()
|
inputs_dict = {}
|
||||||
|
|
||||||
if input_name in ["labels", "start_positions", "end_positions"]:
|
if input_name in ["labels", "start_positions", "end_positions"]:
|
||||||
batch_size = self.encoder_shape[0]
|
batch_size = shape[0]
|
||||||
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||||
inputs_dict["labels"] = torch.ones(batch_size, dtype=torch.long, device=device)
|
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||||
elif model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
|
elif model_class in [
|
||||||
|
*get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING),
|
||||||
|
XLNetForQuestionAnswering,
|
||||||
|
]:
|
||||||
inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
|
inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||||
inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
|
inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||||
elif model_class in [
|
elif model_class in [
|
||||||
@@ -288,59 +378,56 @@ class HFTracer(Tracer):
|
|||||||
]:
|
]:
|
||||||
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
|
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||||
elif model_class in [
|
elif model_class in [
|
||||||
|
*get_values(MODEL_FOR_PRETRAINING_MAPPING),
|
||||||
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
|
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
|
||||||
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING),
|
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING),
|
||||||
*get_values(MODEL_FOR_MASKED_LM_MAPPING),
|
*get_values(MODEL_FOR_MASKED_LM_MAPPING),
|
||||||
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
|
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
|
||||||
GPT2DoubleHeadsModel,
|
GPT2DoubleHeadsModel,
|
||||||
]:
|
]:
|
||||||
inputs_dict["labels"] = torch.zeros(self.decoder_shape, dtype=torch.long, device=device)
|
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
|
||||||
elif model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
|
|
||||||
inputs_dict["labels"] = torch.zeros(self.encoder_shape, dtype=torch.long, device=device)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"{model_class} not supported yet.")
|
raise NotImplementedError(f"{model_class} not supported yet.")
|
||||||
|
|
||||||
elif "mask" in input_name or "ids" in input_name:
|
elif "mask" in input_name or "ids" in input_name:
|
||||||
shape = self.encoder_shape if "decoder" not in input_name else self.decoder_shape
|
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
|
||||||
inputs_dict[input_name] = torch.ones(shape, dtype=torch.long, device=device)
|
|
||||||
else:
|
else:
|
||||||
shape = self.encoder_shape if "decoder" not in input_name else self.decoder_shape
|
shape_with_hidden_size = shape + [model.config.hidden_size]
|
||||||
shape += [model.config.hidden_size]
|
inputs_dict[input_name] = torch.zeros(shape_with_hidden_size, dtype=torch.float, device=device)
|
||||||
inputs_dict[input_name] = torch.ones(shape, dtype=torch.float, device=device)
|
|
||||||
|
|
||||||
return inputs_dict
|
return inputs_dict
|
||||||
|
|
||||||
def record(self, model, input_names, method_names=None):
|
def record(self, model: PreTrainedModel, input_names: List[str], method_names: Optional[Iterable[str]] = None):
|
||||||
"""
|
"""
|
||||||
Records torch.Tensor method outputs (specified by the method_names list) that will then be used during symbolic
|
Records torch.Tensor method outputs (specified by method_names) that will then be used during symbolic tracing.
|
||||||
tracing.
|
|
||||||
"""
|
"""
|
||||||
if method_names is None:
|
if method_names is None:
|
||||||
method_names = self.default_methods_to_record
|
method_names = self._DEFAULT_METHODS_TO_RECORD
|
||||||
|
|
||||||
|
# Creating a random input shape to generate dummy inputs.
|
||||||
|
batch_size = _generate_random_int()
|
||||||
|
sequence_length = _generate_random_int()
|
||||||
|
shape = [batch_size, sequence_length]
|
||||||
|
|
||||||
|
if model.__class__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||||
|
num_choices = _generate_random_int(low=2, high=5)
|
||||||
|
shape.insert(1, num_choices)
|
||||||
|
|
||||||
inputs = {}
|
inputs = {}
|
||||||
for input_name in input_names:
|
for input_name in input_names:
|
||||||
inputs.update(self._generate_dummy_input(model, input_name))
|
inputs.update(self._generate_dummy_input(model, input_name, shape))
|
||||||
|
|
||||||
clone = copy.deepcopy(model)
|
cache_names, original_methods = self._monkey_patch_tensor_methods_for_model_recording(model, method_names)
|
||||||
cache_names, original_methods = _monkey_patch_tensor_methods_for_model_recording(clone, method_names)
|
|
||||||
self.original_methods = original_methods
|
self.original_methods = original_methods
|
||||||
|
|
||||||
clone(**inputs)
|
model(**inputs)
|
||||||
|
|
||||||
# Useful because sometime the config is changed at inference time, for instance for
|
|
||||||
# classification tasks where config.problem_type can be set.
|
|
||||||
model.config = clone.config
|
|
||||||
|
|
||||||
_reset_tensor_methods(original_methods)
|
_reset_tensor_methods(original_methods)
|
||||||
|
|
||||||
self.recorded_methods = {
|
self.recorded_methods = {
|
||||||
method_name: cache_name for method_name, cache_name in cache_names.items() if hasattr(clone, cache_name)
|
method_name: cache_name for method_name, cache_name in cache_names.items() if hasattr(model, cache_name)
|
||||||
}
|
}
|
||||||
|
|
||||||
for cache_name in self.recorded_methods.values():
|
|
||||||
setattr(model, cache_name, getattr(clone, cache_name))
|
|
||||||
|
|
||||||
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
|
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
|
||||||
if isinstance(attr_val, torch.nn.Parameter):
|
if isinstance(attr_val, torch.nn.Parameter):
|
||||||
for n, p in self.root.named_parameters():
|
for n, p in self.root.named_parameters():
|
||||||
@@ -357,7 +444,20 @@ class HFTracer(Tracer):
|
|||||||
return parameter_proxy_cache[n]
|
return parameter_proxy_cache[n]
|
||||||
return attr_val
|
return attr_val
|
||||||
|
|
||||||
def trace(self, root: PreTrainedModel, concrete_args: Optional[Dict[str, Any]] = None, method_names=None) -> Graph:
|
def proxy(self, node: Node):
|
||||||
|
p = HFProxy(node, self)
|
||||||
|
if self.recorded_methods:
|
||||||
|
for method_name, cache_name in self.recorded_methods.items():
|
||||||
|
return_proxy = self._DEFAULT_METHODS_TO_RECORD[method_name]
|
||||||
|
_create_recorded_proxy_method(p, method_name, cache_name, return_proxy)
|
||||||
|
return p
|
||||||
|
|
||||||
|
def trace(
|
||||||
|
self,
|
||||||
|
root: PreTrainedModel,
|
||||||
|
concrete_args: Optional[Dict[str, Any]] = None,
|
||||||
|
method_names: Optional[Iterable[str]] = None,
|
||||||
|
) -> Graph:
|
||||||
if concrete_args is None:
|
if concrete_args is None:
|
||||||
concrete_args = {}
|
concrete_args = {}
|
||||||
|
|
||||||
@@ -366,11 +466,16 @@ class HFTracer(Tracer):
|
|||||||
|
|
||||||
self.record(root, input_names, method_names=method_names)
|
self.record(root, input_names, method_names=method_names)
|
||||||
|
|
||||||
for method_name, cache_name in self.recorded_methods.items():
|
# TODO: adapt the way leaf function are wrapped with the "autowrap function" feature from Tracer.
|
||||||
_wrap_method_for_model_tracing(root, method_name, cache_name)
|
autowrap_functions = [patched for (_, _, patched) in self._leaf_functions_register.values()]
|
||||||
|
self._autowrap_function_ids.update(set([id(f) for f in autowrap_functions]))
|
||||||
|
|
||||||
|
self._patch_leaf_functions_for_root(root)
|
||||||
|
|
||||||
graph = super().trace(root, concrete_args=concrete_args)
|
graph = super().trace(root, concrete_args=concrete_args)
|
||||||
|
|
||||||
|
self._patch_leaf_functions_for_root(root, restore=True)
|
||||||
|
|
||||||
_reset_tensor_methods(self.original_methods)
|
_reset_tensor_methods(self.original_methods)
|
||||||
|
|
||||||
# TODO: keep this until necessary.
|
# TODO: keep this until necessary.
|
||||||
@@ -388,7 +493,7 @@ class HFTracer(Tracer):
|
|||||||
|
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
def _insert_module_as_submodule(self, mod):
|
def _insert_module_as_submodule(self, mod: nn.Module) -> str:
|
||||||
"""
|
"""
|
||||||
Helper method which tries to insert a module that was not declared as submodule.
|
Helper method which tries to insert a module that was not declared as submodule.
|
||||||
"""
|
"""
|
||||||
@@ -434,72 +539,19 @@ class HFTracer(Tracer):
|
|||||||
self.prev_module = path
|
self.prev_module = path
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool:
|
||||||
|
is_loss_module = m.__module__.startswith("torch.nn.modules.loss")
|
||||||
|
return (not is_loss_module) and super().is_leaf_module(m, module_qualified_name)
|
||||||
|
|
||||||
def create_arg(self, a: Any) -> Argument:
|
def create_arg(self, a: Any) -> Argument:
|
||||||
if isinstance(a, range):
|
if isinstance(a, range):
|
||||||
return super().create_arg(list(a))
|
return super().create_arg(list(a))
|
||||||
return super().create_arg(a)
|
return super().create_arg(a)
|
||||||
|
|
||||||
|
|
||||||
@transformation
|
|
||||||
def prepare_for_retracing(gm: GraphModule) -> Tuple[GraphModule, Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Prepares a GraphModule produced by symbolic_trace for retracing by:
|
|
||||||
|
|
||||||
- Caching all the attributes specific to the way the model was initially traced
|
|
||||||
- Patching back the model to a "static input shapes" version if it was traced to accept dynamic input shapes
|
|
||||||
For instance, the need to retrace a GraphModule can happen when applying quantization.
|
|
||||||
"""
|
|
||||||
attributes = _cache_attributes(gm)
|
|
||||||
_patch_arguments_(gm, gm.dynamic2static)
|
|
||||||
|
|
||||||
return gm, attributes
|
|
||||||
|
|
||||||
|
|
||||||
def restore_after_retracing_(gm: GraphModule, attributes: Dict[str, Any]):
|
|
||||||
"""Restores a GraphModule that was retraced to its initial state in terms of static / dynamic input shapes."""
|
|
||||||
_restore_attributes_(gm, attributes)
|
|
||||||
# transform_to_dynamic_input_ will override the static2dynamic and dynamic2static dictionaries which is the desired
|
|
||||||
# behaviour as the previously restored dictionaries contain nodes from the original GraphModule as values.
|
|
||||||
transform_to_dynamic_input_(gm, is_retracing=True)
|
|
||||||
_patch_arguments_(gm, gm.static2dynamic)
|
|
||||||
return gm
|
|
||||||
|
|
||||||
|
|
||||||
def retrace_graph_with(
|
|
||||||
gm: GraphModule, tracer: Tracer = None, func: Callable[[GraphModule], GraphModule] = None
|
|
||||||
) -> GraphModule:
|
|
||||||
"""
|
|
||||||
Retraces a GraphModule by either using a tracer or a function using a tracer (for instance
|
|
||||||
torch.quantization.quantize_fx.prepare_fx). It takes care of preparing the model for retracing, retracing it and
|
|
||||||
restoring anything necessary after the retrace.
|
|
||||||
"""
|
|
||||||
if tracer is None and func is None:
|
|
||||||
raise ValueError("Either a tracer or a function using a tracer must be provided.")
|
|
||||||
elif tracer is not None and func is not None:
|
|
||||||
raise ValueError("Either provide a tracer or a function using a tracer, but not both.")
|
|
||||||
else:
|
|
||||||
gm, attributes = prepare_for_retracing(gm)
|
|
||||||
tracing_func = tracer.trace if tracer else func
|
|
||||||
traced = tracing_func(gm)
|
|
||||||
restore_after_retracing_(traced, attributes)
|
|
||||||
return traced
|
|
||||||
|
|
||||||
|
|
||||||
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
|
|
||||||
if forbidden_values is None:
|
|
||||||
forbidden_values = []
|
|
||||||
value = random.randint(low, high)
|
|
||||||
while value in forbidden_values:
|
|
||||||
value = random.randint(low, high)
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def symbolic_trace(
|
def symbolic_trace(
|
||||||
model: PreTrainedModel,
|
model: PreTrainedModel,
|
||||||
input_names: Optional[List[str]] = None,
|
input_names: Optional[List[str]] = None,
|
||||||
batch_size: int = 1,
|
|
||||||
sequence_length: Union[int, List[int], Tuple[int]] = (128, 128),
|
|
||||||
num_choices: int = -1,
|
|
||||||
) -> GraphModule:
|
) -> GraphModule:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -510,14 +562,6 @@ def symbolic_trace(
|
|||||||
The model to trace.
|
The model to trace.
|
||||||
input_names (`List[str]`, *optional*):
|
input_names (`List[str]`, *optional*):
|
||||||
The names of the inputs of the traced model. If unset, model.dummy_inputs().keys() are used instead.
|
The names of the inputs of the traced model. If unset, model.dummy_inputs().keys() are used instead.
|
||||||
batch_size (`int`, *optional*, defaults to 1):
|
|
||||||
The batch size of the traced model inputs.
|
|
||||||
sequence_length (`int` or `List[int]]`):
|
|
||||||
The sequence length of the traced model inputs. For sequence-to-sequence models with different sequence
|
|
||||||
lengths between the encoder and the decoder inputs, this must be `[encoder_sequence_length,
|
|
||||||
decoder_sequence_length]`.
|
|
||||||
num_choices (`int`, *optional*, defaults to -1):
|
|
||||||
The number of possible choices for a multiple choice task.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model.
|
`torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model.
|
||||||
@@ -527,72 +571,24 @@ def symbolic_trace(
|
|||||||
```python
|
```python
|
||||||
from transformers.utils.fx import symbolic_trace
|
from transformers.utils.fx import symbolic_trace
|
||||||
|
|
||||||
traced_model = symbolic_trace(
|
traced_model = symbolic_trace(model, input_names=["input_ids", "attention_mask", "token_type_ids"])
|
||||||
model,
|
```
|
||||||
input_names=["input_ids", "attention_mask", "token_type_ids"],
|
"""
|
||||||
batch_size=1,
|
|
||||||
sequence_length=128,
|
|
||||||
)
|
|
||||||
```"""
|
|
||||||
if input_names is None:
|
if input_names is None:
|
||||||
input_names = model.dummy_inputs.keys()
|
input_names = model.dummy_inputs.keys()
|
||||||
|
|
||||||
sig = inspect.signature(model.forward)
|
sig = inspect.signature(model.forward)
|
||||||
concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
|
concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
|
||||||
|
|
||||||
# Preparing HFTracer batch_size and sequence_lenght values for potential dynamic axes.
|
|
||||||
use_dynamic_batch_size = batch_size <= 0
|
|
||||||
if isinstance(sequence_length, (list, tuple)):
|
|
||||||
use_dynamic_sequence_length = sequence_length[0] <= 0 or sequence_length[1] <= 0
|
|
||||||
else:
|
|
||||||
use_dynamic_sequence_length = sequence_length <= 0
|
|
||||||
|
|
||||||
if use_dynamic_batch_size or use_dynamic_sequence_length:
|
|
||||||
forbidden_values = [
|
|
||||||
model.config.num_attention_heads,
|
|
||||||
model.config.hidden_size,
|
|
||||||
model.config.hidden_size // model.config.num_attention_heads,
|
|
||||||
]
|
|
||||||
if use_dynamic_batch_size:
|
|
||||||
batch_size = _generate_random_int(forbidden_values=forbidden_values)
|
|
||||||
forbidden_values.append(batch_size)
|
|
||||||
if use_dynamic_sequence_length:
|
|
||||||
encoder_sequence_length = _generate_random_int(forbidden_values=forbidden_values)
|
|
||||||
forbidden_values.append(encoder_sequence_length)
|
|
||||||
decoder_sequence_length = _generate_random_int(forbidden_values=forbidden_values)
|
|
||||||
sequence_length = [encoder_sequence_length, decoder_sequence_length]
|
|
||||||
|
|
||||||
if not isinstance(model, _SUPPORTED_MODELS):
|
if not isinstance(model, _SUPPORTED_MODELS):
|
||||||
supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS))
|
supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS))
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"
|
f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"
|
||||||
)
|
)
|
||||||
if (use_dynamic_batch_size or use_dynamic_sequence_length) and not isinstance(
|
|
||||||
model, _SUPPORTED_MODELS_FOR_DYNAMIC_AXES
|
|
||||||
):
|
|
||||||
supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS_FOR_DYNAMIC_AXES))
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Dynamic axes are not supported for {model.__class__.__name__} yet, supported models: {supported_model_names}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Tracing.
|
# Tracing.
|
||||||
tracer = HFTracer(batch_size=batch_size, sequence_length=sequence_length, num_choices=num_choices)
|
tracer = HFTracer()
|
||||||
|
|
||||||
traced_graph = tracer.trace(model, concrete_args=concrete_args)
|
traced_graph = tracer.trace(model, concrete_args=concrete_args)
|
||||||
traced = torch.fx.GraphModule(model, traced_graph)
|
traced = torch.fx.GraphModule(model, traced_graph)
|
||||||
|
|
||||||
traced.config = copy.deepcopy(model.config)
|
|
||||||
traced.num_choices = num_choices
|
|
||||||
traced.dummy_inputs = {}
|
|
||||||
|
|
||||||
for name in input_names:
|
|
||||||
traced.dummy_inputs.update(tracer._generate_dummy_input(model, name))
|
|
||||||
|
|
||||||
traced.use_dynamic_batch_size = use_dynamic_batch_size
|
|
||||||
traced.use_dynamic_sequence_length = use_dynamic_sequence_length
|
|
||||||
traced.static_batch_size = batch_size
|
|
||||||
traced.static_sequence_length = sequence_length
|
|
||||||
|
|
||||||
transform_to_dynamic_input_(traced)
|
|
||||||
|
|
||||||
return traced
|
return traced
|
||||||
|
|||||||
@@ -1,321 +0,0 @@
|
|||||||
import copy
|
|
||||||
import functools
|
|
||||||
import operator
|
|
||||||
from inspect import signature
|
|
||||||
from typing import Any, Callable, Dict, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.fx import Graph, GraphModule, Node
|
|
||||||
|
|
||||||
|
|
||||||
# Torch FX transformation convention:
|
|
||||||
# - transformations that are supposed to act on a copy of the original GraphModule are decorated with @transformation
|
|
||||||
# - transformations that are inplace have a name ending with "_"
|
|
||||||
|
|
||||||
|
|
||||||
def _cache_attributes(gm: GraphModule) -> Dict[str, Any]:
|
|
||||||
attributes_to_keep = [
|
|
||||||
"config",
|
|
||||||
"num_choices",
|
|
||||||
"dummy_inputs",
|
|
||||||
"use_dynamic_batch_size",
|
|
||||||
"use_dynamic_sequence_length",
|
|
||||||
"static_batch_size",
|
|
||||||
"static_sequence_length",
|
|
||||||
"static2dynamic",
|
|
||||||
"dynamic2static",
|
|
||||||
]
|
|
||||||
attributes = {k: getattr(gm, k, None) for k in attributes_to_keep}
|
|
||||||
return attributes
|
|
||||||
|
|
||||||
|
|
||||||
def _restore_attributes_(gm: GraphModule, attributes: Dict[str, Any]):
|
|
||||||
for name, attr in attributes.items():
|
|
||||||
setattr(gm, name, attr)
|
|
||||||
|
|
||||||
|
|
||||||
def deepcopy_graph(gm: GraphModule) -> GraphModule:
|
|
||||||
"""
|
|
||||||
Performs a deepcopy of the GraphModule while also copying the relevant attributes to know whether the model was
|
|
||||||
traced with dynamic axes, and what were the values if that is the case.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# First, create a copy of the module without the graph.
|
|
||||||
graph = gm.__dict__.pop("_graph")
|
|
||||||
fake_mod = torch.nn.Module()
|
|
||||||
fake_mod.__dict__ = copy.deepcopy(gm.__dict__)
|
|
||||||
gm.__dict__["_graph"] = graph
|
|
||||||
|
|
||||||
# Then, copy the graph.
|
|
||||||
val_map = {}
|
|
||||||
graph_clone = Graph()
|
|
||||||
output_val = graph_clone.graph_copy(graph, val_map=val_map)
|
|
||||||
graph_clone.output(output_val)
|
|
||||||
|
|
||||||
# Finally create a new GraphModule (or a subclass of GraphModule) from the module and the graph copies.
|
|
||||||
# gm.__class__ is used to take into account that gm can be an instance of a subclass of GraphModule.
|
|
||||||
clone = gm.__class__(fake_mod, graph_clone)
|
|
||||||
|
|
||||||
# Restore the dynamic axes related attributes to the clone.
|
|
||||||
attributes = _cache_attributes(gm)
|
|
||||||
attributes["dynamic2static"] = {val_map.get(k, k): v for k, v in attributes["dynamic2static"].items()}
|
|
||||||
attributes["static2dynamic"] = {v: k for k, v in attributes["dynamic2static"].items()}
|
|
||||||
_restore_attributes_(clone, attributes)
|
|
||||||
|
|
||||||
return clone
|
|
||||||
|
|
||||||
|
|
||||||
def transformation(func):
|
|
||||||
"""
|
|
||||||
Decorator that wraps a torch.fx transformation by feeding it a copy of the GraphModule to transform instead of the
|
|
||||||
original.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def map_fn(arg):
|
|
||||||
if isinstance(arg, GraphModule):
|
|
||||||
return deepcopy_graph(arg)
|
|
||||||
return arg
|
|
||||||
|
|
||||||
@functools.wraps(func)
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
new_args = tuple(map_fn(arg) for arg in args)
|
|
||||||
new_kwargs = {k: map_fn(v) for k, v in kwargs.items()}
|
|
||||||
return func(*new_args, **new_kwargs)
|
|
||||||
|
|
||||||
wrapper._is_transformation = True
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
def compose_transformations(
|
|
||||||
*args: Callable[[GraphModule], Optional[GraphModule]], inplace: bool = False
|
|
||||||
) -> GraphModule:
|
|
||||||
"""
|
|
||||||
Allows to compose transformations together and takes of:
|
|
||||||
|
|
||||||
1. Performing the transformations on a copy of the GraphModule if inplace is set to False, transformations that
|
|
||||||
are decorated with @transformation (which means that they are not modifying the original GraphModule) are
|
|
||||||
unwrapped to make them inplace.
|
|
||||||
2. Linting and recompiling only at the end of the composition for performance purposes.
|
|
||||||
"""
|
|
||||||
args = list(args)
|
|
||||||
if not inplace:
|
|
||||||
args.insert(0, deepcopy_graph)
|
|
||||||
|
|
||||||
for i, transformation in enumerate(args[:-1]):
|
|
||||||
sig = signature(transformation)
|
|
||||||
|
|
||||||
# Unwrapping @transformation decorated transformations as performing the transformations inplace or on a copy is
|
|
||||||
# already handled by this function.
|
|
||||||
if getattr(transformation, "_is_transformation", False):
|
|
||||||
transformation = transformation.__wrapped__
|
|
||||||
|
|
||||||
# Linting and recompiling only after the last transformation applied to make composition efficient.
|
|
||||||
if "lint_and_recompile" in sig.parameters:
|
|
||||||
args[i] = functools.partial(transformation, lint_and_recompile=False)
|
|
||||||
|
|
||||||
def reduce_func(f, g):
|
|
||||||
def compose_f_and_g(gm):
|
|
||||||
output_g = g(gm)
|
|
||||||
if output_g is None:
|
|
||||||
output_g = gm
|
|
||||||
output_f = f(output_g)
|
|
||||||
if output_f is None:
|
|
||||||
output_f = gm
|
|
||||||
return output_f
|
|
||||||
|
|
||||||
return compose_f_and_g
|
|
||||||
|
|
||||||
return functools.reduce(reduce_func, reversed(args), lambda x: x)
|
|
||||||
|
|
||||||
|
|
||||||
def remove_unused_nodes_(gm: GraphModule, lint_and_recompile: bool = True):
|
|
||||||
"""Removes all the unused nodes in a GraphModule."""
|
|
||||||
graph = gm.graph
|
|
||||||
for node in graph.nodes:
|
|
||||||
if not node.users and node.op not in ["placeholder", "output"]:
|
|
||||||
graph.erase_node(node)
|
|
||||||
|
|
||||||
if lint_and_recompile:
|
|
||||||
graph.lint()
|
|
||||||
gm.recompile()
|
|
||||||
|
|
||||||
|
|
||||||
def _insert_batch_size_node_(gm: GraphModule, lint_and_recompile: bool = True) -> Node:
|
|
||||||
"""Inserts a node that retrieves the batch size dynamically from the input of the model."""
|
|
||||||
graph = gm.graph
|
|
||||||
input_names = set(gm.dummy_inputs.keys())
|
|
||||||
batch_size_node = None
|
|
||||||
for node in graph.nodes:
|
|
||||||
if node.op == "placeholder" and node.name in input_names:
|
|
||||||
with graph.inserting_after(node):
|
|
||||||
batch_size_node = graph.call_method("size", args=(node, 0))
|
|
||||||
|
|
||||||
if batch_size_node is None:
|
|
||||||
raise ValueError("Could not insert the node that computes the batch size")
|
|
||||||
|
|
||||||
if lint_and_recompile:
|
|
||||||
graph.lint()
|
|
||||||
gm.recompile()
|
|
||||||
|
|
||||||
# Useful when retracing for quantization.
|
|
||||||
if hasattr(gm, "_qconfig_map"):
|
|
||||||
gm._qconfig_map[batch_size_node.name] = None
|
|
||||||
|
|
||||||
return batch_size_node
|
|
||||||
|
|
||||||
|
|
||||||
def _insert_encoder_sequence_length_node_(gm: GraphModule, lint_and_recompile: bool = True) -> Node:
|
|
||||||
"""Inserts a node that retrieves the encoder sequence length dynamically from the input of the model."""
|
|
||||||
graph = gm.graph
|
|
||||||
input_names = set(gm.dummy_inputs.keys())
|
|
||||||
encoder_sequence_length_node = None
|
|
||||||
for node in graph.nodes:
|
|
||||||
if node.op == "placeholder" and node.name in input_names and "decoder" not in node.name:
|
|
||||||
with graph.inserting_after(node):
|
|
||||||
# There are two cases to handle:
|
|
||||||
# 1. num_choices < 0, meaning that the model is not performing a "multiple choice" task, in this case the
|
|
||||||
# input shapes is [batch_size, sequence_length] => index 1
|
|
||||||
# 2. num_choices > 0, meaning the model is performing a "multiple choice" task, in this case the input
|
|
||||||
# shape is [batch_size, num_choices, sequence_length] => index 2
|
|
||||||
encoder_sequence_length_node = graph.call_method("size", args=(node, 1 if gm.num_choices < 0 else 2))
|
|
||||||
|
|
||||||
if encoder_sequence_length_node is None:
|
|
||||||
raise ValueError("Could not insert the node that computes the encoder sequence length")
|
|
||||||
|
|
||||||
if lint_and_recompile:
|
|
||||||
graph.lint()
|
|
||||||
gm.recompile()
|
|
||||||
|
|
||||||
# Useful when retracing for quantization.
|
|
||||||
if hasattr(gm, "_qconfig_map"):
|
|
||||||
gm._qconfig_map[encoder_sequence_length_node.name] = None
|
|
||||||
|
|
||||||
return encoder_sequence_length_node
|
|
||||||
|
|
||||||
|
|
||||||
def _change_view_methods_(
|
|
||||||
gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Changes arguments of view ops that refer to static batch size / sequence lengths to make them refer to the
|
|
||||||
batch_size / sequence_length nodes.
|
|
||||||
"""
|
|
||||||
graph = gm.graph
|
|
||||||
for node in graph.nodes:
|
|
||||||
if node.op == "call_method" and node.target == "view":
|
|
||||||
if isinstance(node.args[1], tuple):
|
|
||||||
node.args = (node.args[0], *node.args[1])
|
|
||||||
node.args = tuple((mapping.get(arg, arg) for arg in node.args))
|
|
||||||
|
|
||||||
if lint_and_recompile:
|
|
||||||
graph.lint()
|
|
||||||
gm.recompile()
|
|
||||||
|
|
||||||
|
|
||||||
def _patch_getitem_(
|
|
||||||
gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True
|
|
||||||
):
|
|
||||||
"""Patches getitem nodes by replacing current arguments to their corresponding values in mapping."""
|
|
||||||
# TODO: combine this with the patch_argument function which seems to do almost the same thing.
|
|
||||||
graph = gm.graph
|
|
||||||
for node in graph.nodes:
|
|
||||||
if node.op == "call_function" and node.target == operator.getitem:
|
|
||||||
indices = node.args[1]
|
|
||||||
if isinstance(indices, tuple):
|
|
||||||
new_indices = []
|
|
||||||
for idx in indices:
|
|
||||||
if isinstance(idx, slice):
|
|
||||||
new_indices.append(
|
|
||||||
slice(
|
|
||||||
mapping.get(idx.start, idx.start),
|
|
||||||
mapping.get(idx.stop, idx.stop),
|
|
||||||
mapping.get(idx.step, idx.step),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif isinstance(idx, int):
|
|
||||||
new_indices.append(mapping.get(idx, idx))
|
|
||||||
else:
|
|
||||||
new_indices.append(idx)
|
|
||||||
|
|
||||||
node.args = (node.args[0], tuple(new_indices))
|
|
||||||
else:
|
|
||||||
node.args = (node.args[0], mapping.get(node.args[1], node.args[1]))
|
|
||||||
|
|
||||||
if lint_and_recompile:
|
|
||||||
graph.lint()
|
|
||||||
gm.recompile()
|
|
||||||
|
|
||||||
|
|
||||||
def _patch_arguments_(
|
|
||||||
gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Patches node by replacing their argument to their corresponding values in mapping (supports regular types, tuples
|
|
||||||
and slices).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _patch_slice(s, mapping):
|
|
||||||
return slice(mapping.get(s.start, s.start), mapping.get(s.stop, s.stop), mapping.get(s.step, s.step))
|
|
||||||
|
|
||||||
graph = gm.graph
|
|
||||||
supported_types = (Node, str, int, float)
|
|
||||||
for node in graph.nodes:
|
|
||||||
new_args = []
|
|
||||||
for arg in node.args:
|
|
||||||
if isinstance(arg, tuple):
|
|
||||||
new_arg = []
|
|
||||||
for a in arg:
|
|
||||||
if isinstance(a, slice):
|
|
||||||
new_arg.append(_patch_slice(a, mapping))
|
|
||||||
else:
|
|
||||||
new_arg.append(mapping.get(a, a))
|
|
||||||
new_args.append(tuple(new_arg))
|
|
||||||
elif isinstance(arg, slice):
|
|
||||||
new_args.append(_patch_slice(arg, mapping))
|
|
||||||
elif isinstance(arg, supported_types):
|
|
||||||
new_args.append(mapping.get(arg, arg))
|
|
||||||
else:
|
|
||||||
new_args.append(arg)
|
|
||||||
node.args = tuple(new_args)
|
|
||||||
|
|
||||||
if lint_and_recompile:
|
|
||||||
graph.lint()
|
|
||||||
gm.recompile()
|
|
||||||
|
|
||||||
|
|
||||||
def transform_to_dynamic_input_(gm: GraphModule, is_retracing: bool = False):
|
|
||||||
"""Transformation that enables traced models to perform inference on dynamic input shapes."""
|
|
||||||
graph = gm.graph
|
|
||||||
static2dynamic = {}
|
|
||||||
|
|
||||||
# Inserting the nodes that will fetch the batch size and sequence lengths dynamically.
|
|
||||||
if gm.use_dynamic_batch_size:
|
|
||||||
batch_size_node = _insert_batch_size_node_(gm, lint_and_recompile=False)
|
|
||||||
static2dynamic[gm.static_batch_size] = batch_size_node
|
|
||||||
if gm.num_choices > 0:
|
|
||||||
with graph.inserting_after(batch_size_node):
|
|
||||||
static2dynamic[gm.static_batch_size * gm.num_choices] = graph.call_function(
|
|
||||||
operator.mul, args=(batch_size_node, gm.num_choices)
|
|
||||||
)
|
|
||||||
# Useful when retracing for quantization.
|
|
||||||
if hasattr(gm, "_qconfig_map"):
|
|
||||||
gm._qconfig_map[static2dynamic[gm.static_batch_size * gm.num_choices]] = None
|
|
||||||
|
|
||||||
if gm.use_dynamic_sequence_length:
|
|
||||||
encoder_sequence_length_node = _insert_encoder_sequence_length_node_(gm, lint_and_recompile=False)
|
|
||||||
static2dynamic[gm.static_sequence_length[0]] = encoder_sequence_length_node
|
|
||||||
|
|
||||||
# TODO: do the same for the decoder.
|
|
||||||
pass
|
|
||||||
|
|
||||||
_change_view_methods_(gm, static2dynamic, lint_and_recompile=False)
|
|
||||||
_patch_getitem_(gm, static2dynamic, lint_and_recompile=False)
|
|
||||||
|
|
||||||
remove_unused_nodes_(gm, lint_and_recompile=False)
|
|
||||||
|
|
||||||
graph.lint()
|
|
||||||
gm.recompile()
|
|
||||||
|
|
||||||
gm.static2dynamic = static2dynamic
|
|
||||||
gm.dynamic2static = {v: k for (k, v) in static2dynamic.items()}
|
|
||||||
@@ -231,8 +231,7 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
fx_ready_model_classes = all_model_classes
|
fx_compatible = True
|
||||||
fx_dynamic_ready_model_classes = all_model_classes
|
|
||||||
|
|
||||||
# special case for ForPreTraining model
|
# special case for ForPreTraining model
|
||||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
|
|||||||
@@ -444,8 +444,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
|
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
|
||||||
fx_ready_model_classes = all_model_classes
|
fx_compatible = True
|
||||||
fx_dynamic_ready_model_classes = all_model_classes
|
|
||||||
|
|
||||||
# special case for ForPreTraining model
|
# special case for ForPreTraining model
|
||||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
|
|||||||
@@ -116,8 +116,7 @@ class ModelTesterMixin:
|
|||||||
model_tester = None
|
model_tester = None
|
||||||
all_model_classes = ()
|
all_model_classes = ()
|
||||||
all_generative_model_classes = ()
|
all_generative_model_classes = ()
|
||||||
fx_ready_model_classes = ()
|
fx_compatible = False
|
||||||
fx_dynamic_ready_model_classes = ()
|
|
||||||
test_torchscript = True
|
test_torchscript = True
|
||||||
test_pruning = True
|
test_pruning = True
|
||||||
test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
@@ -666,19 +665,14 @@ class ModelTesterMixin:
|
|||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True)
|
self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True)
|
||||||
|
|
||||||
def test_torch_fx_dynamic_axes(self):
|
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
if not is_torch_fx_available() or not self.fx_compatible:
|
||||||
self._create_and_check_torch_fx_tracing(config, inputs_dict, dynamic_axes=True)
|
|
||||||
|
|
||||||
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False, dynamic_axes=False):
|
|
||||||
if not is_torch_fx_available():
|
|
||||||
return
|
return
|
||||||
|
|
||||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||||
configs_no_init.return_dict = False
|
configs_no_init.return_dict = False
|
||||||
|
|
||||||
model_classes = self.fx_ready_model_classes if not dynamic_axes else self.fx_dynamic_ready_model_classes
|
for model_class in self.all_model_classes:
|
||||||
for model_class in model_classes:
|
|
||||||
model = model_class(config=configs_no_init)
|
model = model_class(config=configs_no_init)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -687,8 +681,6 @@ class ModelTesterMixin:
|
|||||||
try:
|
try:
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||||
input_ids = inputs["input_ids"]
|
|
||||||
decoder_attention_mask = inputs["decoder_attention_mask"]
|
|
||||||
labels = inputs.get("labels", None)
|
labels = inputs.get("labels", None)
|
||||||
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
|
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
@@ -697,17 +689,7 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
model_output = model(**filtered_inputs)
|
model_output = model(**filtered_inputs)
|
||||||
|
|
||||||
batch_size = input_ids.shape[0]
|
traced_model = symbolic_trace(model, input_names)
|
||||||
encoder_sequence_length = input_ids.shape[1]
|
|
||||||
decoder_sequence_length = decoder_attention_mask.shape[1]
|
|
||||||
|
|
||||||
traced_model = symbolic_trace(
|
|
||||||
model,
|
|
||||||
input_names,
|
|
||||||
batch_size=batch_size if not dynamic_axes else -1,
|
|
||||||
sequence_length=[encoder_sequence_length, decoder_sequence_length] if not dynamic_axes else -1,
|
|
||||||
)
|
|
||||||
|
|
||||||
traced_output = traced_model(**filtered_inputs)
|
traced_output = traced_model(**filtered_inputs)
|
||||||
else:
|
else:
|
||||||
input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
||||||
@@ -729,23 +711,12 @@ class ModelTesterMixin:
|
|||||||
model_output = model(**filtered_inputs)
|
model_output = model(**filtered_inputs)
|
||||||
|
|
||||||
rank = len(input_ids.shape)
|
rank = len(input_ids.shape)
|
||||||
if rank == 2:
|
if rank not in [2, 3]:
|
||||||
batch_size, sequence_length = input_ids.shape
|
|
||||||
num_choices = -1
|
|
||||||
elif rank == 3:
|
|
||||||
batch_size, num_choices, sequence_length = input_ids.shape
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"symbolic_trace automatic parameters inference not implemented for input of rank {rank}."
|
f"symbolic_trace automatic parameters inference not implemented for input of rank {rank}."
|
||||||
)
|
)
|
||||||
|
|
||||||
traced_model = symbolic_trace(
|
traced_model = symbolic_trace(model, input_names)
|
||||||
model,
|
|
||||||
input_names,
|
|
||||||
batch_size=batch_size if not dynamic_axes else -1,
|
|
||||||
sequence_length=sequence_length if not dynamic_axes else -1,
|
|
||||||
num_choices=num_choices,
|
|
||||||
)
|
|
||||||
traced_output = traced_model(**filtered_inputs)
|
traced_output = traced_model(**filtered_inputs)
|
||||||
|
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
|
|||||||
@@ -209,8 +209,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
fx_ready_model_classes = all_model_classes
|
fx_compatible = True
|
||||||
fx_dynamic_ready_model_classes = all_model_classes
|
|
||||||
test_pruning = True
|
test_pruning = True
|
||||||
test_torchscript = True
|
test_torchscript = True
|
||||||
test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
|
|||||||
@@ -369,10 +369,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
all_generative_model_classes = (ElectraForCausalLM,) if is_torch_available() else ()
|
fx_compatible = True
|
||||||
|
|
||||||
fx_ready_model_classes = all_model_classes
|
|
||||||
fx_dynamic_ready_model_classes = all_model_classes
|
|
||||||
|
|
||||||
# special case for ForPreTraining model
|
# special case for ForPreTraining model
|
||||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
|
|||||||
@@ -433,7 +433,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
all_generative_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
|
all_generative_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
|
||||||
all_parallelizable_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
|
all_parallelizable_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
|
||||||
fx_ready_model_classes = all_model_classes
|
fx_compatible = True
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
test_model_parallel = True
|
test_model_parallel = True
|
||||||
|
|
||||||
|
|||||||
@@ -372,7 +372,7 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
|
|||||||
(GPTNeoModel, GPTNeoForCausalLM, GPTNeoForSequenceClassification) if is_torch_available() else ()
|
(GPTNeoModel, GPTNeoForCausalLM, GPTNeoForSequenceClassification) if is_torch_available() else ()
|
||||||
)
|
)
|
||||||
all_generative_model_classes = (GPTNeoForCausalLM,) if is_torch_available() else ()
|
all_generative_model_classes = (GPTNeoForCausalLM,) if is_torch_available() else ()
|
||||||
fx_ready_model_classes = all_model_classes
|
fx_compatible = True
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_model_parallel = False
|
test_model_parallel = False
|
||||||
|
|||||||
@@ -363,7 +363,7 @@ class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
all_generative_model_classes = (GPTJForCausalLM,) if is_torch_available() else ()
|
all_generative_model_classes = (GPTJForCausalLM,) if is_torch_available() else ()
|
||||||
fx_ready_model_classes = all_model_classes
|
fx_compatible = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
test_model_parallel = False
|
test_model_parallel = False
|
||||||
|
|||||||
@@ -283,9 +283,7 @@ class MegatronBertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
fx_ready_model_classes = all_model_classes
|
fx_compatible = True
|
||||||
fx_dynamic_ready_model_classes = all_model_classes
|
|
||||||
|
|
||||||
# test_resize_embeddings = False
|
# test_resize_embeddings = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
|
||||||
|
|||||||
@@ -269,8 +269,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
fx_ready_model_classes = all_model_classes
|
fx_compatible = True
|
||||||
fx_dynamic_ready_model_classes = all_model_classes
|
|
||||||
|
|
||||||
# special case for ForPreTraining model
|
# special case for ForPreTraining model
|
||||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
|
|||||||
@@ -356,6 +356,7 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else ()
|
all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else ()
|
||||||
|
fx_compatible = True
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = RobertaModelTester(self)
|
self.model_tester = RobertaModelTester(self)
|
||||||
|
|||||||
@@ -509,7 +509,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
|
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
|
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
|
||||||
fx_ready_model_classes = all_model_classes
|
fx_compatible = True
|
||||||
all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
|
all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = True
|
test_torchscript = True
|
||||||
|
|||||||
@@ -526,6 +526,7 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
|
|||||||
all_generative_model_classes = (
|
all_generative_model_classes = (
|
||||||
(XLNetLMHeadModel,) if is_torch_available() else ()
|
(XLNetLMHeadModel,) if is_torch_available() else ()
|
||||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||||
|
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
|
||||||
# XLNet has 2 QA models -> need to manually set the correct labels for one of them here
|
# XLNet has 2 QA models -> need to manually set the correct labels for one of them here
|
||||||
|
|||||||
Reference in New Issue
Block a user