Enable more test_torchscript (#16679)
* update _create_and_check_torchscript * Enable test_torchscript * clear_class_registry Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -617,19 +617,21 @@ class ModelTesterMixin:
|
||||
model.eval()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
main_input_name = model_class.main_input_name
|
||||
|
||||
try:
|
||||
if model.config.is_encoder_decoder:
|
||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||
input_ids = inputs["input_ids"]
|
||||
main_input = inputs[main_input_name]
|
||||
attention_mask = inputs["attention_mask"]
|
||||
decoder_input_ids = inputs["decoder_input_ids"]
|
||||
decoder_attention_mask = inputs["decoder_attention_mask"]
|
||||
traced_model = torch.jit.trace(
|
||||
model, (input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
|
||||
model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
|
||||
)
|
||||
else:
|
||||
input_ids = inputs["input_ids"]
|
||||
traced_model = torch.jit.trace(model, input_ids)
|
||||
main_input = inputs[main_input_name]
|
||||
traced_model = torch.jit.trace(model, main_input)
|
||||
except RuntimeError:
|
||||
self.fail("Couldn't trace module.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user