[Pipelines] Add revision tag to all default pipelines (#17667)
* trigger test failure * upload revision poc * Update src/transformers/pipelines/base.py Co-authored-by: Julien Chaumond <julien@huggingface.co> * up * add test * correct some stuff * Update src/transformers/pipelines/__init__.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * correct require flag Co-authored-by: Julien Chaumond <julien@huggingface.co> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
4f8361afe7
commit
e4d2588573
@@ -30,7 +30,7 @@ from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, Aut
|
||||
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
|
||||
from ..tokenization_utils import PreTrainedTokenizer
|
||||
from ..tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from ..utils import http_get, is_tf_available, is_torch_available, logging
|
||||
from ..utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT, http_get, is_tf_available, is_torch_available, logging
|
||||
from .audio_classification import AudioClassificationPipeline
|
||||
from .automatic_speech_recognition import AutomaticSpeechRecognitionPipeline
|
||||
from .base import (
|
||||
@@ -41,7 +41,7 @@ from .base import (
|
||||
Pipeline,
|
||||
PipelineDataFormat,
|
||||
PipelineException,
|
||||
get_default_model,
|
||||
get_default_model_and_revision,
|
||||
infer_framework_load_model,
|
||||
)
|
||||
from .conversational import Conversation, ConversationalPipeline
|
||||
@@ -131,21 +131,21 @@ SUPPORTED_TASKS = {
|
||||
"impl": AudioClassificationPipeline,
|
||||
"tf": (),
|
||||
"pt": (AutoModelForAudioClassification,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "superb/wav2vec2-base-superb-ks"}},
|
||||
"default": {"model": {"pt": ("superb/wav2vec2-base-superb-ks", "372e048")}},
|
||||
"type": "audio",
|
||||
},
|
||||
"automatic-speech-recognition": {
|
||||
"impl": AutomaticSpeechRecognitionPipeline,
|
||||
"tf": (),
|
||||
"pt": (AutoModelForCTC, AutoModelForSpeechSeq2Seq) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "facebook/wav2vec2-base-960h"}},
|
||||
"default": {"model": {"pt": ("facebook/wav2vec2-base-960h", "55bb623")}},
|
||||
"type": "multimodal",
|
||||
},
|
||||
"feature-extraction": {
|
||||
"impl": FeatureExtractionPipeline,
|
||||
"tf": (TFAutoModel,) if is_tf_available() else (),
|
||||
"pt": (AutoModel,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}},
|
||||
"default": {"model": {"pt": ("distilbert-base-cased", "935ac13"), "tf": ("distilbert-base-cased", "935ac13")}},
|
||||
"type": "multimodal",
|
||||
},
|
||||
"text-classification": {
|
||||
@@ -154,8 +154,8 @@ SUPPORTED_TASKS = {
|
||||
"pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),
|
||||
"default": {
|
||||
"model": {
|
||||
"pt": "distilbert-base-uncased-finetuned-sst-2-english",
|
||||
"tf": "distilbert-base-uncased-finetuned-sst-2-english",
|
||||
"pt": ("distilbert-base-uncased-finetuned-sst-2-english", "af0f99b"),
|
||||
"tf": ("distilbert-base-uncased-finetuned-sst-2-english", "af0f99b"),
|
||||
},
|
||||
},
|
||||
"type": "text",
|
||||
@@ -166,8 +166,8 @@ SUPPORTED_TASKS = {
|
||||
"pt": (AutoModelForTokenClassification,) if is_torch_available() else (),
|
||||
"default": {
|
||||
"model": {
|
||||
"pt": "dbmdz/bert-large-cased-finetuned-conll03-english",
|
||||
"tf": "dbmdz/bert-large-cased-finetuned-conll03-english",
|
||||
"pt": ("dbmdz/bert-large-cased-finetuned-conll03-english", "f2482bf"),
|
||||
"tf": ("dbmdz/bert-large-cased-finetuned-conll03-english", "f2482bf"),
|
||||
},
|
||||
},
|
||||
"type": "text",
|
||||
@@ -177,7 +177,10 @@ SUPPORTED_TASKS = {
|
||||
"tf": (TFAutoModelForQuestionAnswering,) if is_tf_available() else (),
|
||||
"pt": (AutoModelForQuestionAnswering,) if is_torch_available() else (),
|
||||
"default": {
|
||||
"model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
|
||||
"model": {
|
||||
"pt": ("distilbert-base-cased-distilled-squad", "626af31"),
|
||||
"tf": ("distilbert-base-cased-distilled-squad", "626af31"),
|
||||
},
|
||||
},
|
||||
"type": "text",
|
||||
},
|
||||
@@ -187,9 +190,8 @@ SUPPORTED_TASKS = {
|
||||
"tf": (TFAutoModelForTableQuestionAnswering,) if is_tf_available() else (),
|
||||
"default": {
|
||||
"model": {
|
||||
"pt": "google/tapas-base-finetuned-wtq",
|
||||
"tokenizer": "google/tapas-base-finetuned-wtq",
|
||||
"tf": "google/tapas-base-finetuned-wtq",
|
||||
"pt": ("google/tapas-base-finetuned-wtq", "69ceee2"),
|
||||
"tf": ("google/tapas-base-finetuned-wtq", "69ceee2"),
|
||||
},
|
||||
},
|
||||
"type": "text",
|
||||
@@ -199,11 +201,7 @@ SUPPORTED_TASKS = {
|
||||
"pt": (AutoModelForVisualQuestionAnswering,) if is_torch_available() else (),
|
||||
"tf": (),
|
||||
"default": {
|
||||
"model": {
|
||||
"pt": "dandelin/vilt-b32-finetuned-vqa",
|
||||
"tokenizer": "dandelin/vilt-b32-finetuned-vqa",
|
||||
"feature_extractor": "dandelin/vilt-b32-finetuned-vqa",
|
||||
},
|
||||
"model": {"pt": ("dandelin/vilt-b32-finetuned-vqa", "4355f59")},
|
||||
},
|
||||
"type": "multimodal",
|
||||
},
|
||||
@@ -211,14 +209,14 @@ SUPPORTED_TASKS = {
|
||||
"impl": FillMaskPipeline,
|
||||
"tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (),
|
||||
"pt": (AutoModelForMaskedLM,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}},
|
||||
"default": {"model": {"pt": ("distilroberta-base", "ec58a5b"), "tf": ("distilroberta-base", "ec58a5b")}},
|
||||
"type": "text",
|
||||
},
|
||||
"summarization": {
|
||||
"impl": SummarizationPipeline,
|
||||
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
|
||||
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}},
|
||||
"default": {"model": {"pt": ("sshleifer/distilbart-cnn-12-6", "a4f8f3e"), "tf": ("t5-small", "d769bba")}},
|
||||
"type": "text",
|
||||
},
|
||||
# This task is a special case as it's parametrized by SRC, TGT languages.
|
||||
@@ -227,9 +225,9 @@ SUPPORTED_TASKS = {
|
||||
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
|
||||
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
|
||||
"default": {
|
||||
("en", "fr"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||
("en", "de"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||
("en", "ro"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||
("en", "fr"): {"model": {"pt": ("t5-base", "686f1db"), "tf": ("t5-base", "686f1db")}},
|
||||
("en", "de"): {"model": {"pt": ("t5-base", "686f1db"), "tf": ("t5-base", "686f1db")}},
|
||||
("en", "ro"): {"model": {"pt": ("t5-base", "686f1db"), "tf": ("t5-base", "686f1db")}},
|
||||
},
|
||||
"type": "text",
|
||||
},
|
||||
@@ -237,14 +235,14 @@ SUPPORTED_TASKS = {
|
||||
"impl": Text2TextGenerationPipeline,
|
||||
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
|
||||
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||
"default": {"model": {"pt": ("t5-base", "686f1db"), "tf": ("t5-base", "686f1db")}},
|
||||
"type": "text",
|
||||
},
|
||||
"text-generation": {
|
||||
"impl": TextGenerationPipeline,
|
||||
"tf": (TFAutoModelForCausalLM,) if is_tf_available() else (),
|
||||
"pt": (AutoModelForCausalLM,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
|
||||
"default": {"model": {"pt": ("gpt2", "6c0e608"), "tf": ("gpt2", "6c0e608")}},
|
||||
"type": "text",
|
||||
},
|
||||
"zero-shot-classification": {
|
||||
@@ -252,9 +250,8 @@ SUPPORTED_TASKS = {
|
||||
"tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),
|
||||
"pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),
|
||||
"default": {
|
||||
"model": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
|
||||
"config": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
|
||||
"tokenizer": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
|
||||
"model": {"pt": ("facebook/bart-large-mnli", "c626438"), "tf": ("roberta-large-mnli", "130fb28")},
|
||||
"config": {"pt": ("facebook/bart-large-mnli", "c626438"), "tf": ("roberta-large-mnli", "130fb28")},
|
||||
},
|
||||
"type": "text",
|
||||
},
|
||||
@@ -262,35 +259,42 @@ SUPPORTED_TASKS = {
|
||||
"impl": ZeroShotImageClassificationPipeline,
|
||||
"tf": (TFAutoModel,) if is_tf_available() else (),
|
||||
"pt": (AutoModel,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "openai/clip-vit-base-patch32", "tf": "openai/clip-vit-base-patch32"}},
|
||||
"default": {
|
||||
"model": {
|
||||
"pt": ("openai/clip-vit-base-patch32", "f4881ba"),
|
||||
"tf": ("openai/clip-vit-base-patch32", "f4881ba"),
|
||||
}
|
||||
},
|
||||
"type": "multimodal",
|
||||
},
|
||||
"conversational": {
|
||||
"impl": ConversationalPipeline,
|
||||
"tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (),
|
||||
"pt": (AutoModelForSeq2SeqLM, AutoModelForCausalLM) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}},
|
||||
"default": {
|
||||
"model": {"pt": ("microsoft/DialoGPT-medium", "8bada3b"), "tf": ("microsoft/DialoGPT-medium", "8bada3b")}
|
||||
},
|
||||
"type": "text",
|
||||
},
|
||||
"image-classification": {
|
||||
"impl": ImageClassificationPipeline,
|
||||
"tf": (),
|
||||
"pt": (AutoModelForImageClassification,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "google/vit-base-patch16-224"}},
|
||||
"default": {"model": {"pt": ("google/vit-base-patch16-224", "5dca96d")}},
|
||||
"type": "image",
|
||||
},
|
||||
"image-segmentation": {
|
||||
"impl": ImageSegmentationPipeline,
|
||||
"tf": (),
|
||||
"pt": (AutoModelForImageSegmentation, AutoModelForSemanticSegmentation) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "facebook/detr-resnet-50-panoptic"}},
|
||||
"default": {"model": {"pt": ("facebook/detr-resnet-50-panoptic", "fc15262")}},
|
||||
"type": "image",
|
||||
},
|
||||
"object-detection": {
|
||||
"impl": ObjectDetectionPipeline,
|
||||
"tf": (),
|
||||
"pt": (AutoModelForObjectDetection,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "facebook/detr-resnet-50"}},
|
||||
"default": {"model": {"pt": ("facebook/detr-resnet-50", "2729413")}},
|
||||
"type": "image",
|
||||
},
|
||||
}
|
||||
@@ -545,8 +549,13 @@ def pipeline(
|
||||
# Use default model/config/tokenizer for the task if no model is provided
|
||||
if model is None:
|
||||
# At that point framework might still be undetermined
|
||||
model = get_default_model(targeted_task, framework, task_options)
|
||||
logger.warning(f"No model was supplied, defaulted to {model} (https://huggingface.co/{model})")
|
||||
model, default_revision = get_default_model_and_revision(targeted_task, framework, task_options)
|
||||
revision = revision if revision is not None else default_revision
|
||||
logger.warning(
|
||||
f"No model was supplied, defaulted to {model} and revision"
|
||||
f" {revision} ({HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{model}).\n"
|
||||
"Using a pipeline without specifying a model name and revision in production is not recommended."
|
||||
)
|
||||
|
||||
# Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained
|
||||
model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token)
|
||||
|
||||
@@ -341,7 +341,9 @@ def get_framework(model, revision: Optional[str] = None):
|
||||
return framework
|
||||
|
||||
|
||||
def get_default_model(targeted_task: Dict, framework: Optional[str], task_options: Optional[Any]) -> str:
|
||||
def get_default_model_and_revision(
|
||||
targeted_task: Dict, framework: Optional[str], task_options: Optional[Any]
|
||||
) -> Union[str, Tuple[str, str]]:
|
||||
"""
|
||||
Select a default model to use for a given task. Defaults to pytorch if ambiguous.
|
||||
|
||||
|
||||
@@ -22,6 +22,8 @@ from abc import abstractmethod
|
||||
from functools import lru_cache
|
||||
from unittest import skipIf
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import (
|
||||
FEATURE_EXTRACTOR_MAPPING,
|
||||
TOKENIZER_MAPPING,
|
||||
@@ -35,7 +37,15 @@ from transformers import (
|
||||
)
|
||||
from transformers.pipelines import get_task
|
||||
from transformers.pipelines.base import _pad
|
||||
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
nested_simplify,
|
||||
require_scatter,
|
||||
require_tensorflow_probability,
|
||||
require_tf,
|
||||
require_torch,
|
||||
slow,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -461,8 +471,8 @@ class PipelinePadTest(unittest.TestCase):
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
@require_torch
|
||||
class PipelineUtilsTest(unittest.TestCase):
|
||||
@require_torch
|
||||
def test_pipeline_dataset(self):
|
||||
from transformers.pipelines.pt_utils import PipelineDataset
|
||||
|
||||
@@ -476,6 +486,7 @@ class PipelineUtilsTest(unittest.TestCase):
|
||||
outputs = [dataset[i] for i in range(4)]
|
||||
self.assertEqual(outputs, [2, 3, 4, 5])
|
||||
|
||||
@require_torch
|
||||
def test_pipeline_iterator(self):
|
||||
from transformers.pipelines.pt_utils import PipelineIterator
|
||||
|
||||
@@ -490,6 +501,7 @@ class PipelineUtilsTest(unittest.TestCase):
|
||||
outputs = [item for item in dataset]
|
||||
self.assertEqual(outputs, [2, 3, 4, 5])
|
||||
|
||||
@require_torch
|
||||
def test_pipeline_iterator_no_len(self):
|
||||
from transformers.pipelines.pt_utils import PipelineIterator
|
||||
|
||||
@@ -507,6 +519,7 @@ class PipelineUtilsTest(unittest.TestCase):
|
||||
outputs = [item for item in dataset]
|
||||
self.assertEqual(outputs, [2, 3, 4, 5])
|
||||
|
||||
@require_torch
|
||||
def test_pipeline_batch_unbatch_iterator(self):
|
||||
from transformers.pipelines.pt_utils import PipelineIterator
|
||||
|
||||
@@ -520,6 +533,7 @@ class PipelineUtilsTest(unittest.TestCase):
|
||||
outputs = [item for item in dataset]
|
||||
self.assertEqual(outputs, [{"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}])
|
||||
|
||||
@require_torch
|
||||
def test_pipeline_batch_unbatch_iterator_tensors(self):
|
||||
import torch
|
||||
|
||||
@@ -537,6 +551,7 @@ class PipelineUtilsTest(unittest.TestCase):
|
||||
nested_simplify(outputs), [{"id": [[12, 22]]}, {"id": [[2, 3]]}, {"id": [[2, 4]]}, {"id": [[5]]}]
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_pipeline_chunk_iterator(self):
|
||||
from transformers.pipelines.pt_utils import PipelineChunkIterator
|
||||
|
||||
@@ -552,6 +567,7 @@ class PipelineUtilsTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(outputs, [0, 1, 0, 1, 2])
|
||||
|
||||
@require_torch
|
||||
def test_pipeline_pack_iterator(self):
|
||||
from transformers.pipelines.pt_utils import PipelinePackIterator
|
||||
|
||||
@@ -584,6 +600,7 @@ class PipelineUtilsTest(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_pipeline_pack_unbatch_iterator(self):
|
||||
from transformers.pipelines.pt_utils import PipelinePackIterator
|
||||
|
||||
@@ -607,3 +624,125 @@ class PipelineUtilsTest(unittest.TestCase):
|
||||
|
||||
outputs = [item for item in dataset]
|
||||
self.assertEqual(outputs, [[{"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}]])
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_load_default_pipelines_pt(self):
|
||||
import torch
|
||||
|
||||
from transformers.pipelines import SUPPORTED_TASKS
|
||||
|
||||
set_seed_fn = lambda: torch.manual_seed(0) # noqa: E731
|
||||
for task in SUPPORTED_TASKS.keys():
|
||||
if task == "table-question-answering":
|
||||
# test table in seperate test due to more dependencies
|
||||
continue
|
||||
|
||||
self.check_default_pipeline(task, "pt", set_seed_fn, self.check_models_equal_pt)
|
||||
|
||||
@slow
|
||||
@require_tf
|
||||
def test_load_default_pipelines_tf(self):
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers.pipelines import SUPPORTED_TASKS
|
||||
|
||||
set_seed_fn = lambda: tf.random.set_seed(0) # noqa: E731
|
||||
for task in SUPPORTED_TASKS.keys():
|
||||
if task == "table-question-answering":
|
||||
# test table in seperate test due to more dependencies
|
||||
continue
|
||||
|
||||
self.check_default_pipeline(task, "tf", set_seed_fn, self.check_models_equal_tf)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_scatter
|
||||
def test_load_default_pipelines_pt_table_qa(self):
|
||||
import torch
|
||||
|
||||
set_seed_fn = lambda: torch.manual_seed(0) # noqa: E731
|
||||
self.check_default_pipeline("table-question-answering", "pt", set_seed_fn, self.check_models_equal_pt)
|
||||
|
||||
@slow
|
||||
@require_tf
|
||||
@require_tensorflow_probability
|
||||
def test_load_default_pipelines_tf_table_qa(self):
|
||||
import tensorflow as tf
|
||||
|
||||
set_seed_fn = lambda: tf.random.set_seed(0) # noqa: E731
|
||||
self.check_default_pipeline("table-question-answering", "tf", set_seed_fn, self.check_models_equal_tf)
|
||||
|
||||
def check_default_pipeline(self, task, framework, set_seed_fn, check_models_equal_fn):
|
||||
from transformers.pipelines import SUPPORTED_TASKS, pipeline
|
||||
|
||||
task_dict = SUPPORTED_TASKS[task]
|
||||
# test to compare pipeline to manually loading the respective model
|
||||
model = None
|
||||
relevant_auto_classes = task_dict[framework]
|
||||
|
||||
if len(relevant_auto_classes) == 0:
|
||||
# task has no default
|
||||
logger.debug(f"{task} in {framework} has no default")
|
||||
return
|
||||
|
||||
# by default use first class
|
||||
auto_model_cls = relevant_auto_classes[0]
|
||||
|
||||
# retrieve correct model ids
|
||||
if task == "translation":
|
||||
# special case for translation pipeline which has multiple languages
|
||||
model_ids = []
|
||||
revisions = []
|
||||
tasks = []
|
||||
for translation_pair in task_dict["default"].keys():
|
||||
model_id, revision = task_dict["default"][translation_pair]["model"][framework]
|
||||
|
||||
model_ids.append(model_id)
|
||||
revisions.append(revision)
|
||||
tasks.append(task + f"_{'_to_'.join(translation_pair)}")
|
||||
else:
|
||||
# normal case - non-translation pipeline
|
||||
model_id, revision = task_dict["default"]["model"][framework]
|
||||
|
||||
model_ids = [model_id]
|
||||
revisions = [revision]
|
||||
tasks = [task]
|
||||
|
||||
# check for equality
|
||||
for model_id, revision, task in zip(model_ids, revisions, tasks):
|
||||
# load default model
|
||||
try:
|
||||
set_seed_fn()
|
||||
model = auto_model_cls.from_pretrained(model_id, revision=revision)
|
||||
except ValueError:
|
||||
# first auto class is possible not compatible with model, go to next model class
|
||||
auto_model_cls = relevant_auto_classes[1]
|
||||
set_seed_fn()
|
||||
model = auto_model_cls.from_pretrained(model_id, revision=revision)
|
||||
|
||||
# load default pipeline
|
||||
set_seed_fn()
|
||||
default_pipeline = pipeline(task, framework=framework)
|
||||
|
||||
# compare pipeline model with default model
|
||||
models_are_equal = check_models_equal_fn(default_pipeline.model, model)
|
||||
self.assertTrue(models_are_equal, f"{task} model doesn't match pipeline.")
|
||||
|
||||
logger.debug(f"{task} in {framework} succeeded with {model_id}.")
|
||||
|
||||
def check_models_equal_pt(self, model1, model2):
|
||||
models_are_equal = True
|
||||
for model1_p, model2_p in zip(model1.parameters(), model2.parameters()):
|
||||
if model1_p.data.ne(model2_p.data).sum() > 0:
|
||||
models_are_equal = False
|
||||
|
||||
return models_are_equal
|
||||
|
||||
def check_models_equal_tf(self, model1, model2):
|
||||
models_are_equal = True
|
||||
for model1_p, model2_p in zip(model1.weights, model2.weights):
|
||||
if np.abs(model1_p.numpy() - model2_p.numpy()).sum() > 1e-5:
|
||||
models_are_equal = False
|
||||
|
||||
return models_are_equal
|
||||
|
||||
Reference in New Issue
Block a user