[tests] remove flax-pt equivalence and cross tests (#36283)
This commit is contained in:
@@ -17,9 +17,8 @@ import inspect
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import transformers
|
||||
from transformers import WhisperConfig, is_flax_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
from transformers.utils import cached_property
|
||||
from transformers.utils.import_utils import is_datasets_available
|
||||
|
||||
@@ -45,7 +44,6 @@ if is_flax_available():
|
||||
WhisperFeatureExtractor,
|
||||
WhisperProcessor,
|
||||
)
|
||||
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
|
||||
from transformers.models.whisper.modeling_flax_whisper import sinusoidal_embedding_init
|
||||
|
||||
|
||||
@@ -245,99 +243,6 @@ class FlaxWhisperModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
||||
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
|
||||
# We override with a slightly higher tol value, as test recently became flaky
|
||||
super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)
|
||||
|
||||
# overwrite because of `input_features`
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_bf16_to_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class.__name__ == base_class.__name__:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
model.params = model.to_bf16(model.params)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite because of `input_features`
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_from_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class.__name__ == base_class.__name__:
|
||||
continue
|
||||
|
||||
model = base_class(config)
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# save pt model
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))
|
||||
|
||||
for key in base_param_from_head.keys():
|
||||
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite because of `input_features`
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_to_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class.__name__ == base_class.__name__:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite because of `input_features`
|
||||
def test_save_load_from_base(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -899,18 +804,3 @@ class WhisperEncoderModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
# WhisperEncoder does not have any base model
|
||||
def test_save_load_from_base(self):
|
||||
pass
|
||||
|
||||
# WhisperEncoder does not have any base model
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_from_base_pt(self):
|
||||
pass
|
||||
|
||||
# WhisperEncoder does not have any base model
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_to_base_pt(self):
|
||||
pass
|
||||
|
||||
# WhisperEncoder does not have any base model
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_bf16_to_base_pt(self):
|
||||
pass
|
||||
|
||||
@@ -32,7 +32,6 @@ import transformers
|
||||
from transformers import WhisperConfig
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
is_pt_flax_cross_test,
|
||||
require_flash_attn,
|
||||
require_non_xpu,
|
||||
require_torch,
|
||||
@@ -44,7 +43,7 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import cached_property, is_flax_available, is_torch_available, is_torchaudio_available
|
||||
from transformers.utils import cached_property, is_torch_available, is_torchaudio_available
|
||||
from transformers.utils.import_utils import is_datasets_available
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
@@ -155,15 +154,6 @@ if is_torchaudio_available():
|
||||
import torchaudio
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
|
||||
|
||||
def prepare_whisper_inputs_dict(
|
||||
config,
|
||||
input_features,
|
||||
@@ -1069,161 +1059,6 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
|
||||
# We override with a slightly higher tol value, as test recently became flaky
|
||||
super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
init_shape = (1,) + inputs_dict["input_features"].shape[1:]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
fx_model_class_name = "Flax" + model_class.__name__
|
||||
|
||||
if not hasattr(transformers, fx_model_class_name):
|
||||
self.skipTest(reason="No Flax model exists for this class")
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||
|
||||
# load PyTorch class
|
||||
pt_model = model_class(config).eval()
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
|
||||
# load Flax class
|
||||
fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)
|
||||
|
||||
# make sure only flax inputs are forward that actually exist in function args
|
||||
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||
|
||||
# prepare inputs
|
||||
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# remove function args that don't exist in Flax
|
||||
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||
|
||||
# send pytorch inputs to the correct device
|
||||
pt_inputs = {
|
||||
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
|
||||
}
|
||||
|
||||
# convert inputs to Flax
|
||||
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
fx_outputs = fx_model(**fx_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, input_shape=init_shape, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**fx_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
init_shape = (1,) + inputs_dict["input_features"].shape[1:]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
fx_model_class_name = "Flax" + model_class.__name__
|
||||
|
||||
if not hasattr(transformers, fx_model_class_name):
|
||||
self.skipTest(reason="No Flax model exists for this class")
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||
|
||||
# load PyTorch class
|
||||
pt_model = model_class(config).eval()
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
|
||||
# load Flax class
|
||||
fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)
|
||||
|
||||
# make sure only flax inputs are forward that actually exist in function args
|
||||
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||
|
||||
# prepare inputs
|
||||
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# remove function args that don't exist in Flax
|
||||
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||
|
||||
# send pytorch inputs to the correct device
|
||||
pt_inputs = {
|
||||
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
|
||||
}
|
||||
|
||||
# convert inputs to Flax
|
||||
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
|
||||
# make sure weights are tied in PyTorch
|
||||
pt_model.tie_weights()
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
fx_outputs = fx_model(**fx_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model_loaded.to(torch_device)
|
||||
pt_model_loaded.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
|
||||
|
||||
def test_mask_feature_prob(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.mask_feature_prob = 0.2
|
||||
@@ -3622,157 +3457,6 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
|
||||
def test_resize_tokens_embeddings(self):
|
||||
pass
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
init_shape = (1,) + inputs_dict["input_features"].shape[1:]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
fx_model_class_name = "Flax" + model_class.__name__
|
||||
|
||||
if not hasattr(transformers, fx_model_class_name):
|
||||
self.skipTest(reason="Flax model does not exist")
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||
|
||||
# load PyTorch class
|
||||
pt_model = model_class(config).eval()
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
|
||||
# load Flax class
|
||||
fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)
|
||||
|
||||
# make sure only flax inputs are forward that actually exist in function args
|
||||
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||
|
||||
# prepare inputs
|
||||
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# remove function args that don't exist in Flax
|
||||
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||
|
||||
# send pytorch inputs to the correct device
|
||||
pt_inputs = {
|
||||
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
|
||||
}
|
||||
|
||||
# convert inputs to Flax
|
||||
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
fx_outputs = fx_model(**fx_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, input_shape=init_shape, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**fx_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
init_shape = (1,) + inputs_dict["input_features"].shape[1:]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
fx_model_class_name = "Flax" + model_class.__name__
|
||||
|
||||
if not hasattr(transformers, fx_model_class_name):
|
||||
self.skipTest("Flax model does not exist")
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||
|
||||
# load PyTorch class
|
||||
pt_model = model_class(config).eval()
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
|
||||
# load Flax class
|
||||
fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)
|
||||
|
||||
# make sure only flax inputs are forward that actually exist in function args
|
||||
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||
|
||||
# prepare inputs
|
||||
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# remove function args that don't exist in Flax
|
||||
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||
|
||||
# send pytorch inputs to the correct device
|
||||
pt_inputs = {
|
||||
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
|
||||
}
|
||||
|
||||
# convert inputs to Flax
|
||||
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
|
||||
# make sure weights are tied in PyTorch
|
||||
pt_model.tie_weights()
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
fx_outputs = fx_model(**fx_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model_loaded.to(torch_device)
|
||||
pt_model_loaded.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
|
||||
|
||||
|
||||
class WhisperStandaloneDecoderModelTester:
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user