A cleaner and more scalable implementation of symbolic tracing (#11763)
Cleaner and more scalable implementation of symbolic tracing with torch.fx, and provides support for new architectures: - ALBERT - DistilBERT - MobileBERT - MegatronBERT - GPT2 - GPT Neo Co-authored-by: Michael Benayoun <michael@huggingface.co>
This commit is contained in:
@@ -1,11 +1,31 @@
|
||||
import dis
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
from typing import List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule, Node, Proxy, Tracer
|
||||
from torch.fx import Graph, GraphModule, Node, Proxy, Tracer
|
||||
from torch.fx.node import Argument
|
||||
|
||||
from . import PreTrainedModel
|
||||
from . import (
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
MODEL_FOR_PRETRAINING_MAPPING,
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
GPT2DoubleHeadsModel,
|
||||
PreTrainedModel,
|
||||
logging,
|
||||
)
|
||||
from .models.auto import get_values
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class HFProxy(Proxy):
|
||||
@@ -21,98 +41,10 @@ class HFProxy(Proxy):
|
||||
self.device = self.tracer.root.device
|
||||
self.dtype = next(self.tracer.root.parameters()).dtype
|
||||
|
||||
def dim(self):
|
||||
return len(self.tracer.encoder_shape)
|
||||
|
||||
def _shape(self, calling_frame):
|
||||
module = calling_frame.f_locals.get("self", None)
|
||||
is_decoder = hasattr(module, "is_decoder") and module.is_decoder
|
||||
return list(self.tracer.decoder_shape) if is_decoder else list(self.tracer.encoder_shape)
|
||||
|
||||
def size(self, dim=None):
|
||||
frame = inspect.currentframe()
|
||||
calling_frame = frame.f_back
|
||||
|
||||
# self.size can be called through the shape property, in which case we need to get the outer
|
||||
# frame, containing the meaningful information.
|
||||
if calling_frame.f_code.co_name == "shape":
|
||||
calling_frame = calling_frame.f_back
|
||||
|
||||
instructions = list(reversed(list(dis.get_instructions(calling_frame.f_code))[: calling_frame.f_lasti]))
|
||||
code_context = inspect.getframeinfo(calling_frame).code_context[0].strip()
|
||||
|
||||
shape = self._shape(calling_frame)
|
||||
|
||||
if calling_frame.f_code.co_name == "transpose_for_scores":
|
||||
# Provides the proper "x.size()" for:
|
||||
# new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
shape = shape + [-1]
|
||||
elif "context_layer" in calling_frame.f_locals:
|
||||
# Provides the proper "context_layer.size()" for:
|
||||
# new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
shape = shape + [-1, -1]
|
||||
elif calling_frame.f_locals.get("do_cross_attention", False):
|
||||
# Provides the proper shape for:
|
||||
# query_length = present_key_value_state[0].shape[2]
|
||||
# (modeling_t5.py)
|
||||
shape = list(self.tracer.encoder_shape)
|
||||
shape = shape[:1] + [-1] + shape[1:2]
|
||||
elif "key_length" in code_context or "encoder_seq_length" in code_context:
|
||||
shape = list(self.tracer.encoder_shape)
|
||||
elif "lm_logits.size(-1)" in code_context:
|
||||
shape = [self.tracer.root.config.vocab_size]
|
||||
elif "start_positions" in code_context or "end_positions" in code_context:
|
||||
# For question answering tasks.
|
||||
shape = [1]
|
||||
elif "num_choices" in code_context:
|
||||
if self.tracer.num_choices <= 0:
|
||||
raise ValueError("num_choices must be given to the CustomTracer for MultipleChoice tasks.")
|
||||
shape = shape[:1] + [self.tracer.num_choices] + shape[1:]
|
||||
elif "hidden_states.s" in code_context:
|
||||
shape = shape + [self.tracer.root.config.hidden_size]
|
||||
else:
|
||||
# Default case:
|
||||
# - If self.size is called for an unpacking, retrieves the corresponding unpacking
|
||||
# instruction, and returns the shape padded as much as necessary to match the expected
|
||||
# number of items.
|
||||
# - If self.size is called outside of an unpacking context, simply return the shape.
|
||||
is_unpack = False
|
||||
|
||||
for inst in instructions:
|
||||
if inst.opname == "UNPACK_SEQUENCE":
|
||||
is_unpack = True
|
||||
break
|
||||
|
||||
if is_unpack and inst.argval >= 3:
|
||||
shape += [self.tracer.root.config.hidden_size]
|
||||
dummy_values = [1] * (inst.argval - 3)
|
||||
shape += dummy_values
|
||||
|
||||
if dim is not None:
|
||||
return shape[dim]
|
||||
|
||||
return tuple(shape)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.size()
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
frame = inspect.currentframe()
|
||||
calling_frame = frame.f_back
|
||||
code_context = inspect.getframeinfo(calling_frame).code_context[0].strip()
|
||||
if calling_frame.f_code.co_name == "apply_chunking_to_forward":
|
||||
# Returning True to every assertion in "apply_chuncking_to_forward"
|
||||
return True
|
||||
elif "assert" in code_context:
|
||||
# Returning True to any assertion.
|
||||
return True
|
||||
elif calling_frame.f_code.co_name == "get_extended_attention_mask":
|
||||
# Corresponding to:
|
||||
# if causal_mask.shape[1] < attention_mask.shape[1]:
|
||||
return calling_frame.f_back.f_locals["past_key_values"][0] is not None
|
||||
raise NotImplementedError("__bool__ was called for CustomProxy, but this case is not covered yet.")
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
pass
|
||||
|
||||
@@ -120,28 +52,203 @@ class HFProxy(Proxy):
|
||||
return False
|
||||
|
||||
|
||||
def _wrap_method_for_model_recording(model, method_name, cache_name):
|
||||
"""Helper function that wraps a torch.Tensor method to record its outputs during forward pass."""
|
||||
method = getattr(torch.Tensor, method_name)
|
||||
|
||||
@functools.wraps(method)
|
||||
def wrapped(*args, **kwargs):
|
||||
if not hasattr(model, cache_name):
|
||||
setattr(model, cache_name, [])
|
||||
cache = getattr(model, cache_name)
|
||||
res = method(*args, **kwargs)
|
||||
cache.append(res)
|
||||
return res
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def _create_recorded_proxy_method(proxy, method_name, cache_name):
|
||||
"""
|
||||
Helper function that sets a recorded torch.Tensor method as a HFProxy method that will use the recorded values
|
||||
during symbolic tracing.
|
||||
"""
|
||||
|
||||
def method(self, *args, **kwargs):
|
||||
cache = getattr(self.tracer.root, cache_name)
|
||||
res = cache.pop(0)
|
||||
return res
|
||||
|
||||
method.__name__ = method_name
|
||||
bound_method = method.__get__(proxy, proxy.__class__)
|
||||
setattr(proxy, method_name, bound_method)
|
||||
|
||||
|
||||
def _wrap_method_for_model_tracing(model, method_name, cache_name):
|
||||
"""
|
||||
Helper function that sets a recorded torch.Tensor method as a torch.Tensor method that will use the recorded values
|
||||
during symbolic tracing.
|
||||
"""
|
||||
|
||||
original_method = getattr(torch.Tensor, method_name)
|
||||
|
||||
@functools.wraps(original_method)
|
||||
def method(*args, **kwargs):
|
||||
cache = getattr(model, cache_name)
|
||||
res = cache.pop(0)
|
||||
return res
|
||||
|
||||
setattr(torch.Tensor, method_name, method)
|
||||
|
||||
if method_name == "size":
|
||||
setattr(torch.Tensor, "shape", property(getattr(torch.Tensor, method_name)))
|
||||
|
||||
|
||||
def _monkey_patch_tensor_methods_for_model_recording(model, method_names):
|
||||
"""
|
||||
Helper function that patches torch.Tensor methods (specified by the method_names list) to record model inference
|
||||
before symbolic tracing.
|
||||
"""
|
||||
cache_names = dict()
|
||||
original_methods = dict()
|
||||
for method_name in method_names:
|
||||
cache_name = f"cache_{method_name}"
|
||||
cache_names[method_name] = cache_name
|
||||
if not hasattr(torch.Tensor, method_name):
|
||||
logger.info(f"torch.Tensor has no method called {method_name}, skipping patching.")
|
||||
continue
|
||||
original_methods[method_name] = getattr(torch.Tensor, method_name)
|
||||
setattr(torch.Tensor, method_name, _wrap_method_for_model_recording(model, method_name, cache_name))
|
||||
|
||||
if method_name == "size":
|
||||
original_methods["shape"] = torch.Tensor.shape
|
||||
setattr(torch.Tensor, "shape", property(getattr(torch.Tensor, method_name)))
|
||||
|
||||
return cache_names, original_methods
|
||||
|
||||
|
||||
def _reset_tensor_methods(original_methods):
|
||||
"""Helper function that resets the monkey patched torch.Tensor methods to their original values."""
|
||||
for name, method in original_methods.items():
|
||||
setattr(torch.Tensor, name, method)
|
||||
|
||||
|
||||
class HFTracer(Tracer):
|
||||
"""
|
||||
Tracer that is able to symbolically trace models from the library (currently BERT, ELECTRA and T5). To do that, it
|
||||
uses the HFProxy instead of the regular PyTorch torch.fx.Proxy.
|
||||
Tracer that is able to symbolically trace models from the library. To do that, it uses the HFProxy instead of the
|
||||
regular PyTorch torch.fx.Proxy.
|
||||
"""
|
||||
|
||||
default_methods_to_record = {"__bool__", "size", "dim"}
|
||||
|
||||
def __init__(self, batch_size=1, sequence_length=[128, 128], num_choices=-1):
|
||||
super().__init__()
|
||||
encoder_sequence_length = sequence_length[0] if isinstance(sequence_length, (list, tuple)) else sequence_length
|
||||
decoder_sequence_length = sequence_length[1] if isinstance(sequence_length, (list, tuple)) else -1
|
||||
decoder_sequence_length = (
|
||||
sequence_length[1] if isinstance(sequence_length, (list, tuple)) else encoder_sequence_length
|
||||
)
|
||||
self.encoder_shape = [batch_size, encoder_sequence_length]
|
||||
self.decoder_shape = (
|
||||
[batch_size, decoder_sequence_length] if decoder_sequence_length > 0 else list(self.encoder_shape)
|
||||
)
|
||||
self.num_choices = num_choices
|
||||
if self.num_choices > 0:
|
||||
self.encoder_shape[0] *= self.num_choices
|
||||
self.encoder_shape = [batch_size, self.num_choices, encoder_sequence_length]
|
||||
self.decoder_shape = [batch_size, self.num_choices, decoder_sequence_length]
|
||||
|
||||
self.prev_module = None
|
||||
self.recorded_methods = None
|
||||
|
||||
def proxy(self, node: Node):
|
||||
return HFProxy(node, self)
|
||||
p = HFProxy(node, self)
|
||||
if self.recorded_methods:
|
||||
for method_name, cache_name in self.recorded_methods.items():
|
||||
_create_recorded_proxy_method(p, method_name, cache_name)
|
||||
return p
|
||||
|
||||
def _generate_dummy_input(self, model, input_name):
|
||||
"""Generates dummy input for model inference recording."""
|
||||
model_class = model.__class__
|
||||
device = model.device
|
||||
inputs_dict = dict()
|
||||
|
||||
if input_name in ["labels", "start_positions", "end_positions"]:
|
||||
batch_size = self.encoder_shape[0]
|
||||
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||
inputs_dict["labels"] = torch.ones(batch_size, dtype=torch.long, device=device)
|
||||
elif model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
|
||||
inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||
inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||
elif model_class in [
|
||||
*get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
|
||||
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING),
|
||||
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
|
||||
]:
|
||||
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||
elif model_class in [
|
||||
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
|
||||
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING),
|
||||
*get_values(MODEL_FOR_MASKED_LM_MAPPING),
|
||||
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
|
||||
GPT2DoubleHeadsModel,
|
||||
]:
|
||||
inputs_dict["labels"] = torch.zeros(self.decoder_shape, dtype=torch.long, device=device)
|
||||
elif model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
|
||||
inputs_dict["labels"] = torch.zeros(self.encoder_shape, dtype=torch.long, device=device)
|
||||
else:
|
||||
raise NotImplementedError(f"{model_class} not supported yet.")
|
||||
|
||||
elif "mask" in input_name or "ids" in input_name:
|
||||
shape = self.encoder_shape if "decoder" not in input_name else self.decoder_shape
|
||||
inputs_dict[input_name] = torch.ones(shape, dtype=torch.long, device=device)
|
||||
else:
|
||||
shape = self.encoder_shape if "decoder" not in input_name else self.decoder_shape
|
||||
shape += [model.config.hidden_size]
|
||||
inputs_dict[input_name] = torch.ones(shape, dtype=torch.float, device=device)
|
||||
|
||||
return inputs_dict
|
||||
|
||||
def record(self, model, input_names, method_names=None):
|
||||
"""
|
||||
Records torch.Tensor method outputs (specified by the method_names list) that will then be used during symbolic
|
||||
tracing.
|
||||
"""
|
||||
if method_names is None:
|
||||
method_names = self.default_methods_to_record
|
||||
|
||||
inputs = dict()
|
||||
for input_name in input_names:
|
||||
inputs.update(self._generate_dummy_input(model, input_name))
|
||||
|
||||
clone = copy.deepcopy(model)
|
||||
cache_names, original_methods = _monkey_patch_tensor_methods_for_model_recording(clone, method_names)
|
||||
self.original_methods = original_methods
|
||||
|
||||
clone(**inputs)
|
||||
|
||||
_reset_tensor_methods(original_methods)
|
||||
|
||||
self.recorded_methods = {
|
||||
method_name: cache_name for method_name, cache_name in cache_names.items() if hasattr(clone, cache_name)
|
||||
}
|
||||
|
||||
for cache_name in self.recorded_methods.values():
|
||||
setattr(model, cache_name, getattr(clone, cache_name))
|
||||
|
||||
def trace(self, root: PreTrainedModel, concrete_args: Optional[Dict[str, Any]] = None, method_names=None) -> Graph:
|
||||
sig = inspect.signature(root.forward)
|
||||
input_names = sig.parameters.keys() - concrete_args.keys()
|
||||
|
||||
self.record(root, input_names, method_names=method_names)
|
||||
|
||||
for method_name, cache_name in self.recorded_methods.items():
|
||||
_wrap_method_for_model_tracing(root, method_name, cache_name)
|
||||
|
||||
graph = super().trace(root, concrete_args=concrete_args)
|
||||
|
||||
_reset_tensor_methods(self.original_methods)
|
||||
|
||||
return graph
|
||||
|
||||
def _insert_module_as_submodule(self, mod):
|
||||
"""
|
||||
@@ -202,6 +309,11 @@ class HFTracer(Tracer):
|
||||
self.prev_module = path
|
||||
return path
|
||||
|
||||
def create_arg(self, a: Any) -> Argument:
|
||||
if isinstance(a, range):
|
||||
return super().create_arg(list(a))
|
||||
return super().create_arg(a)
|
||||
|
||||
|
||||
def symbolic_trace(
|
||||
model: PreTrainedModel,
|
||||
@@ -249,6 +361,7 @@ def symbolic_trace(
|
||||
concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
|
||||
|
||||
tracer = HFTracer(batch_size=batch_size, sequence_length=sequence_length, num_choices=num_choices)
|
||||
|
||||
traced_graph = tracer.trace(model, concrete_args=concrete_args)
|
||||
traced = torch.fx.GraphModule(model, traced_graph)
|
||||
|
||||
|
||||
@@ -229,6 +229,7 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
fx_ready_model_classes = all_model_classes
|
||||
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
|
||||
@@ -600,9 +600,9 @@ class ModelTesterMixin:
|
||||
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
prepared_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
|
||||
model_output = model(**prepared_inputs)
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
encoder_sequence_length = input_ids.shape[1]
|
||||
@@ -615,26 +615,37 @@ class ModelTesterMixin:
|
||||
sequence_length=[encoder_sequence_length, decoder_sequence_length],
|
||||
)
|
||||
|
||||
traced_output = traced_model(**prepared_inputs)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
|
||||
else:
|
||||
input_ids = inputs["input_ids"]
|
||||
labels = inputs.get("labels", None)
|
||||
input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
||||
input_ids = inputs["input_ids"]
|
||||
|
||||
labels = inputs.get("labels", None)
|
||||
start_positions = inputs.get("start_positions", None)
|
||||
end_positions = inputs.get("end_positions", None)
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
prepared_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
if start_positions is not None:
|
||||
input_names.append("start_positions")
|
||||
if end_positions is not None:
|
||||
input_names.append("end_positions")
|
||||
|
||||
model_output = model(**prepared_inputs)
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = filtered_inputs.keys()
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||
sequence_length = input_ids.shape[2]
|
||||
num_choices = input_ids.shape[1]
|
||||
else:
|
||||
sequence_length = input_ids.shape[1]
|
||||
rank = len(input_ids.shape)
|
||||
if rank == 2:
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
num_choices = -1
|
||||
elif rank == 3:
|
||||
batch_size, num_choices, sequence_length = input_ids.shape
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"symbolic_trace automatic parameters inference not implemented for input of rank {rank}."
|
||||
)
|
||||
|
||||
traced_model = symbolic_trace(
|
||||
model,
|
||||
@@ -643,14 +654,31 @@ class ModelTesterMixin:
|
||||
sequence_length=sequence_length,
|
||||
num_choices=num_choices,
|
||||
)
|
||||
traced_output = traced_model(**prepared_inputs)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
|
||||
except RuntimeError:
|
||||
self.fail("Couldn't trace module.")
|
||||
|
||||
def flatten_output(output):
|
||||
flatten = []
|
||||
for x in output:
|
||||
if isinstance(x, (tuple, list)):
|
||||
flatten += flatten_output(x)
|
||||
elif not isinstance(x, torch.Tensor):
|
||||
continue
|
||||
else:
|
||||
flatten.append(x)
|
||||
return flatten
|
||||
|
||||
model_output = flatten_output(model_output)
|
||||
traced_output = flatten_output(traced_output)
|
||||
num_outputs = len(model_output)
|
||||
outputs_are_close = all(torch.allclose(model_output[i], traced_output[i]) for i in range(num_outputs))
|
||||
self.assertTrue(outputs_are_close)
|
||||
|
||||
for i in range(num_outputs):
|
||||
self.assertTrue(
|
||||
torch.allclose(model_output[i], traced_output[i]),
|
||||
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
def test_headmasking(self):
|
||||
if not self.test_head_masking:
|
||||
|
||||
@@ -208,6 +208,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else None
|
||||
)
|
||||
fx_ready_model_classes = all_model_classes
|
||||
test_pruning = True
|
||||
test_torchscript = True
|
||||
test_resize_embeddings = True
|
||||
|
||||
@@ -399,6 +399,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
)
|
||||
all_generative_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
|
||||
all_parallelizable_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
|
||||
fx_ready_model_classes = all_model_classes
|
||||
test_missing_keys = False
|
||||
test_model_parallel = True
|
||||
|
||||
|
||||
@@ -276,6 +276,7 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
|
||||
|
||||
all_model_classes = (GPTNeoModel, GPTNeoForCausalLM) if is_torch_available() else ()
|
||||
all_generative_model_classes = (GPTNeoForCausalLM,) if is_torch_available() else ()
|
||||
fx_ready_model_classes = all_model_classes
|
||||
test_missing_keys = False
|
||||
test_pruning = False
|
||||
test_model_parallel = False
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2021 NVIDIA Corporation. All rights reserved.
|
||||
#
|
||||
@@ -282,6 +281,7 @@ class MegatronBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
fx_ready_model_classes = all_model_classes
|
||||
|
||||
# test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
|
||||
@@ -267,6 +267,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
fx_ready_model_classes = all_model_classes
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# special case for ForPreTraining model
|
||||
|
||||
Reference in New Issue
Block a user