Fix pylint warnings (#39477)
* Fix pylint warnings Signed-off-by: cyy <cyyever@outlook.com> * Fix variable names Signed-off-by: cyy <cyyever@outlook.com> --------- Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
@@ -1079,10 +1079,10 @@ def add_model_to_auto_classes(
|
|||||||
new_model_patterns (`ModelPatterns`): The patterns for the new model.
|
new_model_patterns (`ModelPatterns`): The patterns for the new model.
|
||||||
model_classes (`dict[str, list[str]]`): A dictionary framework to list of model classes implemented.
|
model_classes (`dict[str, list[str]]`): A dictionary framework to list of model classes implemented.
|
||||||
"""
|
"""
|
||||||
for filename in AUTO_CLASSES_PATTERNS:
|
for filename, patterns in AUTO_CLASSES_PATTERNS.items():
|
||||||
# Extend patterns with all model classes if necessary
|
# Extend patterns with all model classes if necessary
|
||||||
new_patterns = []
|
new_patterns = []
|
||||||
for pattern in AUTO_CLASSES_PATTERNS[filename]:
|
for pattern in patterns:
|
||||||
if re.search("any_([a-z]*)_class", pattern) is not None:
|
if re.search("any_([a-z]*)_class", pattern) is not None:
|
||||||
framework = re.search("any_([a-z]*)_class", pattern).groups()[0]
|
framework = re.search("any_([a-z]*)_class", pattern).groups()[0]
|
||||||
if framework in model_classes:
|
if framework in model_classes:
|
||||||
|
|||||||
@@ -146,14 +146,14 @@ class TextStreamer(BaseStreamer):
|
|||||||
# like the all of the other languages.
|
# like the all of the other languages.
|
||||||
if (
|
if (
|
||||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
or (cp >= 0x3400 and cp <= 0x4DBF)
|
||||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
or (cp >= 0x20000 and cp <= 0x2A6DF)
|
||||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
or (cp >= 0x2A700 and cp <= 0x2B73F)
|
||||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
or (cp >= 0x2B740 and cp <= 0x2B81F)
|
||||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
or (cp >= 0x2B820 and cp <= 0x2CEAF)
|
||||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
||||||
): #
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -428,8 +428,7 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False, model_to_lo
|
|||||||
if isinstance(value, str) and architecture in value:
|
if isinstance(value, str) and architecture in value:
|
||||||
value = value.replace(architecture, updated_architecture)
|
value = value.replace(architecture, updated_architecture)
|
||||||
|
|
||||||
for parameter in GGUF_TO_TRANSFORMERS_MAPPING:
|
for parameter, parameter_renames in GGUF_TO_TRANSFORMERS_MAPPING.items():
|
||||||
parameter_renames = GGUF_TO_TRANSFORMERS_MAPPING[parameter]
|
|
||||||
if prefix in parameter_renames and config_key in parameter_renames[prefix]:
|
if prefix in parameter_renames and config_key in parameter_renames[prefix]:
|
||||||
renamed_config_key = parameter_renames[prefix][config_key]
|
renamed_config_key = parameter_renames[prefix][config_key]
|
||||||
if renamed_config_key == -1:
|
if renamed_config_key == -1:
|
||||||
|
|||||||
@@ -1572,8 +1572,8 @@ def _find_mismatched_keys(
|
|||||||
# Fix the key names
|
# Fix the key names
|
||||||
new_state_dict = {keys_to_rename_mapping[k]: v for k, v in state_dict.items() if k in keys_to_rename_mapping}
|
new_state_dict = {keys_to_rename_mapping[k]: v for k, v in state_dict.items() if k in keys_to_rename_mapping}
|
||||||
|
|
||||||
for key in new_state_dict.keys():
|
for key, tensor in new_state_dict.items():
|
||||||
if key in model_state_dict and new_state_dict[key].shape != model_state_dict[key].shape:
|
if key in model_state_dict and tensor.shape != model_state_dict[key].shape:
|
||||||
# This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences.
|
# This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences.
|
||||||
# Without matching with module type or parameter type it seems like a practical way to detect valid 4bit weights.
|
# Without matching with module type or parameter type it seems like a practical way to detect valid 4bit weights.
|
||||||
if not (
|
if not (
|
||||||
@@ -1582,7 +1582,7 @@ def _find_mismatched_keys(
|
|||||||
and new_state_dict[key].numel() * 2 == model_state_dict[key].numel()
|
and new_state_dict[key].numel() * 2 == model_state_dict[key].numel()
|
||||||
):
|
):
|
||||||
mismatched_keys.append(key)
|
mismatched_keys.append(key)
|
||||||
mismatched_shapes.append((new_state_dict[key].shape, model_state_dict[key].shape))
|
mismatched_shapes.append((tensor.shape, model_state_dict[key].shape))
|
||||||
|
|
||||||
return mismatched_keys, mismatched_shapes
|
return mismatched_keys, mismatched_shapes
|
||||||
|
|
||||||
|
|||||||
@@ -130,12 +130,12 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"):
|
|||||||
state_dict = checkpoint["model"]
|
state_dict = checkpoint["model"]
|
||||||
# fixup checkpoint
|
# fixup checkpoint
|
||||||
unwanted_prefix = "_orig_mod."
|
unwanted_prefix = "_orig_mod."
|
||||||
for k, v in list(state_dict.items()):
|
for k in state_dict:
|
||||||
if k.startswith(unwanted_prefix):
|
if k.startswith(unwanted_prefix):
|
||||||
# replace part of the key with corresponding layer name in HF implementation
|
# replace part of the key with corresponding layer name in HF implementation
|
||||||
new_k = k[len(unwanted_prefix) :]
|
new_k = k[len(unwanted_prefix) :]
|
||||||
for old_layer_name in new_layer_name_dict:
|
for old_layer_name, new_layer_name in new_layer_name_dict.items():
|
||||||
new_k = new_k.replace(old_layer_name, new_layer_name_dict[old_layer_name])
|
new_k = new_k.replace(old_layer_name, new_layer_name)
|
||||||
|
|
||||||
state_dict[new_k] = state_dict.pop(k)
|
state_dict[new_k] = state_dict.pop(k)
|
||||||
|
|
||||||
|
|||||||
@@ -392,14 +392,14 @@ class BasicTokenizer:
|
|||||||
# like the all of the other languages.
|
# like the all of the other languages.
|
||||||
if (
|
if (
|
||||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
or (cp >= 0x3400 and cp <= 0x4DBF)
|
||||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
or (cp >= 0x20000 and cp <= 0x2A6DF)
|
||||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
or (cp >= 0x2A700 and cp <= 0x2B73F)
|
||||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
or (cp >= 0x2B740 and cp <= 0x2B81F)
|
||||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
or (cp >= 0x2B820 and cp <= 0x2CEAF)
|
||||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
||||||
): #
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -796,14 +796,14 @@ class BasicTokenizer:
|
|||||||
# like the all of the other languages.
|
# like the all of the other languages.
|
||||||
if (
|
if (
|
||||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
or (cp >= 0x3400 and cp <= 0x4DBF)
|
||||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
or (cp >= 0x20000 and cp <= 0x2A6DF)
|
||||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
or (cp >= 0x2A700 and cp <= 0x2B73F)
|
||||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
or (cp >= 0x2B740 and cp <= 0x2B81F)
|
||||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
or (cp >= 0x2B820 and cp <= 0x2CEAF)
|
||||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
||||||
): #
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -225,14 +225,14 @@ class BasicTokenizer:
|
|||||||
# like the all of the other languages.
|
# like the all of the other languages.
|
||||||
if (
|
if (
|
||||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
or (cp >= 0x3400 and cp <= 0x4DBF)
|
||||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
or (cp >= 0x20000 and cp <= 0x2A6DF)
|
||||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
or (cp >= 0x2A700 and cp <= 0x2B73F)
|
||||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
or (cp >= 0x2B740 and cp <= 0x2B81F)
|
||||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
or (cp >= 0x2B820 and cp <= 0x2CEAF)
|
||||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
||||||
): #
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -396,14 +396,14 @@ class BasicTokenizer:
|
|||||||
# like the all of the other languages.
|
# like the all of the other languages.
|
||||||
if (
|
if (
|
||||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
or (cp >= 0x3400 and cp <= 0x4DBF)
|
||||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
or (cp >= 0x20000 and cp <= 0x2A6DF)
|
||||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
or (cp >= 0x2A700 and cp <= 0x2B73F)
|
||||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
or (cp >= 0x2B740 and cp <= 0x2B81F)
|
||||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
or (cp >= 0x2B820 and cp <= 0x2CEAF)
|
||||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
||||||
): #
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -290,7 +290,7 @@ def convert_cvt_checkpoint(cvt_model, image_size, cvt_file_name, pytorch_dump_fo
|
|||||||
id2label = id2label
|
id2label = id2label
|
||||||
label2id = {v: k for k, v in id2label.items()}
|
label2id = {v: k for k, v in id2label.items()}
|
||||||
|
|
||||||
config = config = CvtConfig(num_labels=num_labels, id2label=id2label, label2id=label2id)
|
config = CvtConfig(num_labels=num_labels, id2label=id2label, label2id=label2id)
|
||||||
|
|
||||||
# For depth size 13 (13 = 1+2+10)
|
# For depth size 13 (13 = 1+2+10)
|
||||||
if cvt_model.rsplit("/", 1)[-1][4:6] == "13":
|
if cvt_model.rsplit("/", 1)[-1][4:6] == "13":
|
||||||
|
|||||||
@@ -448,14 +448,14 @@ class BasicTokenizer:
|
|||||||
# like the all of the other languages.
|
# like the all of the other languages.
|
||||||
if (
|
if (
|
||||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
or (cp >= 0x3400 and cp <= 0x4DBF)
|
||||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
or (cp >= 0x20000 and cp <= 0x2A6DF)
|
||||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
or (cp >= 0x2A700 and cp <= 0x2B73F)
|
||||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
or (cp >= 0x2B740 and cp <= 0x2B81F)
|
||||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
or (cp >= 0x2B820 and cp <= 0x2CEAF)
|
||||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
||||||
): #
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -389,14 +389,14 @@ class BasicTokenizer:
|
|||||||
# like the all of the other languages.
|
# like the all of the other languages.
|
||||||
if (
|
if (
|
||||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
or (cp >= 0x3400 and cp <= 0x4DBF)
|
||||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
or (cp >= 0x20000 and cp <= 0x2A6DF)
|
||||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
or (cp >= 0x2A700 and cp <= 0x2B73F)
|
||||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
or (cp >= 0x2B740 and cp <= 0x2B81F)
|
||||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
or (cp >= 0x2B820 and cp <= 0x2CEAF)
|
||||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
||||||
): #
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -405,14 +405,14 @@ class BasicTokenizer:
|
|||||||
# like the all of the other languages.
|
# like the all of the other languages.
|
||||||
if (
|
if (
|
||||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
or (cp >= 0x3400 and cp <= 0x4DBF)
|
||||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
or (cp >= 0x20000 and cp <= 0x2A6DF)
|
||||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
or (cp >= 0x2A700 and cp <= 0x2B73F)
|
||||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
or (cp >= 0x2B740 and cp <= 0x2B81F)
|
||||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
or (cp >= 0x2B820 and cp <= 0x2CEAF)
|
||||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
||||||
): #
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -395,14 +395,14 @@ class BasicTokenizer:
|
|||||||
# like the all of the other languages.
|
# like the all of the other languages.
|
||||||
if (
|
if (
|
||||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
or (cp >= 0x3400 and cp <= 0x4DBF)
|
||||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
or (cp >= 0x20000 and cp <= 0x2A6DF)
|
||||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
or (cp >= 0x2A700 and cp <= 0x2B73F)
|
||||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
or (cp >= 0x2B740 and cp <= 0x2B81F)
|
||||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
or (cp >= 0x2B820 and cp <= 0x2CEAF)
|
||||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
||||||
): #
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -258,7 +258,7 @@ def convert_model(vq_model_id, llm_model_id, output_dir, hub_model_id=None, test
|
|||||||
# Convert and save processor
|
# Convert and save processor
|
||||||
tokenizer_tiktoken = AutoTokenizer.from_pretrained(llm_model_id, trust_remote_code=True)
|
tokenizer_tiktoken = AutoTokenizer.from_pretrained(llm_model_id, trust_remote_code=True)
|
||||||
convert_tiktoken(tokenizer_tiktoken, output_dir)
|
convert_tiktoken(tokenizer_tiktoken, output_dir)
|
||||||
extra_special_tokens = extra_special_tokens = {
|
extra_special_tokens = {
|
||||||
"image_token": "<image>",
|
"image_token": "<image>",
|
||||||
"boi_token": "<|image start|>",
|
"boi_token": "<|image start|>",
|
||||||
"eoi_token": "<|image end|>",
|
"eoi_token": "<|image end|>",
|
||||||
|
|||||||
@@ -254,7 +254,6 @@ class FalconH1Config(PretrainedConfig):
|
|||||||
if ssm_multipliers is not None:
|
if ssm_multipliers is not None:
|
||||||
self.ssm_multipliers = ssm_multipliers
|
self.ssm_multipliers = ssm_multipliers
|
||||||
else:
|
else:
|
||||||
#
|
|
||||||
self.ssm_multipliers = [1.0, 1.0, 1.0, 1.0, 1.0]
|
self.ssm_multipliers = [1.0, 1.0, 1.0, 1.0, 1.0]
|
||||||
|
|
||||||
if ssm_in_multiplier is not None:
|
if ssm_in_multiplier is not None:
|
||||||
|
|||||||
@@ -455,14 +455,14 @@ class BasicTokenizer:
|
|||||||
# like the all of the other languages.
|
# like the all of the other languages.
|
||||||
if (
|
if (
|
||||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
or (cp >= 0x3400 and cp <= 0x4DBF)
|
||||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
or (cp >= 0x20000 and cp <= 0x2A6DF)
|
||||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
or (cp >= 0x2A700 and cp <= 0x2B73F)
|
||||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
or (cp >= 0x2B740 and cp <= 0x2B81F)
|
||||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
or (cp >= 0x2B820 and cp <= 0x2CEAF)
|
||||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
||||||
): #
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -248,14 +248,14 @@ class BasicTokenizer:
|
|||||||
# like the all of the other languages.
|
# like the all of the other languages.
|
||||||
if (
|
if (
|
||||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
or (cp >= 0x3400 and cp <= 0x4DBF)
|
||||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
or (cp >= 0x20000 and cp <= 0x2A6DF)
|
||||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
or (cp >= 0x2A700 and cp <= 0x2B73F)
|
||||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
or (cp >= 0x2B740 and cp <= 0x2B81F)
|
||||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
or (cp >= 0x2B820 and cp <= 0x2CEAF)
|
||||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
||||||
): #
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -367,11 +367,11 @@ class ImageGPTAttention(nn.Module):
|
|||||||
|
|
||||||
if layer_past is not None and is_updated:
|
if layer_past is not None and is_updated:
|
||||||
# reuse k,v, cross_attentions, and compute only q
|
# reuse k,v, cross_attentions, and compute only q
|
||||||
query = query = self.q_attn(hidden_states)
|
query = self.q_attn(hidden_states)
|
||||||
key = curr_past_key_value.key_cache[self.layer_idx]
|
key = curr_past_key_value.key_cache[self.layer_idx]
|
||||||
value = curr_past_key_value.value_cache[self.layer_idx]
|
value = curr_past_key_value.value_cache[self.layer_idx]
|
||||||
else:
|
else:
|
||||||
query = query = self.q_attn(hidden_states)
|
query = self.q_attn(hidden_states)
|
||||||
key, value = self.c_attn(current_states).split(self.split_size, dim=2)
|
key, value = self.c_attn(current_states).split(self.split_size, dim=2)
|
||||||
key = key.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
key = key.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
value = value.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
value = value.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|||||||
@@ -157,7 +157,7 @@ class TFLayoutLMEmbeddings(keras.layers.Layer):
|
|||||||
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
|
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
|
||||||
|
|
||||||
if bbox is None:
|
if bbox is None:
|
||||||
bbox = bbox = tf.fill(input_shape + [4], value=0)
|
bbox = tf.fill(input_shape + [4], value=0)
|
||||||
try:
|
try:
|
||||||
left_position_embeddings = tf.gather(self.x_position_embeddings, bbox[:, :, 0])
|
left_position_embeddings = tf.gather(self.x_position_embeddings, bbox[:, :, 0])
|
||||||
upper_position_embeddings = tf.gather(self.y_position_embeddings, bbox[:, :, 1])
|
upper_position_embeddings = tf.gather(self.y_position_embeddings, bbox[:, :, 1])
|
||||||
|
|||||||
@@ -396,14 +396,14 @@ class BasicTokenizer:
|
|||||||
# like the all of the other languages.
|
# like the all of the other languages.
|
||||||
if (
|
if (
|
||||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
or (cp >= 0x3400 and cp <= 0x4DBF)
|
||||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
or (cp >= 0x20000 and cp <= 0x2A6DF)
|
||||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
or (cp >= 0x2A700 and cp <= 0x2B73F)
|
||||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
or (cp >= 0x2B740 and cp <= 0x2B81F)
|
||||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
or (cp >= 0x2B820 and cp <= 0x2CEAF)
|
||||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
||||||
): #
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -1458,14 +1458,14 @@ class BasicTokenizer:
|
|||||||
# like the all of the other languages.
|
# like the all of the other languages.
|
||||||
if (
|
if (
|
||||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
or (cp >= 0x3400 and cp <= 0x4DBF)
|
||||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
or (cp >= 0x20000 and cp <= 0x2A6DF)
|
||||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
or (cp >= 0x2A700 and cp <= 0x2B73F)
|
||||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
or (cp >= 0x2B740 and cp <= 0x2B81F)
|
||||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
or (cp >= 0x2B820 and cp <= 0x2CEAF)
|
||||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
||||||
): #
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -395,14 +395,14 @@ class BasicTokenizer:
|
|||||||
# like the all of the other languages.
|
# like the all of the other languages.
|
||||||
if (
|
if (
|
||||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
or (cp >= 0x3400 and cp <= 0x4DBF)
|
||||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
or (cp >= 0x20000 and cp <= 0x2A6DF)
|
||||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
or (cp >= 0x2A700 and cp <= 0x2B73F)
|
||||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
or (cp >= 0x2B740 and cp <= 0x2B81F)
|
||||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
or (cp >= 0x2B820 and cp <= 0x2CEAF)
|
||||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
||||||
): #
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -397,14 +397,14 @@ class BasicTokenizer:
|
|||||||
# like the all of the other languages.
|
# like the all of the other languages.
|
||||||
if (
|
if (
|
||||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
or (cp >= 0x3400 and cp <= 0x4DBF)
|
||||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
or (cp >= 0x20000 and cp <= 0x2A6DF)
|
||||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
or (cp >= 0x2A700 and cp <= 0x2B73F)
|
||||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
or (cp >= 0x2B740 and cp <= 0x2B81F)
|
||||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
or (cp >= 0x2B820 and cp <= 0x2CEAF)
|
||||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
||||||
): #
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -450,14 +450,14 @@ class BasicTokenizer:
|
|||||||
# like the all of the other languages.
|
# like the all of the other languages.
|
||||||
if (
|
if (
|
||||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
or (cp >= 0x3400 and cp <= 0x4DBF)
|
||||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
or (cp >= 0x20000 and cp <= 0x2A6DF)
|
||||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
or (cp >= 0x2A700 and cp <= 0x2B73F)
|
||||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
or (cp >= 0x2B740 and cp <= 0x2B81F)
|
||||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
or (cp >= 0x2B820 and cp <= 0x2CEAF)
|
||||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
||||||
): #
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -178,14 +178,14 @@ class BasicTokenizer:
|
|||||||
# like the all of the other languages.
|
# like the all of the other languages.
|
||||||
if (
|
if (
|
||||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
or (cp >= 0x3400 and cp <= 0x4DBF)
|
||||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
or (cp >= 0x20000 and cp <= 0x2A6DF)
|
||||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
or (cp >= 0x2A700 and cp <= 0x2B73F)
|
||||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
or (cp >= 0x2B740 and cp <= 0x2B81F)
|
||||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
or (cp >= 0x2B820 and cp <= 0x2CEAF)
|
||||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
||||||
): #
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -314,7 +314,7 @@ def convert_clip_backbone(flax_params, torch_config):
|
|||||||
# Copy flax CLIP backbone params to PyTorch params
|
# Copy flax CLIP backbone params to PyTorch params
|
||||||
for name, param in new_torch_params.items():
|
for name, param in new_torch_params.items():
|
||||||
if name in torch_clip_params.keys():
|
if name in torch_clip_params.keys():
|
||||||
new_param = torch.from_numpy(new_torch_params[name])
|
new_param = torch.from_numpy(param)
|
||||||
torch_clip_params[name].copy_(new_param)
|
torch_clip_params[name].copy_(new_param)
|
||||||
else:
|
else:
|
||||||
attn_params[name] = param
|
attn_params[name] = param
|
||||||
|
|||||||
@@ -291,9 +291,6 @@ def convert_paligemma_checkpoint(
|
|||||||
processor.save_pretrained(pytorch_dump_folder_path)
|
processor.save_pretrained(pytorch_dump_folder_path)
|
||||||
|
|
||||||
|
|
||||||
#
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -93,11 +93,11 @@ def rename_and_convert_flax_params(flax_dict):
|
|||||||
|
|
||||||
converted_torch_dict = {}
|
converted_torch_dict = {}
|
||||||
# convert converted_dict into torch format
|
# convert converted_dict into torch format
|
||||||
for key in converted_dict.keys():
|
for key, value in converted_dict.items():
|
||||||
if ("embed_tokens" not in key) and ("embedder" not in key):
|
if ("embed_tokens" not in key) and ("embedder" not in key):
|
||||||
converted_torch_dict[key] = torch.from_numpy(converted_dict[key].T)
|
converted_torch_dict[key] = torch.from_numpy(value.T)
|
||||||
else:
|
else:
|
||||||
converted_torch_dict[key] = torch.from_numpy(converted_dict[key])
|
converted_torch_dict[key] = torch.from_numpy(value)
|
||||||
|
|
||||||
return converted_torch_dict
|
return converted_torch_dict
|
||||||
|
|
||||||
|
|||||||
@@ -174,14 +174,14 @@ class BasicTokenizer:
|
|||||||
# like the all of the other languages.
|
# like the all of the other languages.
|
||||||
if (
|
if (
|
||||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
or (cp >= 0x3400 and cp <= 0x4DBF)
|
||||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
or (cp >= 0x20000 and cp <= 0x2A6DF)
|
||||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
or (cp >= 0x2A700 and cp <= 0x2B73F)
|
||||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
or (cp >= 0x2B740 and cp <= 0x2B81F)
|
||||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
or (cp >= 0x2B820 and cp <= 0x2CEAF)
|
||||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
||||||
): #
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -1005,14 +1005,14 @@ class RoCBertBasicTokenizer:
|
|||||||
# like the all of the other languages.
|
# like the all of the other languages.
|
||||||
if (
|
if (
|
||||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
or (cp >= 0x3400 and cp <= 0x4DBF)
|
||||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
or (cp >= 0x20000 and cp <= 0x2A6DF)
|
||||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
or (cp >= 0x2A700 and cp <= 0x2B73F)
|
||||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
or (cp >= 0x2B740 and cp <= 0x2B81F)
|
||||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
or (cp >= 0x2B820 and cp <= 0x2CEAF)
|
||||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
||||||
): #
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -186,14 +186,14 @@ class BasicTokenizer:
|
|||||||
# like the all of the other languages.
|
# like the all of the other languages.
|
||||||
if (
|
if (
|
||||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
or (cp >= 0x3400 and cp <= 0x4DBF)
|
||||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
or (cp >= 0x20000 and cp <= 0x2A6DF)
|
||||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
or (cp >= 0x2A700 and cp <= 0x2B73F)
|
||||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
or (cp >= 0x2B740 and cp <= 0x2B81F)
|
||||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
or (cp >= 0x2B820 and cp <= 0x2CEAF)
|
||||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
||||||
): #
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -421,14 +421,14 @@ class BasicTokenizer:
|
|||||||
# like the all of the other languages.
|
# like the all of the other languages.
|
||||||
if (
|
if (
|
||||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
or (cp >= 0x3400 and cp <= 0x4DBF)
|
||||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
or (cp >= 0x20000 and cp <= 0x2A6DF)
|
||||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
or (cp >= 0x2A700 and cp <= 0x2B73F)
|
||||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
or (cp >= 0x2B740 and cp <= 0x2B81F)
|
||||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
or (cp >= 0x2B820 and cp <= 0x2CEAF)
|
||||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
||||||
): #
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -396,14 +396,14 @@ class BasicTokenizer:
|
|||||||
# like the all of the other languages.
|
# like the all of the other languages.
|
||||||
if (
|
if (
|
||||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
or (cp >= 0x3400 and cp <= 0x4DBF)
|
||||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
or (cp >= 0x20000 and cp <= 0x2A6DF)
|
||||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
or (cp >= 0x2A700 and cp <= 0x2B73F)
|
||||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
or (cp >= 0x2B740 and cp <= 0x2B81F)
|
||||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
or (cp >= 0x2B820 and cp <= 0x2CEAF)
|
||||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
||||||
): #
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -89,9 +89,9 @@ def shard_on_the_fly(switch_checkpoint_path, dump_path, max_shard_size, dtype, w
|
|||||||
else:
|
else:
|
||||||
all_layers[curr_real_layer_name] = {split_layer[-1]: content}
|
all_layers[curr_real_layer_name] = {split_layer[-1]: content}
|
||||||
|
|
||||||
for key in all_layers.keys():
|
for key, layer in all_layers.items():
|
||||||
# open tensorstore file
|
# open tensorstore file
|
||||||
raw_weights = ts.open(unflatten_dict(all_layers[key])).result().read().result()
|
raw_weights = ts.open(unflatten_dict(layer)).result().read().result()
|
||||||
raw_weights = torch.tensor(raw_weights)
|
raw_weights = torch.tensor(raw_weights)
|
||||||
weight_size = raw_weights.numel() * raw_weights.element_size()
|
weight_size = raw_weights.numel() * raw_weights.element_size()
|
||||||
|
|
||||||
|
|||||||
@@ -2123,14 +2123,14 @@ class BasicTokenizer:
|
|||||||
# like the all of the other languages.
|
# like the all of the other languages.
|
||||||
if (
|
if (
|
||||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
or (cp >= 0x3400 and cp <= 0x4DBF)
|
||||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
or (cp >= 0x20000 and cp <= 0x2A6DF)
|
||||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
or (cp >= 0x2A700 and cp <= 0x2B73F)
|
||||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
or (cp >= 0x2B740 and cp <= 0x2B81F)
|
||||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
or (cp >= 0x2B820 and cp <= 0x2CEAF)
|
||||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
||||||
): #
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ def t5x_relpos_bias_lookup(params, i, prefix):
|
|||||||
|
|
||||||
def t5x_attention_lookup(params, i, prefix, layer_name="attention"):
|
def t5x_attention_lookup(params, i, prefix, layer_name="attention"):
|
||||||
"""Returns the KOQV parameters of (self-)attention. Does not transpose."""
|
"""Returns the KOQV parameters of (self-)attention. Does not transpose."""
|
||||||
k_tmp = k_tmp = np.ascontiguousarray(params[f"{prefix}/{prefix}/{layer_name}/key/kernel"][:, i, :, :])
|
k_tmp = np.ascontiguousarray(params[f"{prefix}/{prefix}/{layer_name}/key/kernel"][:, i, :, :])
|
||||||
k = k_tmp.reshape(k_tmp.shape[0], k_tmp.shape[1] * k_tmp.shape[2])
|
k = k_tmp.reshape(k_tmp.shape[0], k_tmp.shape[1] * k_tmp.shape[2])
|
||||||
o_tmp = np.ascontiguousarray(params[f"{prefix}/{prefix}/{layer_name}/out/kernel"][:, i, :, :])
|
o_tmp = np.ascontiguousarray(params[f"{prefix}/{prefix}/{layer_name}/out/kernel"][:, i, :, :])
|
||||||
o = o_tmp.reshape(o_tmp.shape[0] * o_tmp.shape[1], o_tmp.shape[2])
|
o = o_tmp.reshape(o_tmp.shape[0] * o_tmp.shape[1], o_tmp.shape[2])
|
||||||
|
|||||||
@@ -1184,7 +1184,7 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
used_keys = set()
|
used_keys = set()
|
||||||
|
|
||||||
# get defaults from set model processor kwargs if they exist
|
# get defaults from set model processor kwargs if they exist
|
||||||
for modality in default_kwargs:
|
for modality in default_kwargs: # noqa: PLC0206
|
||||||
default_kwargs[modality] = ModelProcessorKwargs._defaults.get(modality, {}).copy()
|
default_kwargs[modality] = ModelProcessorKwargs._defaults.get(modality, {}).copy()
|
||||||
# update defaults with arguments from tokenizer init
|
# update defaults with arguments from tokenizer init
|
||||||
for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__.keys():
|
for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__.keys():
|
||||||
@@ -1202,7 +1202,7 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
|
|
||||||
# update modality kwargs with passed kwargs
|
# update modality kwargs with passed kwargs
|
||||||
non_modality_kwargs = set(kwargs) - set(output_kwargs)
|
non_modality_kwargs = set(kwargs) - set(output_kwargs)
|
||||||
for modality in output_kwargs:
|
for modality, output_kwarg in output_kwargs.items():
|
||||||
for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__.keys():
|
for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__.keys():
|
||||||
# check if we received a structured kwarg dict or not to handle it correctly
|
# check if we received a structured kwarg dict or not to handle it correctly
|
||||||
if modality in kwargs:
|
if modality in kwargs:
|
||||||
@@ -1220,7 +1220,7 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
else:
|
else:
|
||||||
kwarg_value = "__empty__"
|
kwarg_value = "__empty__"
|
||||||
if not isinstance(kwarg_value, str) or kwarg_value != "__empty__":
|
if not isinstance(kwarg_value, str) or kwarg_value != "__empty__":
|
||||||
output_kwargs[modality][modality_key] = kwarg_value
|
output_kwarg[modality_key] = kwarg_value
|
||||||
used_keys.add(modality_key)
|
used_keys.add(modality_key)
|
||||||
|
|
||||||
# Determine if kwargs is a flat dictionary or contains nested dictionaries
|
# Determine if kwargs is a flat dictionary or contains nested dictionaries
|
||||||
@@ -1234,18 +1234,18 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
used_keys.add(subkey)
|
used_keys.add(subkey)
|
||||||
else:
|
else:
|
||||||
# kwargs is a flat dictionary
|
# kwargs is a flat dictionary
|
||||||
for key in kwargs:
|
for key, kwarg in kwargs.items():
|
||||||
if key not in used_keys:
|
if key not in used_keys:
|
||||||
if key in ModelProcessorKwargs.__annotations__["common_kwargs"].__annotations__.keys():
|
if key in ModelProcessorKwargs.__annotations__["common_kwargs"].__annotations__.keys():
|
||||||
output_kwargs["common_kwargs"][key] = kwargs[key]
|
output_kwargs["common_kwargs"][key] = kwarg
|
||||||
elif key not in possible_modality_keywords:
|
elif key not in possible_modality_keywords:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
f"Keyword argument `{key}` is not a valid argument for this processor and will be ignored."
|
f"Keyword argument `{key}` is not a valid argument for this processor and will be ignored."
|
||||||
)
|
)
|
||||||
|
|
||||||
# all modality-specific kwargs are updated with common kwargs
|
# all modality-specific kwargs are updated with common kwargs
|
||||||
for modality in output_kwargs:
|
for kwarg in output_kwargs.values():
|
||||||
output_kwargs[modality].update(output_kwargs["common_kwargs"])
|
kwarg.update(output_kwargs["common_kwargs"])
|
||||||
return output_kwargs
|
return output_kwargs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -253,8 +253,8 @@ class HqqHfQuantizer(HfQuantizer):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Step 1: populate module with weight/bias from module state dict
|
# Step 1: populate module with weight/bias from module state dict
|
||||||
for key in module_state_dict:
|
for key, tensor in module_state_dict.items():
|
||||||
setattr(module, key, torch.nn.Parameter(module_state_dict[key]))
|
setattr(module, key, torch.nn.Parameter(tensor))
|
||||||
|
|
||||||
# Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module
|
# Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module
|
||||||
# directly doesn't work.
|
# directly doesn't work.
|
||||||
|
|||||||
@@ -118,7 +118,7 @@ def generate_attention_matrix_from_mask(
|
|||||||
colored_word = f"{YELLOW}{word_repr}{RESET}" if img_token in word else word_repr
|
colored_word = f"{YELLOW}{word_repr}{RESET}" if img_token in word else word_repr
|
||||||
row_display = " ".join(
|
row_display = " ".join(
|
||||||
f"{YELLOW}{BLACK_SQUARE}{RESET}"
|
f"{YELLOW}{BLACK_SQUARE}{RESET}"
|
||||||
if img_token in words[j] and mask[i, j] and img_token in words[i]
|
if img_token in words[j] and mask[i, j] and img_token in word
|
||||||
else f"{GREEN}{BLACK_SQUARE}{RESET}"
|
else f"{GREEN}{BLACK_SQUARE}{RESET}"
|
||||||
if i == j
|
if i == j
|
||||||
else BLACK_SQUARE
|
else BLACK_SQUARE
|
||||||
@@ -130,9 +130,7 @@ def generate_attention_matrix_from_mask(
|
|||||||
if sliding_window is not None:
|
if sliding_window is not None:
|
||||||
sliding_window_row = " ".join(
|
sliding_window_row = " ".join(
|
||||||
f"{YELLOW}{BLACK_SQUARE}{RESET}"
|
f"{YELLOW}{BLACK_SQUARE}{RESET}"
|
||||||
if img_token in words[j]
|
if img_token in words[j] and img_token in word and token_type_buckets[0, i] == token_type_buckets[0, j]
|
||||||
and img_token in words[i]
|
|
||||||
and token_type_buckets[0, i] == token_type_buckets[0, j]
|
|
||||||
else f"{GREEN}{BLACK_SQUARE}{RESET}"
|
else f"{GREEN}{BLACK_SQUARE}{RESET}"
|
||||||
if i == j
|
if i == j
|
||||||
else BLACK_SQUARE
|
else BLACK_SQUARE
|
||||||
|
|||||||
@@ -247,7 +247,7 @@ class GroupViTVisionModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
list(self_attentions[i].shape[-2:]),
|
list(self_attn.shape[-2:]),
|
||||||
[
|
[
|
||||||
self.model_tester.num_output_groups[i],
|
self.model_tester.num_output_groups[i],
|
||||||
self.model_tester.num_output_groups[i - 1] if i > 0 else seq_len,
|
self.model_tester.num_output_groups[i - 1] if i > 0 else seq_len,
|
||||||
|
|||||||
@@ -1375,9 +1375,7 @@ class TokenizerTesterMixin:
|
|||||||
self.assertEqual(output_pt["assistant_masks"].shape, output_pt["input_ids"].shape)
|
self.assertEqual(output_pt["assistant_masks"].shape, output_pt["input_ids"].shape)
|
||||||
|
|
||||||
for i, conv in enumerate(conversations):
|
for i, conv in enumerate(conversations):
|
||||||
chat_string = tokenizer_r.apply_chat_template(
|
chat_string = tokenizer_r.apply_chat_template(conv, tokenize=False, chat_template=dummy_template)
|
||||||
conversations[i], tokenize=False, chat_template=dummy_template
|
|
||||||
)
|
|
||||||
assistant_start = output.char_to_token(i, chat_string.index(assistant_prefix_suffix[i][0][0]))
|
assistant_start = output.char_to_token(i, chat_string.index(assistant_prefix_suffix[i][0][0]))
|
||||||
assistant_end = output.char_to_token(
|
assistant_end = output.char_to_token(
|
||||||
i,
|
i,
|
||||||
|
|||||||
@@ -97,6 +97,7 @@ git bisect run python3 target_script.py
|
|||||||
|
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
["bash", "run_git_bisect.sh"],
|
["bash", "run_git_bisect.sh"],
|
||||||
|
check=False,
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
text=True,
|
text=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ def get_runner_status(target_runners, token):
|
|||||||
"https://api.github.com/repos/huggingface/transformers/actions/runners",
|
"https://api.github.com/repos/huggingface/transformers/actions/runners",
|
||||||
]
|
]
|
||||||
|
|
||||||
output = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE)
|
output = subprocess.run(cmd, check=False, shell=True, stdout=subprocess.PIPE)
|
||||||
o = output.stdout.decode("utf-8")
|
o = output.stdout.decode("utf-8")
|
||||||
status = json.loads(o)
|
status = json.loads(o)
|
||||||
|
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ def get_prs_by_label(label):
|
|||||||
"--limit",
|
"--limit",
|
||||||
"100",
|
"100",
|
||||||
]
|
]
|
||||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
result = subprocess.run(cmd, check=False, capture_output=True, text=True)
|
||||||
result.check_returncode()
|
result.check_returncode()
|
||||||
prs = json.loads(result.stdout)
|
prs = json.loads(result.stdout)
|
||||||
for pr in prs:
|
for pr in prs:
|
||||||
@@ -97,7 +97,9 @@ def get_prs_by_label(label):
|
|||||||
|
|
||||||
def get_commit_timestamp(commit_sha):
|
def get_commit_timestamp(commit_sha):
|
||||||
"""Get UNIX timestamp of a commit using git."""
|
"""Get UNIX timestamp of a commit using git."""
|
||||||
result = subprocess.run(["git", "show", "-s", "--format=%ct", commit_sha], capture_output=True, text=True)
|
result = subprocess.run(
|
||||||
|
["git", "show", "-s", "--format=%ct", commit_sha], check=False, capture_output=True, text=True
|
||||||
|
)
|
||||||
result.check_returncode()
|
result.check_returncode()
|
||||||
return int(result.stdout.strip())
|
return int(result.stdout.strip())
|
||||||
|
|
||||||
@@ -115,6 +117,7 @@ def commit_in_history(commit_sha, base_branch="HEAD"):
|
|||||||
"""Return True if commit is already part of base_branch history."""
|
"""Return True if commit is already part of base_branch history."""
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
["git", "merge-base", "--is-ancestor", commit_sha, base_branch],
|
["git", "merge-base", "--is-ancestor", commit_sha, base_branch],
|
||||||
|
check=False,
|
||||||
stdout=subprocess.DEVNULL,
|
stdout=subprocess.DEVNULL,
|
||||||
stderr=subprocess.DEVNULL,
|
stderr=subprocess.DEVNULL,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user