Update tiny model summary file for recent models (#22637)
* Update tiny model summary file for recent models --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -19,7 +19,7 @@ import inspect
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import EfficientNetConfig
|
from transformers import EfficientNetConfig
|
||||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
from transformers.testing_utils import is_pipeline_test, require_torch, require_vision, slow, torch_device
|
||||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@@ -229,6 +229,12 @@ class EfficientNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
|
|||||||
model = EfficientNetModel.from_pretrained(model_name)
|
model = EfficientNetModel.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
@is_pipeline_test
|
||||||
|
@require_vision
|
||||||
|
@slow
|
||||||
|
def test_pipeline_image_classification(self):
|
||||||
|
super().test_pipeline_image_classification()
|
||||||
|
|
||||||
|
|
||||||
# We will verify our results on an image of cute cats
|
# We will verify our results on an image of cute cats
|
||||||
def prepare_img():
|
def prepare_img():
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from transformers.testing_utils import require_torch, slow, tooslow, torch_devic
|
|||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||||
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||||
|
|
||||||
|
|
||||||
class GPTSanJapaneseTester:
|
class GPTSanJapaneseTester:
|
||||||
@@ -127,8 +128,19 @@ class GPTSanJapaneseTester:
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class GPTSanJapaneseTest(ModelTesterMixin, unittest.TestCase):
|
class GPTSanJapaneseTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (GPTSanJapaneseModel,) if is_torch_available() else ()
|
all_model_classes = (GPTSanJapaneseModel,) if is_torch_available() else ()
|
||||||
|
pipeline_model_mapping = (
|
||||||
|
{
|
||||||
|
"conversational": GPTSanJapaneseForConditionalGeneration,
|
||||||
|
"feature-extraction": GPTSanJapaneseForConditionalGeneration,
|
||||||
|
"summarization": GPTSanJapaneseForConditionalGeneration,
|
||||||
|
"text2text-generation": GPTSanJapaneseForConditionalGeneration,
|
||||||
|
"translation": GPTSanJapaneseForConditionalGeneration,
|
||||||
|
}
|
||||||
|
if is_torch_available()
|
||||||
|
else {}
|
||||||
|
)
|
||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
@@ -140,6 +152,19 @@ class GPTSanJapaneseTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
# The small GPTSAN_JAPANESE model needs higher percentages for CPU/MP tests
|
# The small GPTSAN_JAPANESE model needs higher percentages for CPU/MP tests
|
||||||
model_split_percents = [0.8, 0.9]
|
model_split_percents = [0.8, 0.9]
|
||||||
|
|
||||||
|
# TODO: Fix the failed tests when this model gets more usage
|
||||||
|
def is_pipeline_test_to_skip(
|
||||||
|
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||||
|
):
|
||||||
|
if pipeline_test_casse_name == "SummarizationPipelineTests":
|
||||||
|
# TODO: fix `_reorder_cache` is not implemented for this model
|
||||||
|
return True
|
||||||
|
elif pipeline_test_casse_name == "Text2TextGenerationPipelineTests":
|
||||||
|
# TODO: check this.
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = GPTSanJapaneseTester(self)
|
self.model_tester = GPTSanJapaneseTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=GPTSanJapaneseConfig, d_model=37)
|
self.config_tester = ConfigTester(self, config_class=GPTSanJapaneseConfig, d_model=37)
|
||||||
|
|||||||
@@ -299,7 +299,10 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
def is_pipeline_test_to_skip(
|
def is_pipeline_test_to_skip(
|
||||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||||
):
|
):
|
||||||
if pipeline_test_casse_name == "AutomaticSpeechRecognitionPipelineTests":
|
if pipeline_test_casse_name in [
|
||||||
|
"AutomaticSpeechRecognitionPipelineTests",
|
||||||
|
"AudioClassificationPipelineTests",
|
||||||
|
]:
|
||||||
# RuntimeError: The size of tensor a (1500) must match the size of tensor b (30) at non-singleton
|
# RuntimeError: The size of tensor a (1500) must match the size of tensor b (30) at non-singleton
|
||||||
# dimension 1
|
# dimension 1
|
||||||
return True
|
return True
|
||||||
|
|||||||
@@ -137,7 +137,7 @@
|
|||||||
"model_classes": [
|
"model_classes": [
|
||||||
"BartForCausalLM"
|
"BartForCausalLM"
|
||||||
],
|
],
|
||||||
"sha": "6ca393c5c34d638e70bafdc02488b65b9025872c"
|
"sha": "c25526ac67d2dbe79fe5462af4b7908ca2fbc3ff"
|
||||||
},
|
},
|
||||||
"BartForConditionalGeneration": {
|
"BartForConditionalGeneration": {
|
||||||
"tokenizer_classes": [
|
"tokenizer_classes": [
|
||||||
@@ -149,7 +149,7 @@
|
|||||||
"BartForConditionalGeneration",
|
"BartForConditionalGeneration",
|
||||||
"TFBartForConditionalGeneration"
|
"TFBartForConditionalGeneration"
|
||||||
],
|
],
|
||||||
"sha": "44a5e3a5616b22b89cb767ac8d05f360e5de2e58"
|
"sha": "3a489a21e4b04705f4a6047924b7616a67be7e37"
|
||||||
},
|
},
|
||||||
"BartForQuestionAnswering": {
|
"BartForQuestionAnswering": {
|
||||||
"tokenizer_classes": [
|
"tokenizer_classes": [
|
||||||
@@ -160,7 +160,7 @@
|
|||||||
"model_classes": [
|
"model_classes": [
|
||||||
"BartForQuestionAnswering"
|
"BartForQuestionAnswering"
|
||||||
],
|
],
|
||||||
"sha": "291888e031ae29b9defb5a4376460107cfb7a1a9"
|
"sha": "3ebf9aab39a57ceab55128d5fc6f61e4db0dadd4"
|
||||||
},
|
},
|
||||||
"BartForSequenceClassification": {
|
"BartForSequenceClassification": {
|
||||||
"tokenizer_classes": [
|
"tokenizer_classes": [
|
||||||
@@ -169,9 +169,10 @@
|
|||||||
],
|
],
|
||||||
"processor_classes": [],
|
"processor_classes": [],
|
||||||
"model_classes": [
|
"model_classes": [
|
||||||
"BartForSequenceClassification"
|
"BartForSequenceClassification",
|
||||||
|
"TFBartForSequenceClassification"
|
||||||
],
|
],
|
||||||
"sha": "5ceca1f5dbcf32c04ef44355e4bc66128cd4ea8b"
|
"sha": "ea452fd9a928cfebd71723afa50feb20326917bc"
|
||||||
},
|
},
|
||||||
"BartModel": {
|
"BartModel": {
|
||||||
"tokenizer_classes": [
|
"tokenizer_classes": [
|
||||||
@@ -183,7 +184,7 @@
|
|||||||
"BartModel",
|
"BartModel",
|
||||||
"TFBartModel"
|
"TFBartModel"
|
||||||
],
|
],
|
||||||
"sha": "26c409f22daa4773a78d7a7c80510cdc8b752a3d"
|
"sha": "e5df6d1aa75f03833b2df328b9c35463f73a421b"
|
||||||
},
|
},
|
||||||
"BeitForImageClassification": {
|
"BeitForImageClassification": {
|
||||||
"tokenizer_classes": [],
|
"tokenizer_classes": [],
|
||||||
@@ -476,6 +477,16 @@
|
|||||||
],
|
],
|
||||||
"sha": "07073b31da84054fd12226e3cae4cb3beb2547f9"
|
"sha": "07073b31da84054fd12226e3cae4cb3beb2547f9"
|
||||||
},
|
},
|
||||||
|
"BioGptForTokenClassification": {
|
||||||
|
"tokenizer_classes": [
|
||||||
|
"BioGptTokenizer"
|
||||||
|
],
|
||||||
|
"processor_classes": [],
|
||||||
|
"model_classes": [
|
||||||
|
"BioGptForTokenClassification"
|
||||||
|
],
|
||||||
|
"sha": "67f8173c1a17273064d452a9031a51b67f327b6a"
|
||||||
|
},
|
||||||
"BioGptModel": {
|
"BioGptModel": {
|
||||||
"tokenizer_classes": [
|
"tokenizer_classes": [
|
||||||
"BioGptTokenizer"
|
"BioGptTokenizer"
|
||||||
@@ -618,9 +629,10 @@
|
|||||||
"BlipImageProcessor"
|
"BlipImageProcessor"
|
||||||
],
|
],
|
||||||
"model_classes": [
|
"model_classes": [
|
||||||
"BlipForConditionalGeneration"
|
"BlipForConditionalGeneration",
|
||||||
|
"TFBlipForConditionalGeneration"
|
||||||
],
|
],
|
||||||
"sha": "e776bae5de3a4e9c11170b2465775eb37baf2bfe"
|
"sha": "eaf32bc0369349deef0c777442fc185119171d1f"
|
||||||
},
|
},
|
||||||
"BlipModel": {
|
"BlipModel": {
|
||||||
"tokenizer_classes": [
|
"tokenizer_classes": [
|
||||||
@@ -631,9 +643,10 @@
|
|||||||
"BlipImageProcessor"
|
"BlipImageProcessor"
|
||||||
],
|
],
|
||||||
"model_classes": [
|
"model_classes": [
|
||||||
"BlipModel"
|
"BlipModel",
|
||||||
|
"TFBlipModel"
|
||||||
],
|
],
|
||||||
"sha": "261433f322f7146b0c28c0c025e92b3a33f716bb"
|
"sha": "3d1d1c15eff22d6b2664a2d15757fa6f5d93827d"
|
||||||
},
|
},
|
||||||
"BloomForCausalLM": {
|
"BloomForCausalLM": {
|
||||||
"tokenizer_classes": [
|
"tokenizer_classes": [
|
||||||
@@ -808,6 +821,19 @@
|
|||||||
],
|
],
|
||||||
"sha": "504271a3c5fd9c2e877f5b4c01848bc18778c7c3"
|
"sha": "504271a3c5fd9c2e877f5b4c01848bc18778c7c3"
|
||||||
},
|
},
|
||||||
|
"ClapModel": {
|
||||||
|
"tokenizer_classes": [
|
||||||
|
"RobertaTokenizer",
|
||||||
|
"RobertaTokenizerFast"
|
||||||
|
],
|
||||||
|
"processor_classes": [
|
||||||
|
"ClapFeatureExtractor"
|
||||||
|
],
|
||||||
|
"model_classes": [
|
||||||
|
"ClapModel"
|
||||||
|
],
|
||||||
|
"sha": "a7874595b900f9b2ddc79130dafc3ff48f4fbfb9"
|
||||||
|
},
|
||||||
"CodeGenForCausalLM": {
|
"CodeGenForCausalLM": {
|
||||||
"tokenizer_classes": [
|
"tokenizer_classes": [
|
||||||
"CodeGenTokenizer",
|
"CodeGenTokenizer",
|
||||||
@@ -2397,7 +2423,7 @@
|
|||||||
"model_classes": [
|
"model_classes": [
|
||||||
"GPTSanJapaneseForConditionalGeneration"
|
"GPTSanJapaneseForConditionalGeneration"
|
||||||
],
|
],
|
||||||
"sha": "83bbd0feb62cd12d9163c7638e15bf2bb6fef1eb"
|
"sha": "ff6a41faaa713c7fbd5d9a1a50539745f9e1178e"
|
||||||
},
|
},
|
||||||
"GitForCausalLM": {
|
"GitForCausalLM": {
|
||||||
"tokenizer_classes": [
|
"tokenizer_classes": [
|
||||||
@@ -3340,6 +3366,83 @@
|
|||||||
],
|
],
|
||||||
"sha": "473b54a464bc0ccee29bc23b4f6610f32eec05af"
|
"sha": "473b54a464bc0ccee29bc23b4f6610f32eec05af"
|
||||||
},
|
},
|
||||||
|
"MegaForCausalLM": {
|
||||||
|
"tokenizer_classes": [
|
||||||
|
"RobertaTokenizer",
|
||||||
|
"RobertaTokenizerFast"
|
||||||
|
],
|
||||||
|
"processor_classes": [],
|
||||||
|
"model_classes": [
|
||||||
|
"MegaForCausalLM"
|
||||||
|
],
|
||||||
|
"sha": "6642b9da860f8b62abcfb0660feabcebf6698418"
|
||||||
|
},
|
||||||
|
"MegaForMaskedLM": {
|
||||||
|
"tokenizer_classes": [
|
||||||
|
"RobertaTokenizer",
|
||||||
|
"RobertaTokenizerFast"
|
||||||
|
],
|
||||||
|
"processor_classes": [],
|
||||||
|
"model_classes": [
|
||||||
|
"MegaForMaskedLM"
|
||||||
|
],
|
||||||
|
"sha": "6b2d47ba03bec9e6f7eefdd4a67351fa191aae6f"
|
||||||
|
},
|
||||||
|
"MegaForMultipleChoice": {
|
||||||
|
"tokenizer_classes": [
|
||||||
|
"RobertaTokenizer",
|
||||||
|
"RobertaTokenizerFast"
|
||||||
|
],
|
||||||
|
"processor_classes": [],
|
||||||
|
"model_classes": [
|
||||||
|
"MegaForMultipleChoice"
|
||||||
|
],
|
||||||
|
"sha": "2b1e751da36a4410473eef07a62b09227a26d504"
|
||||||
|
},
|
||||||
|
"MegaForQuestionAnswering": {
|
||||||
|
"tokenizer_classes": [
|
||||||
|
"RobertaTokenizer",
|
||||||
|
"RobertaTokenizerFast"
|
||||||
|
],
|
||||||
|
"processor_classes": [],
|
||||||
|
"model_classes": [
|
||||||
|
"MegaForQuestionAnswering"
|
||||||
|
],
|
||||||
|
"sha": "612acd9a53c351c42514adb3c04f2057d2870be7"
|
||||||
|
},
|
||||||
|
"MegaForSequenceClassification": {
|
||||||
|
"tokenizer_classes": [
|
||||||
|
"RobertaTokenizer",
|
||||||
|
"RobertaTokenizerFast"
|
||||||
|
],
|
||||||
|
"processor_classes": [],
|
||||||
|
"model_classes": [
|
||||||
|
"MegaForSequenceClassification"
|
||||||
|
],
|
||||||
|
"sha": "4871572da1613b7e9cfd3640c6d1129af004eefb"
|
||||||
|
},
|
||||||
|
"MegaForTokenClassification": {
|
||||||
|
"tokenizer_classes": [
|
||||||
|
"RobertaTokenizer",
|
||||||
|
"RobertaTokenizerFast"
|
||||||
|
],
|
||||||
|
"processor_classes": [],
|
||||||
|
"model_classes": [
|
||||||
|
"MegaForTokenClassification"
|
||||||
|
],
|
||||||
|
"sha": "450d3722c3b995215d06b9c12544c99f958581c7"
|
||||||
|
},
|
||||||
|
"MegaModel": {
|
||||||
|
"tokenizer_classes": [
|
||||||
|
"RobertaTokenizer",
|
||||||
|
"RobertaTokenizerFast"
|
||||||
|
],
|
||||||
|
"processor_classes": [],
|
||||||
|
"model_classes": [
|
||||||
|
"MegaModel"
|
||||||
|
],
|
||||||
|
"sha": "ca0862db27428893fe22f9bb5d2eb0875c2156f3"
|
||||||
|
},
|
||||||
"MegatronBertForCausalLM": {
|
"MegatronBertForCausalLM": {
|
||||||
"tokenizer_classes": [
|
"tokenizer_classes": [
|
||||||
"BertTokenizer",
|
"BertTokenizer",
|
||||||
@@ -3801,6 +3904,28 @@
|
|||||||
],
|
],
|
||||||
"sha": "80e05ba7c55bcdd7f4d1387ef9a09a7a8e95b5ac"
|
"sha": "80e05ba7c55bcdd7f4d1387ef9a09a7a8e95b5ac"
|
||||||
},
|
},
|
||||||
|
"NllbMoeForConditionalGeneration": {
|
||||||
|
"tokenizer_classes": [
|
||||||
|
"NllbTokenizer",
|
||||||
|
"NllbTokenizerFast"
|
||||||
|
],
|
||||||
|
"processor_classes": [],
|
||||||
|
"model_classes": [
|
||||||
|
"NllbMoeForConditionalGeneration"
|
||||||
|
],
|
||||||
|
"sha": "2a7f87dffe826af3d52086888f3f3773246e5528"
|
||||||
|
},
|
||||||
|
"NllbMoeModel": {
|
||||||
|
"tokenizer_classes": [
|
||||||
|
"NllbTokenizer",
|
||||||
|
"NllbTokenizerFast"
|
||||||
|
],
|
||||||
|
"processor_classes": [],
|
||||||
|
"model_classes": [
|
||||||
|
"NllbMoeModel"
|
||||||
|
],
|
||||||
|
"sha": "9f7a2261eed4658e1aa5623be4672ba64bee7da5"
|
||||||
|
},
|
||||||
"NystromformerForMaskedLM": {
|
"NystromformerForMaskedLM": {
|
||||||
"tokenizer_classes": [
|
"tokenizer_classes": [
|
||||||
"AlbertTokenizer",
|
"AlbertTokenizer",
|
||||||
@@ -5584,9 +5709,10 @@
|
|||||||
"ViTImageProcessor"
|
"ViTImageProcessor"
|
||||||
],
|
],
|
||||||
"model_classes": [
|
"model_classes": [
|
||||||
|
"TFVisionTextDualEncoderModel",
|
||||||
"VisionTextDualEncoderModel"
|
"VisionTextDualEncoderModel"
|
||||||
],
|
],
|
||||||
"sha": "fcedabfb7cbe3c717c1485613064418acf57ab3d"
|
"sha": "c3569ef17f66acbacb76f7ceb6f71e02d075dd6c"
|
||||||
},
|
},
|
||||||
"VisualBertForPreTraining": {
|
"VisualBertForPreTraining": {
|
||||||
"tokenizer_classes": [
|
"tokenizer_classes": [
|
||||||
@@ -5791,6 +5917,18 @@
|
|||||||
],
|
],
|
||||||
"sha": "e932275e37cb643be271f655bd1d649f4f4b4bd5"
|
"sha": "e932275e37cb643be271f655bd1d649f4f4b4bd5"
|
||||||
},
|
},
|
||||||
|
"WhisperForAudioClassification": {
|
||||||
|
"tokenizer_classes": [
|
||||||
|
"WhisperTokenizer"
|
||||||
|
],
|
||||||
|
"processor_classes": [
|
||||||
|
"WhisperFeatureExtractor"
|
||||||
|
],
|
||||||
|
"model_classes": [
|
||||||
|
"WhisperForAudioClassification"
|
||||||
|
],
|
||||||
|
"sha": "d71b13674b1a67443cd19d0594a3b5b1e5968f0d"
|
||||||
|
},
|
||||||
"WhisperForConditionalGeneration": {
|
"WhisperForConditionalGeneration": {
|
||||||
"tokenizer_classes": [
|
"tokenizer_classes": [
|
||||||
"WhisperTokenizer",
|
"WhisperTokenizer",
|
||||||
|
|||||||
@@ -991,6 +991,12 @@ def get_config_overrides(config_class, processors):
|
|||||||
# We use `len(tokenizer)` instead of `tokenizer.vocab_size` to avoid potential issues for tokenizers with non-empty
|
# We use `len(tokenizer)` instead of `tokenizer.vocab_size` to avoid potential issues for tokenizers with non-empty
|
||||||
# `added_tokens_encoder`. One example is the `DebertaV2Tokenizer` where the mask token is the extra token.
|
# `added_tokens_encoder`. One example is the `DebertaV2Tokenizer` where the mask token is the extra token.
|
||||||
vocab_size = len(tokenizer)
|
vocab_size = len(tokenizer)
|
||||||
|
|
||||||
|
# The original checkpoint has length `35998`, but it doesn't have ids `30400` and `30514` but instead `35998` and
|
||||||
|
# `35999`.
|
||||||
|
if config_class.__name__ == "GPTSanJapaneseConfig":
|
||||||
|
vocab_size += 2
|
||||||
|
|
||||||
config_overrides["vocab_size"] = vocab_size
|
config_overrides["vocab_size"] = vocab_size
|
||||||
|
|
||||||
# Used to create a new model tester with `tokenizer.vocab_size` in order to get the (updated) special token ids.
|
# Used to create a new model tester with `tokenizer.vocab_size` in order to get the (updated) special token ids.
|
||||||
@@ -1329,6 +1335,33 @@ def build_simple_report(results):
|
|||||||
return text, failed_text
|
return text, failed_text
|
||||||
|
|
||||||
|
|
||||||
|
def update_tiny_model_summary_file(report_path):
|
||||||
|
with open(os.path.join(report_path, "tiny_model_summary.json")) as fp:
|
||||||
|
new_data = json.load(fp)
|
||||||
|
with open("tests/utils/tiny_model_summary.json") as fp:
|
||||||
|
data = json.load(fp)
|
||||||
|
for key, value in new_data.items():
|
||||||
|
if key not in data:
|
||||||
|
data[key] = value
|
||||||
|
else:
|
||||||
|
for attr in ["tokenizer_classes", "processor_classes", "model_classes"]:
|
||||||
|
# we might get duplication here. We will remove them below when creating `updated_data`.
|
||||||
|
data[key][attr].extend(value[attr])
|
||||||
|
new_sha = value.get("sha", None)
|
||||||
|
if new_sha is not None:
|
||||||
|
data[key]["sha"] = new_sha
|
||||||
|
|
||||||
|
updated_data = {}
|
||||||
|
for key in sorted(data.keys()):
|
||||||
|
updated_data[key] = {}
|
||||||
|
for attr, value in data[key].items():
|
||||||
|
# deduplication and sort
|
||||||
|
updated_data[key][attr] = sorted(set(value)) if attr != "sha" else value
|
||||||
|
|
||||||
|
with open(os.path.join(report_path, "updated_tiny_model_summary.json"), "w") as fp:
|
||||||
|
json.dump(updated_data, fp, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
def create_tiny_models(
|
def create_tiny_models(
|
||||||
output_path,
|
output_path,
|
||||||
all,
|
all,
|
||||||
@@ -1444,6 +1477,8 @@ def create_tiny_models(
|
|||||||
with open(os.path.join(report_path, "simple_failed_report.txt"), "w") as fp:
|
with open(os.path.join(report_path, "simple_failed_report.txt"), "w") as fp:
|
||||||
fp.write(failed_report)
|
fp.write(failed_report)
|
||||||
|
|
||||||
|
update_tiny_model_summary_file(report_path=os.path.join(output_path, "reports"))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# This has to be `spawn` to avoid hanging forever!
|
# This has to be `spawn` to avoid hanging forever!
|
||||||
|
|||||||
@@ -171,33 +171,6 @@ def get_tiny_model_summary_from_hub(output_path):
|
|||||||
json.dump(summary, fp, ensure_ascii=False, indent=4)
|
json.dump(summary, fp, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
|
||||||
def update_tiny_model_summary_file(report_path):
|
|
||||||
with open(os.path.join(report_path, "tiny_model_summary.json")) as fp:
|
|
||||||
new_data = json.load(fp)
|
|
||||||
with open("tests/utils/tiny_model_summary.json") as fp:
|
|
||||||
data = json.load(fp)
|
|
||||||
for key, value in new_data.items():
|
|
||||||
if key not in data:
|
|
||||||
data[key] = value
|
|
||||||
else:
|
|
||||||
for attr in ["tokenizer_classes", "processor_classes", "model_classes"]:
|
|
||||||
# we might get duplication here. We will remove them below when creating `updated_data`.
|
|
||||||
data[key][attr].extend(value[attr])
|
|
||||||
new_sha = value["sha"]
|
|
||||||
if new_sha is not None:
|
|
||||||
data[key]["sha"] = new_sha
|
|
||||||
|
|
||||||
updated_data = {}
|
|
||||||
for key in sorted(data.keys()):
|
|
||||||
updated_data[key] = {}
|
|
||||||
for attr, value in data[key].items():
|
|
||||||
# deduplication and sort
|
|
||||||
updated_data[key][attr] = sorted(set(value)) if attr != "sha" else value
|
|
||||||
|
|
||||||
with open(os.path.join(report_path, "updated_tiny_model_summary.json"), "w") as fp:
|
|
||||||
json.dump(updated_data, fp, indent=4, ensure_ascii=False)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--num_workers", default=1, type=int, help="The number of workers to run.")
|
parser.add_argument("--num_workers", default=1, type=int, help="The number of workers to run.")
|
||||||
@@ -225,5 +198,3 @@ if __name__ == "__main__":
|
|||||||
token=os.environ.get("TOKEN", None),
|
token=os.environ.get("TOKEN", None),
|
||||||
num_workers=args.num_workers,
|
num_workers=args.num_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
update_tiny_model_summary_file(report_path=os.path.join(output_path, "reports"))
|
|
||||||
|
|||||||
Reference in New Issue
Block a user