Update tiny model creation script (#27674)
update Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -420,17 +420,25 @@ def get_tiny_config(config_class, model_class=None, **model_tester_kwargs):
|
||||
error = f"Tiny config not created for {model_type} - no model tester is found in the testing module."
|
||||
raise ValueError(error)
|
||||
|
||||
# CLIP-like models have `text_model_tester` and `vision_model_tester`, and we need to pass `vocab_size` to
|
||||
# `text_model_tester` via `text_kwargs`. The same trick is also necessary for `Flava`.
|
||||
|
||||
if "vocab_size" in model_tester_kwargs:
|
||||
if "text_kwargs" in inspect.signature(model_tester_class.__init__).parameters.keys():
|
||||
vocab_size = model_tester_kwargs.pop("vocab_size")
|
||||
model_tester_kwargs["text_kwargs"] = {"vocab_size": vocab_size}
|
||||
|
||||
# `parent` is an instance of `unittest.TestCase`, but we don't need it here.
|
||||
model_tester = model_tester_class(parent=None, **model_tester_kwargs)
|
||||
|
||||
if hasattr(model_tester, "get_pipeline_config"):
|
||||
return model_tester.get_pipeline_config()
|
||||
config = model_tester.get_pipeline_config()
|
||||
elif hasattr(model_tester, "prepare_config_and_inputs"):
|
||||
# `PoolFormer` has no `get_config` defined. Furthermore, it's better to use `prepare_config_and_inputs` even if
|
||||
# `get_config` is defined, since there might be some extra changes in `prepare_config_and_inputs`.
|
||||
return model_tester.prepare_config_and_inputs()[0]
|
||||
config = model_tester.prepare_config_and_inputs()[0]
|
||||
elif hasattr(model_tester, "get_config"):
|
||||
return model_tester.get_config()
|
||||
config = model_tester.get_config()
|
||||
else:
|
||||
error = (
|
||||
f"Tiny config not created for {model_type} - the model tester {model_tester_class.__name__} lacks"
|
||||
@@ -438,6 +446,26 @@ def get_tiny_config(config_class, model_class=None, **model_tester_kwargs):
|
||||
)
|
||||
raise ValueError(error)
|
||||
|
||||
# make sure this is long enough (some model tester has `20` for this attr.) to pass `text-generation`
|
||||
# pipeline tests.
|
||||
max_positions = []
|
||||
for key in ["max_position_embeddings", "max_source_positions", "max_target_positions"]:
|
||||
if getattr(config, key, 0) > 0:
|
||||
max_positions.append(getattr(config, key))
|
||||
if getattr(config, "text_config", None) is not None:
|
||||
if getattr(config.text_config, key, None) is not None:
|
||||
max_positions.append(getattr(config.text_config, key))
|
||||
if len(max_positions) > 0:
|
||||
max_position = max(200, min(max_positions))
|
||||
for key in ["max_position_embeddings", "max_source_positions", "max_target_positions"]:
|
||||
if getattr(config, key, 0) > 0:
|
||||
setattr(config, key, max_position)
|
||||
if getattr(config, "text_config", None) is not None:
|
||||
if getattr(config.text_config, key, None) is not None:
|
||||
setattr(config.text_config, key, max_position)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def convert_tokenizer(tokenizer_fast: PreTrainedTokenizerFast):
|
||||
new_tokenizer = tokenizer_fast.train_new_from_iterator(
|
||||
@@ -1006,26 +1034,8 @@ def get_config_overrides(config_class, processors):
|
||||
|
||||
# Used to create a new model tester with `tokenizer.vocab_size` in order to get the (updated) special token ids.
|
||||
model_tester_kwargs = {"vocab_size": vocab_size}
|
||||
# CLIP-like models have `text_model_tester` and `vision_model_tester`, and we need to pass `vocab_size` to
|
||||
# `text_model_tester` via `text_kwargs`. The same trick is also necessary for `Flava`.
|
||||
if config_class.__name__ in [
|
||||
"AlignConfig",
|
||||
"AltCLIPConfig",
|
||||
"ChineseCLIPConfig",
|
||||
"CLIPSegConfig",
|
||||
"ClapConfig",
|
||||
"CLIPConfig",
|
||||
"GroupViTConfig",
|
||||
"OwlViTConfig",
|
||||
"XCLIPConfig",
|
||||
"FlavaConfig",
|
||||
"BlipConfig",
|
||||
"Blip2Config",
|
||||
]:
|
||||
del model_tester_kwargs["vocab_size"]
|
||||
model_tester_kwargs["text_kwargs"] = {"vocab_size": vocab_size}
|
||||
# `FSMTModelTester` accepts `src_vocab_size` and `tgt_vocab_size` but not `vocab_size`.
|
||||
elif config_class.__name__ == "FSMTConfig":
|
||||
if config_class.__name__ == "FSMTConfig":
|
||||
del model_tester_kwargs["vocab_size"]
|
||||
model_tester_kwargs["src_vocab_size"] = tokenizer.src_vocab_size
|
||||
model_tester_kwargs["tgt_vocab_size"] = tokenizer.tgt_vocab_size
|
||||
@@ -1158,7 +1168,9 @@ def build(config_class, models_to_create, output_dir):
|
||||
if hasattr(tiny_config, k):
|
||||
setattr(tiny_config, k, v)
|
||||
# So far, we only have to deal with `text_config`, as `config_overrides` contains text-related attributes only.
|
||||
elif (
|
||||
# `FuyuConfig` saves data under both FuyuConfig and its `text_config`. This is not good, but let's just update
|
||||
# every involved fields to avoid potential failure.
|
||||
if (
|
||||
hasattr(tiny_config, "text_config")
|
||||
and tiny_config.text_config is not None
|
||||
and hasattr(tiny_config.text_config, k)
|
||||
|
||||
Reference in New Issue
Block a user