Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2e752ead46 | ||
|
|
785b5cf444 | ||
|
|
3b09464364 | ||
|
|
b00807fac2 | ||
|
|
612bfd0801 |
@@ -506,8 +506,6 @@
|
||||
title: MobileBERT
|
||||
- local: model_doc/modernbert
|
||||
title: ModernBert
|
||||
- local: model_doc/moonshine
|
||||
title: moonshine
|
||||
- local: model_doc/mpnet
|
||||
title: MPNet
|
||||
- local: model_doc/mpt
|
||||
@@ -770,6 +768,8 @@
|
||||
title: Mimi
|
||||
- local: model_doc/mms
|
||||
title: MMS
|
||||
- local: model_doc/moonshine
|
||||
title: Moonshine
|
||||
- local: model_doc/moshi
|
||||
title: Moshi
|
||||
- local: model_doc/musicgen
|
||||
|
||||
2
setup.py
2
setup.py
@@ -437,7 +437,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.48.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.48.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
||||
author_email="transformers@huggingface.co",
|
||||
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.48.0"
|
||||
__version__ = "4.48.1"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@@ -1,448 +0,0 @@
|
||||
# Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, Optional
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
Emu3Config,
|
||||
Emu3ForConditionalGeneration,
|
||||
Emu3ImageProcessor,
|
||||
Emu3Processor,
|
||||
Emu3TextConfig,
|
||||
GenerationConfig,
|
||||
)
|
||||
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
|
||||
|
||||
|
||||
"""
|
||||
Sample usage:
|
||||
|
||||
```
|
||||
python src/transformers/models/emu3/convert_emu3_weights_to_hf.py \
|
||||
--vq_model_id BAAI/Emu3-VisionTokenizer --llm_model_id BAAI/Emu3-Chat --output_dir /output/path
|
||||
```
|
||||
|
||||
Thereafter, models can be loaded via:
|
||||
|
||||
```py
|
||||
from transformers import Emu3ForConditionalGeneration, Emu3Processor
|
||||
|
||||
model = Emu3ForConditionalGeneration.from_pretrained("/output/path")
|
||||
processor = Emu3Processor.from_pretrained("/output/path")
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
|
||||
byte_encoder = bytes_to_unicode()
|
||||
CHAT_TEMPLATE = "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}"
|
||||
|
||||
|
||||
# Tiktoken to HF conversion, thanks for Xenova
|
||||
def token_bytes_to_string(b):
|
||||
return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
|
||||
|
||||
|
||||
# Adapted from https://github.com/openai/tiktoken/issues/60#issuecomment-1499977960
|
||||
def bpe(mergeable_ranks: Dict[bytes, int], token: bytes, max_rank: Optional[int] = None):
|
||||
parts = [bytes([b]) for b in token]
|
||||
while True:
|
||||
min_idx = None
|
||||
min_rank = None
|
||||
for i, pair in enumerate(zip(parts[:-1], parts[1:])):
|
||||
rank = mergeable_ranks.get(pair[0] + pair[1])
|
||||
if rank is not None and (min_rank is None or rank < min_rank):
|
||||
min_idx = i
|
||||
min_rank = rank
|
||||
if min_rank is None or (max_rank is not None and min_rank >= max_rank):
|
||||
break
|
||||
assert min_idx is not None
|
||||
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2 :]
|
||||
return parts
|
||||
|
||||
|
||||
def generate_vocab_and_merges(encoder):
|
||||
mergeable_ranks = encoder._mergeable_ranks
|
||||
|
||||
merges = []
|
||||
vocab = {}
|
||||
for token, rank in mergeable_ranks.items():
|
||||
vocab[token_bytes_to_string(token)] = rank
|
||||
|
||||
if len(token) == 1:
|
||||
continue
|
||||
merged = tuple(bpe(mergeable_ranks, token, max_rank=rank))
|
||||
assert len(merged) == 2
|
||||
merges.append(" ".join(map(token_bytes_to_string, merged)))
|
||||
|
||||
# Also add special tokens
|
||||
vocab.update(encoder._special_tokens)
|
||||
return vocab, merges
|
||||
|
||||
|
||||
def convert_tiktoken(tokenizer, output_dir):
|
||||
encoder = tokenizer.tokenizer
|
||||
vocab, merges = generate_vocab_and_merges(encoder)
|
||||
added_tokens = [
|
||||
{
|
||||
"id": id,
|
||||
"content": content,
|
||||
"single_word": False,
|
||||
"lstrip": False,
|
||||
"rstrip": False,
|
||||
"normalized": False,
|
||||
"special": True,
|
||||
}
|
||||
for content, id in encoder._special_tokens.items()
|
||||
if content != "<|extra_0|>"
|
||||
]
|
||||
|
||||
# https://huggingface.co/Xenova/gpt2/raw/main/tokenizer_config.json
|
||||
tokenizer_config_template = {
|
||||
"add_prefix_space": False,
|
||||
"bos_token": "<|extra_203|>",
|
||||
"clean_up_tokenization_spaces": False,
|
||||
"eos_token": "<|extra_204|>",
|
||||
"pad_token": "<|endoftext|>",
|
||||
}
|
||||
tokenizer_config_template.update({"tokenizer_class": "GPT2Tokenizer"})
|
||||
tokenizer_config_template = dict(sorted(tokenizer_config_template.items(), key=lambda x: x[0]))
|
||||
|
||||
# add placeholder image token by taking one of the reserved tokens
|
||||
reserved_token_id = vocab["<|extra_0|>"]
|
||||
vocab["<image>"] = reserved_token_id
|
||||
del vocab["<|extra_0|>"]
|
||||
added_tokens.append(
|
||||
{
|
||||
"id": reserved_token_id,
|
||||
"content": "<image>",
|
||||
"single_word": False,
|
||||
"lstrip": False,
|
||||
"rstrip": False,
|
||||
"normalized": False,
|
||||
"special": True,
|
||||
}
|
||||
)
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
pre_tokenizer = {
|
||||
"type": "ByteLevel",
|
||||
"add_prefix_space": False,
|
||||
"trim_offsets": True,
|
||||
"use_regex": True,
|
||||
}
|
||||
|
||||
# https://huggingface.co/Xenova/gpt2/raw/main/tokenizer.json
|
||||
tokenizer_template = {
|
||||
"version": "1.0",
|
||||
"truncation": None,
|
||||
"padding": None,
|
||||
"added_tokens": added_tokens,
|
||||
"normalizer": None,
|
||||
"pre_tokenizer": pre_tokenizer,
|
||||
"post_processor": None,
|
||||
"decoder": {
|
||||
"type": "ByteLevel",
|
||||
"add_prefix_space": True,
|
||||
"trim_offsets": True,
|
||||
"use_regex": True,
|
||||
},
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"dropout": None,
|
||||
"unk_token": None,
|
||||
"continuing_subword_prefix": "",
|
||||
"end_of_word_suffix": "",
|
||||
"fuse_unk": False,
|
||||
"byte_fallback": False,
|
||||
"vocab": vocab,
|
||||
"merges": merges,
|
||||
},
|
||||
}
|
||||
|
||||
# Save to files
|
||||
with open(os.path.join(output_dir, "vocab.json"), "w", encoding="utf-8") as fp:
|
||||
json.dump(vocab, fp, indent=2, ensure_ascii=False)
|
||||
|
||||
with open(os.path.join(output_dir, "tokenizer.json"), "w", encoding="utf-8") as fp:
|
||||
json.dump(tokenizer_template, fp, indent=2, ensure_ascii=False)
|
||||
|
||||
with open(os.path.join(output_dir, "tokenizer_config.json"), "w", encoding="utf-8") as fp:
|
||||
json.dump(tokenizer_config_template, fp, indent=2, ensure_ascii=False)
|
||||
|
||||
with open(os.path.join(output_dir, "special_tokens_map.json"), "w", encoding="utf-8") as fp:
|
||||
json.dump(
|
||||
{
|
||||
"bos_token": "<|extra_203|>",
|
||||
"eos_token": "<|extra_204|>",
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
fp,
|
||||
indent=2,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
with open(os.path.join(output_dir, "merges.txt"), "w", encoding="utf-8") as fp:
|
||||
fp.write("#version: 0.2\n")
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
|
||||
KEYS_TO_MODIFY_MAPPING = {
|
||||
"^encoder": "model.vqmodel.encoder",
|
||||
"^decoder": "model.vqmodel.decoder",
|
||||
"^post_quant_conv": "model.vqmodel.post_quant_conv",
|
||||
"^quant_conv": "model.vqmodel.quant_conv",
|
||||
"^quantize": "model.vqmodel.quantize",
|
||||
"^model": "text_model.model",
|
||||
r"lm_head\.weight": "text_model.lm_head.weight",
|
||||
r"^text_model\.model\.vqmodel": "vqmodel",
|
||||
# rename QKV proj for the VQ-VAE model because we use SiglipAttention
|
||||
r"\.q\.": ".q_proj.",
|
||||
r"\.k\.": ".k_proj.",
|
||||
r"\.v\.": ".v_proj.",
|
||||
r"\.proj_out\.": ".out_proj.",
|
||||
# move the attention norms outside of attention modules
|
||||
r"mid\.attn_1\.norm\.": "mid.attn_norm.",
|
||||
r"attn\.0\.norm\.": "attn_norms.0.",
|
||||
r"attn\.1\.norm\.": "attn_norms.1.",
|
||||
r"attn\.2\.norm\.": "attn_norms.2.",
|
||||
r"attn\.3\.norm\.": "attn_norms.3.",
|
||||
# isolate down/mid/up into separate classes for readability
|
||||
r"\.down\.": ".down_block.down.",
|
||||
r"\.up\.": ".up_block.up.",
|
||||
r"\.mid\.": ".middle_block.",
|
||||
}
|
||||
|
||||
|
||||
def convert_state_dict_to_hf(old_state_dict, new_state_dict):
|
||||
for key, value in old_state_dict.items():
|
||||
# convert conv layers in attn to linear
|
||||
if (
|
||||
any(key.endswith(name) for name in ["q.weight", "k.weight", "v.weight", "proj_out.weight"])
|
||||
and value.ndim == 4
|
||||
):
|
||||
value = value.squeeze()
|
||||
|
||||
for old_pattern, new_pattern in KEYS_TO_MODIFY_MAPPING.items():
|
||||
key = re.sub(old_pattern, new_pattern, key)
|
||||
|
||||
new_state_dict[key] = value
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_model(vq_model_id, llm_model_id, output_dir, hub_model_id=None, test_inference=False):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Convert and save processor
|
||||
tokenizer_tiktoken = AutoTokenizer.from_pretrained(llm_model_id, trust_remote_code=True)
|
||||
convert_tiktoken(tokenizer_tiktoken, output_dir)
|
||||
extra_special_tokens = extra_special_tokens = {
|
||||
"image_token": "<image>",
|
||||
"boi_token": "<|image start|>",
|
||||
"eoi_token": "<|image end|>",
|
||||
"image_wrapper_token": "<|image token|>",
|
||||
"eof_token": "<|extra_201|>",
|
||||
}
|
||||
tokenizer_converted = AutoTokenizer.from_pretrained(output_dir, extra_special_tokens=extra_special_tokens)
|
||||
tokenizer_converted.padding_side = "left"
|
||||
|
||||
image_processor = Emu3ImageProcessor.from_pretrained(vq_model_id)
|
||||
processor = Emu3Processor(image_processor, tokenizer_converted, chat_template=CHAT_TEMPLATE)
|
||||
processor.save_pretrained(output_dir)
|
||||
|
||||
# load models
|
||||
model_llm = AutoModelForCausalLM.from_pretrained(
|
||||
llm_model_id,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
model_vqgan = AutoModel.from_pretrained(vq_model_id, trust_remote_code=True)
|
||||
with open(f"{output_dir}/tokenizer.json", "r") as file:
|
||||
tokenizer_config = json.load(file)
|
||||
vocabulary_map = tokenizer_config["model"]["vocab"]
|
||||
|
||||
text_config = Emu3TextConfig(
|
||||
max_position_embeddings=model_llm.config.max_position_embeddings,
|
||||
rope_scaling={"rope_type": "default"},
|
||||
)
|
||||
config = Emu3Config(text_config=text_config, vocabulary_map=vocabulary_map)
|
||||
|
||||
with init_empty_weights():
|
||||
model = Emu3ForConditionalGeneration(config=config)
|
||||
model.generation_config = GenerationConfig(
|
||||
do_sample=True,
|
||||
top_k=2048,
|
||||
max_new_tokens=50_000,
|
||||
pad_token_id=processor.tokenizer.pad_token_id,
|
||||
eos_token_id=processor.tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
state_dict = {}
|
||||
state_dict = convert_state_dict_to_hf(model_llm.state_dict(), state_dict)
|
||||
state_dict = convert_state_dict_to_hf(model_vqgan.state_dict(), state_dict)
|
||||
|
||||
model.load_state_dict(state_dict, assign=True, strict=True)
|
||||
model.save_pretrained(output_dir, safe_serialization=True)
|
||||
|
||||
if hub_model_id is not None:
|
||||
model.push_to_hub(hub_model_id)
|
||||
processor.push_to_hub(hub_model_id)
|
||||
|
||||
if test_inference and llm_model_id.endswith("Chat"):
|
||||
# Short inference on a few examples to check if generation makes sense
|
||||
print("Loading the checkpoint in a Emu3 model...")
|
||||
print("*" * 100)
|
||||
model = Emu3ForConditionalGeneration.from_pretrained(output_dir, torch_dtype=torch.bfloat16, device_map="auto")
|
||||
processor = Emu3Processor.from_pretrained(output_dir)
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": "You are a helpful assistant."},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Please tell me about this art work and its artist."},
|
||||
{"type": "image"},
|
||||
],
|
||||
},
|
||||
]
|
||||
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
||||
|
||||
image = Image.open(
|
||||
requests.get(
|
||||
"https://uploads4.wikiart.org/images/paul-klee/death-for-the-idea-1915.jpg!Large.jpg", stream=True
|
||||
).raw
|
||||
)
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device, torch.bfloat16)
|
||||
length = inputs.input_ids.shape[1]
|
||||
|
||||
out = model.generate(**inputs, max_new_tokens=40, do_sample=False)
|
||||
generated_text = processor.batch_decode(out[:, length:], skip_special_tokens=True)[0]
|
||||
|
||||
print(f"Generation for single-image: {generated_text}")
|
||||
print("*" * 100)
|
||||
elif test_inference and llm_model_id.endswith("Gen"):
|
||||
processor = Emu3Processor.from_pretrained(output_dir)
|
||||
model = Emu3ForConditionalGeneration.from_pretrained(output_dir, torch_dtype=torch.bfloat16, device_map="auto")
|
||||
|
||||
inputs = processor(
|
||||
text=[
|
||||
"a portrait of young girl. masterpiece, film grained, best quality.",
|
||||
"a dog running under the rain",
|
||||
],
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
return_for_image_generation=True,
|
||||
)
|
||||
inputs = inputs.to(device="cuda:0", dtype=torch.bfloat16)
|
||||
|
||||
neg_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry."
|
||||
neg_inputs = processor(text=[neg_prompt] * 2, return_tensors="pt").to(device="cuda:0")
|
||||
|
||||
image_sizes = inputs.pop("image_sizes")
|
||||
HEIGHT, WIDTH = image_sizes[0]
|
||||
VISUAL_TOKENS = model.vocabulary_mapping.image_tokens
|
||||
|
||||
def prefix_allowed_tokens_fn(batch_id, input_ids):
|
||||
height, width = HEIGHT, WIDTH
|
||||
visual_tokens = VISUAL_TOKENS
|
||||
image_token_id = processor.tokenizer.encode("<|image token|>", return_tensors="pt")[0].to(model.device)
|
||||
eoi_token_id = processor.tokenizer.encode("<|image end|>", return_tensors="pt")[0]
|
||||
eos_token_id = processor.tokenizer.encode("<|extra_204|>", return_tensors="pt")[0]
|
||||
pad_token_id = processor.tokenizer.encode("<|endoftext|>", return_tensors="pt")[0]
|
||||
eol_token_id = processor.tokenizer.encode("<|extra_200|>", return_tensors="pt")[0]
|
||||
eof_token_id = processor.tokenizer.encode("<|extra_201|>", return_tensors="pt")[0]
|
||||
|
||||
position = torch.nonzero(input_ids == image_token_id, as_tuple=True)[0][0]
|
||||
offset = input_ids.shape[0] - position
|
||||
if offset % (width + 1) == 0:
|
||||
return (eol_token_id,)
|
||||
elif offset == (width + 1) * height + 1:
|
||||
return (eof_token_id,)
|
||||
elif offset == (width + 1) * height + 2:
|
||||
return (eoi_token_id,)
|
||||
elif offset == (width + 1) * height + 3:
|
||||
return (eos_token_id,)
|
||||
elif offset > (width + 1) * height + 3:
|
||||
return (pad_token_id,)
|
||||
else:
|
||||
return visual_tokens
|
||||
|
||||
out = model.generate(
|
||||
**inputs,
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
negative_prompt_ids=neg_inputs.input_ids,
|
||||
negative_prompt_attention_mask=neg_inputs.attention_mask,
|
||||
)
|
||||
|
||||
image = model.decode_image_tokens(out[:, inputs.input_ids.shape[1] :], height=HEIGHT, width=WIDTH)
|
||||
images = processor.postprocess(
|
||||
list(image.float()), return_tensors="PIL.Image.Image"
|
||||
) # internally we convert to np but it's not supported in bf16 precision
|
||||
for i, image in enumerate(images["pixel_values"]):
|
||||
image.save(f"result_{i}.png")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--vq_model_id",
|
||||
help="Model ID of Emu3 VQ-VAE on the hub",
|
||||
default="BAAI/Emu3-VisionTokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm_model_id",
|
||||
help="Model ID of Emu3 bacbone LLM on the hub",
|
||||
default="BAAI/Emu3-Chat",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
help="Location to write HF model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hub_model_id",
|
||||
help="Model ID in the hub where to push the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_inference",
|
||||
action="store_true",
|
||||
help="Whether to load the model for generation to test it's converted correctly.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_model(
|
||||
vq_model_id=args.vq_model_id,
|
||||
llm_model_id=args.llm_model_id,
|
||||
output_dir=args.output_dir,
|
||||
hub_model_id=args.hub_model_id,
|
||||
test_inference=args.test_inference,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,169 +0,0 @@
|
||||
# Copyright 2025 Useful Sensors and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import re
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers.models.moonshine.modeling_moonshine import MoonshineConfig, MoonshineForConditionalGeneration
|
||||
|
||||
|
||||
# Copied from https://github.com/usefulsensors/moonshine/blob/a1d77cc573b0471ac4602b86f67b3f48d67df1a9/moonshine/model.py
|
||||
def _get_weights(model_name):
|
||||
repo = "UsefulSensors/moonshine"
|
||||
|
||||
return (
|
||||
hf_hub_download(repo, f"{x}.weights.h5", subfolder=model_name) for x in ("preprocessor", "encoder", "decoder")
|
||||
)
|
||||
|
||||
|
||||
def _read_h5_weights(group, current_key="", weights={}):
|
||||
for key in group.keys():
|
||||
full_key = f"{current_key}.{key}" if current_key else key
|
||||
if isinstance(group[key], h5py.Dataset):
|
||||
w = np.array(group[key])
|
||||
w = torch.from_numpy(w)
|
||||
if len(w.shape) > 1:
|
||||
if len(w.shape) == 3:
|
||||
hidden_size = max(list(w.shape))
|
||||
try:
|
||||
w = w.reshape(hidden_size, hidden_size)
|
||||
except RuntimeError:
|
||||
# meaning its a conv layers
|
||||
pass
|
||||
w = w.transpose(0, -1)
|
||||
weights[full_key] = w
|
||||
else:
|
||||
_read_h5_weights(group[key], full_key, weights)
|
||||
return weights
|
||||
|
||||
|
||||
def _convert_layer_names(name, gated_mlp=False):
|
||||
name = re.sub(
|
||||
r"layers\.functional(?:_(\d+))?\.layers",
|
||||
lambda m: f'layers.{m.group(1) if m.group(1) else "0"}',
|
||||
name,
|
||||
count=1,
|
||||
)
|
||||
if gated_mlp:
|
||||
name = re.sub(r"functional\.layers\.dense\.", "mlp.fc1.", name)
|
||||
name = re.sub(r"functional\.layers\.dense_1\.", "mlp.fc2.", name)
|
||||
else:
|
||||
name = re.sub(r"functional\.layers\.sequential\.layers\.dense\.", "mlp.fc1.", name)
|
||||
name = re.sub(r"functional\.layers\.sequential\.layers\.dense_1\.", "mlp.fc2.", name)
|
||||
name = re.sub(r"layers\.sequential\.layers\.conv1d\.", "conv1.", name)
|
||||
name = re.sub(r"layers\.sequential\.layers\.conv1d_1\.", "conv2.", name)
|
||||
name = re.sub(r"layers\.sequential\.layers\.conv1d_2\.", "conv3.", name)
|
||||
name = re.sub(r"layers\.sequential\.layers\.group_normalization\.", "groupnorm.", name)
|
||||
name = re.sub(r"mha_with_rope\.key_dense", "self_attn.k_proj", name)
|
||||
name = re.sub(r"mha_with_rope\.query_dense", "self_attn.q_proj", name)
|
||||
name = re.sub(r"mha_with_rope\.value_dense", "self_attn.v_proj", name)
|
||||
name = re.sub(r"mha_with_rope\.output_dense", "self_attn.o_proj", name)
|
||||
name = re.sub(r"mha_precomputed_kv\.key_dense", "encoder_attn.k_proj", name)
|
||||
name = re.sub(r"mha_precomputed_kv\.query_dense", "encoder_attn.q_proj", name)
|
||||
name = re.sub(r"mha_precomputed_kv\.value_dense", "encoder_attn.v_proj", name)
|
||||
name = re.sub(r"mha_precomputed_kv\.output_dense", "encoder_attn.o_proj", name)
|
||||
name = re.sub(r"mha_causal_with_rope\.key_dense", "self_attn.k_proj", name)
|
||||
name = re.sub(r"mha_causal_with_rope\.query_dense", "self_attn.q_proj", name)
|
||||
name = re.sub(r"mha_causal_with_rope\.value_dense", "self_attn.v_proj", name)
|
||||
name = re.sub(r"mha_causal_with_rope\.output_dense", "self_attn.o_proj", name)
|
||||
name = re.sub(r"layer_normalization\.", "input_layernorm.", name)
|
||||
name = re.sub(r"layer_normalization_1\.", "post_attention_layernorm.", name)
|
||||
name = re.sub(r"layer_normalization_2\.", "final_layernorm.", name)
|
||||
name = re.sub(r"vars\.0", "weight", name)
|
||||
name = re.sub(r"vars\.1", "bias", name)
|
||||
name = re.sub(r"layers\.reversible_embedding", "embed_tokens", name)
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def _convert_weights(weights, encoder=True):
|
||||
if "layers.rotary_embedding.vars.0" in weights:
|
||||
weights.pop("layers.rotary_embedding.vars.0")
|
||||
|
||||
converted_weights = {}
|
||||
if encoder:
|
||||
converted_weights["layer_norm.weight"] = weights.pop("layers.layer_normalization.vars.0")
|
||||
else:
|
||||
converted_weights["norm.weight"] = weights.pop("layers.layer_normalization.vars.0")
|
||||
|
||||
for name, w in weights.items():
|
||||
if encoder:
|
||||
new_name = _convert_layer_names(name)
|
||||
else:
|
||||
new_name = _convert_layer_names(name, gated_mlp=True)
|
||||
converted_weights[new_name] = w
|
||||
|
||||
return converted_weights
|
||||
|
||||
|
||||
def convert_usefulsensors_moonshine_to_hf(model_name, pytorch_dump_folder_path):
|
||||
preprocessor_weights_path, encoder_weights_path, decoder_weights_path = _get_weights(model_name)
|
||||
|
||||
with h5py.File(preprocessor_weights_path, "r") as f:
|
||||
loaded_preprocessor_weights = _read_h5_weights(f, weights={})
|
||||
|
||||
with h5py.File(encoder_weights_path, "r") as f:
|
||||
loaded_encoder_weights = _read_h5_weights(f, weights={})
|
||||
|
||||
with h5py.File(decoder_weights_path, "r") as f:
|
||||
loaded_decoder_weights = _read_h5_weights(f, weights={})
|
||||
|
||||
encoder_state_dict = {**loaded_encoder_weights, **loaded_preprocessor_weights}
|
||||
converted_encoder_state_dict = _convert_weights(encoder_state_dict)
|
||||
|
||||
converted_decoder_state_dict = _convert_weights(loaded_decoder_weights, encoder=False)
|
||||
converted_decoder_state_dict["embed_tokens.weight"] = converted_decoder_state_dict["embed_tokens.weight"].T
|
||||
|
||||
final_weights = {}
|
||||
for k, v in converted_encoder_state_dict.items():
|
||||
final_weights[f"model.encoder.{k}"] = v
|
||||
|
||||
for k, v in converted_decoder_state_dict.items():
|
||||
final_weights[f"model.decoder.{k}"] = v
|
||||
|
||||
if model_name == "tiny":
|
||||
config = MoonshineConfig()
|
||||
elif model_name == "base":
|
||||
config = MoonshineConfig(
|
||||
hidden_size=416,
|
||||
intermediate_size=1664,
|
||||
encoder_num_hidden_layers=8,
|
||||
decoder_num_hidden_layers=8,
|
||||
encoder_num_attention_heads=8,
|
||||
decoder_num_attention_heads=8,
|
||||
partial_rotary_factor=0.62,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown model name {model_name}")
|
||||
|
||||
final_weights["proj_out.weight"] = converted_decoder_state_dict["embed_tokens.weight"]
|
||||
|
||||
model = MoonshineForConditionalGeneration(config)
|
||||
model.load_state_dict(final_weights)
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# # Required parameters
|
||||
parser.add_argument("--model_name", type=str, help="Path to the downloaded checkpoints")
|
||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_usefulsensors_moonshine_to_hf(args.model_name, args.pytorch_dump_folder_path)
|
||||
@@ -1166,10 +1166,9 @@ MOONSHINE_MODEL_INPUTS_DOCSTRING = r"""
|
||||
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
|
||||
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
|
||||
and conversion into a tensor of type `torch.FloatTensor`.
|
||||
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
|
||||
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
|
||||
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
|
||||
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||||
attention_mask (`torch.Tensor`)`, *optional*):
|
||||
Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility,
|
||||
but it is not used.
|
||||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
@@ -1178,9 +1177,6 @@ MOONSHINE_MODEL_INPUTS_DOCSTRING = r"""
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor`)`, *optional*):
|
||||
Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility,
|
||||
but it is not used.
|
||||
decoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
@@ -1201,11 +1197,10 @@ MOONSHINE_MODEL_INPUTS_DOCSTRING = r"""
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.n_positions - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
|
||||
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
|
||||
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
|
||||
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||||
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
||||
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
||||
@@ -1228,6 +1223,11 @@ MOONSHINE_MODEL_INPUTS_DOCSTRING = r"""
|
||||
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded representation. This
|
||||
is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.n_positions - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
@@ -1549,22 +1549,5 @@ class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixi
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
)
|
||||
|
||||
def generate(self, *args, **kwargs):
|
||||
# TODO: @eustlb do it rather with a custom logits processor
|
||||
token_limit_factor = 6.5 / 16000.0 # Maximum of 6.5 tokens per second
|
||||
if kwargs.get("max_new_tokens") is None and kwargs.get("max_length") is None:
|
||||
if kwargs.get("attention_mask") is not None:
|
||||
seq_lens = kwargs["attention_mask"].sum(dim=-1)
|
||||
else:
|
||||
seq_lens = kwargs["input_values"].shape[-1]
|
||||
max_length = int(seq_lens.max().item() * token_limit_factor)
|
||||
logger.warning_once(
|
||||
f"Based on the input length, Moonshine will generate up to {max_length} tokens (ratio of 6.5 tokens/second). "
|
||||
"To specify a different length, set either `max_new_tokens` or `max_length`."
|
||||
)
|
||||
kwargs["max_length"] = max_length
|
||||
|
||||
return super().generate(*args, **kwargs)
|
||||
|
||||
|
||||
__all__ = ["MoonshineModel", "MoonshinePreTrainedModel", "MoonshineForConditionalGeneration"]
|
||||
|
||||
@@ -816,10 +816,9 @@ MOONSHINE_MODEL_INPUTS_DOCSTRING = r"""
|
||||
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
|
||||
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
|
||||
and conversion into a tensor of type `torch.FloatTensor`.
|
||||
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
|
||||
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
|
||||
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
|
||||
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||||
attention_mask (`torch.Tensor`)`, *optional*):
|
||||
Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility,
|
||||
but it is not used.
|
||||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
@@ -828,9 +827,6 @@ MOONSHINE_MODEL_INPUTS_DOCSTRING = r"""
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor`)`, *optional*):
|
||||
Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility,
|
||||
but it is not used.
|
||||
decoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
@@ -851,11 +847,10 @@ MOONSHINE_MODEL_INPUTS_DOCSTRING = r"""
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.n_positions - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
|
||||
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
|
||||
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
|
||||
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||||
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
||||
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
||||
@@ -878,6 +873,11 @@ MOONSHINE_MODEL_INPUTS_DOCSTRING = r"""
|
||||
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded representation. This
|
||||
is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.n_positions - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
@@ -1109,23 +1109,6 @@ class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixi
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
)
|
||||
|
||||
def generate(self, *args, **kwargs):
|
||||
# TODO: @eustlb do it rather with a custom logits processor
|
||||
token_limit_factor = 6.5 / 16000.0 # Maximum of 6.5 tokens per second
|
||||
if kwargs.get("max_new_tokens") is None and kwargs.get("max_length") is None:
|
||||
if kwargs.get("attention_mask") is not None:
|
||||
seq_lens = kwargs["attention_mask"].sum(dim=-1)
|
||||
else:
|
||||
seq_lens = kwargs["input_values"].shape[-1]
|
||||
max_length = int(seq_lens.max().item() * token_limit_factor)
|
||||
logger.warning_once(
|
||||
f"Based on the input length, Moonshine will generate up to {max_length} tokens (ratio of 6.5 tokens/second). "
|
||||
"To specify a different length, set either `max_new_tokens` or `max_length`."
|
||||
)
|
||||
kwargs["max_length"] = max_length
|
||||
|
||||
return super().generate(*args, **kwargs)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MoonshineConfig",
|
||||
|
||||
@@ -724,7 +724,7 @@ class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin):
|
||||
super().__init__(config)
|
||||
self.model = PhiModel(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@@ -284,7 +284,9 @@ class PhiModel(LlamaModel):
|
||||
|
||||
|
||||
class PhiForCausalLM(LlamaForCausalLM):
|
||||
pass
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
|
||||
|
||||
class PhiForSequenceClassification(LlamaForSequenceClassification):
|
||||
|
||||
@@ -3672,10 +3672,7 @@ class Trainer:
|
||||
return loss_mb.reduce_mean().detach().to(self.args.device)
|
||||
|
||||
with self.compute_loss_context_manager():
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss = self.compute_loss(model, inputs)
|
||||
else:
|
||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||
|
||||
del inputs
|
||||
if (
|
||||
@@ -3709,7 +3706,7 @@ class Trainer:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
# Finally we need to normalize the loss for reporting
|
||||
if num_items_in_batch is None:
|
||||
if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
|
||||
loss = loss / self.args.gradient_accumulation_steps
|
||||
|
||||
self.accelerator.backward(loss, **kwargs)
|
||||
@@ -5157,10 +5154,6 @@ class Trainer:
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
# Keep default behavior the same
|
||||
if not self.model_accepts_loss_kwargs:
|
||||
return batch_samples, None
|
||||
|
||||
if len(batch_samples) > 0 and "labels" in batch_samples[0]:
|
||||
# For now we don't support object detection
|
||||
try:
|
||||
|
||||
@@ -484,9 +484,9 @@ class MoonshineModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = torch.tensor([
|
||||
-9.1107, 4.5538, 6.3902, -6.8141, -7.2459, -7.9077, -7.2842, -7.6045, -8.0387, -7.8354,
|
||||
-7.3870, -7.2453, -7.7423, -7.3914, -7.3869, -7.6982, -7.6422, -7.0507, -7.3982, -7.2486,
|
||||
-8.0799, -7.3303, -7.3675, -6.8769, -7.6879, -7.2684, -6.9868, -6.7459, -7.6858, -7.3052,
|
||||
-9.1106, 4.5542, 6.3892, -6.8139, -7.2456, -7.9074, -7.2839, -7.6043, -8.0384, -7.8351,
|
||||
-7.3867, -7.2450, -7.7420, -7.3912, -7.3866, -7.6979, -7.6420, -7.0504, -7.3979, -7.2483,
|
||||
-8.0796, -7.3300, -7.3672, -6.8765, -7.6876, -7.2682, -6.9866, -6.7457, -7.6855, -7.3050,
|
||||
])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(outputs.logits[0][0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
|
||||
@@ -502,9 +502,9 @@ class MoonshineModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = torch.tensor([
|
||||
-6.7340, 1.9483, 5.2449, -8.0277, -7.9167, -7.8956, -7.9649, -7.9348, -8.1312, -8.0616,
|
||||
-8.1070, -7.7696, -7.8809, -7.9451, -8.1013, -7.8177, -7.8598, -7.8257, -7.8729, -7.9657,
|
||||
-7.9310, -8.1024, -7.8698, -7.8231, -8.0752, -7.9764, -7.8127, -8.0536, -7.9492, -7.9289,
|
||||
-6.7336, 1.9482, 5.2448, -8.0277, -7.9167, -7.8956, -7.9649, -7.9348, -8.1312, -8.0616,
|
||||
-8.1070, -7.7696, -7.8809, -7.9450, -8.1013, -7.8177, -7.8598, -7.8257, -7.8729, -7.9657,
|
||||
-7.9310, -8.1024, -7.8699, -7.8231, -8.0752, -7.9764, -7.8127, -8.0536, -7.9492, -7.9290,
|
||||
])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(outputs.logits[0][0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
|
||||
@@ -519,10 +519,10 @@ class MoonshineModelIntegrationTests(unittest.TestCase):
|
||||
outputs = model.generate(**inputs, max_new_tokens=1, return_dict_in_generate=True, output_logits=True)
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = torch.tensor([
|
||||
[-8.0098, 5.0239, 4.5986, -6.8125, -7.1676, -7.8782, -7.2152, -7.5188, -7.9078, -7.7394],
|
||||
[-4.4394, -1.4429, 6.6715, -6.8927, -7.3748, -7.0967, -6.5255, -7.0255, -7.2583, -7.0007],
|
||||
[-10.0088, 3.2862, 0.7342, -6.5558, -6.8514, -6.5309, -6.4173, -6.9485, -6.6215, -6.6230],
|
||||
[-10.8083, 4.0034, -0.0635, -5.0501, -5.3903, -5.4587, -5.2416, -5.4742, -5.2662, -5.3154]
|
||||
[-8.0109, 5.0241, 4.5979, -6.8125, -7.1675, -7.8783, -7.2152, -7.5188, -7.9077, -7.7394],
|
||||
[-4.4399, -1.4422, 6.6710, -6.8929, -7.3751, -7.0969, -6.5257, -7.0257, -7.2585, -7.0008],
|
||||
[-10.0086, 3.2859, 0.7345, -6.5557, -6.8514, -6.5308, -6.4172, -6.9484, -6.6214, -6.6229],
|
||||
[-10.8078, 4.0030, -0.0633, -5.0505, -5.3906, -5.4590, -5.2420, -5.4746, -5.2665, -5.3158]
|
||||
])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(outputs.logits[0][:, :10].cpu(), EXPECTED_LOGITS, atol=1e-4))
|
||||
@@ -538,10 +538,10 @@ class MoonshineModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = torch.tensor([
|
||||
[-7.7288, 1.4636, 5.2273, -7.7310, -7.6249, -7.6009, -7.6786, -7.6438, -7.8450, -7.7546],
|
||||
[-6.2161, -0.5891, 7.9489, -7.0693, -6.9996, -6.9980, -7.0952, -7.0830, -7.1685, -7.0136],
|
||||
[-7.3186, 3.1192, 3.8938, -5.7208, -5.8429, -5.7610, -5.9997, -5.8213, -5.8616, -5.8720],
|
||||
[-9.5488, 1.0147, 4.1174, -5.9972, -6.0616, -6.0331, -6.2105, -6.0320, -6.0791, -6.0875]
|
||||
[-7.7272, 1.4630, 5.2294, -7.7313, -7.6252, -7.6011, -7.6788, -7.6441, -7.8452, -7.7549],
|
||||
[-6.2173, -0.5891, 7.9493, -7.0694, -6.9997, -6.9982, -7.0953, -7.0831, -7.1686, -7.0137],
|
||||
[-7.3184, 3.1192, 3.8937, -5.7206, -5.8428, -5.7609, -5.9996, -5.8212, -5.8615, -5.8719],
|
||||
[-9.5475, 1.0146, 4.1179, -5.9971, -6.0614, -6.0329, -6.2103, -6.0318, -6.0789, -6.0873]
|
||||
])
|
||||
|
||||
# fmt: on
|
||||
|
||||
@@ -855,7 +855,14 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")
|
||||
|
||||
# max diff broken should be very off
|
||||
self.assertGreater(max(diff_broken), 3, f"Difference {max(diff_broken)} is not greater than 3")
|
||||
self.assertGreater(max(diff_broken), 2, f"Difference {max(diff_broken)} is not greater than 2")
|
||||
|
||||
loss_base = sum(base_loss_callback.losses)
|
||||
loss_broken = sum(broken_loss_callback.losses)
|
||||
|
||||
# mean/sum loss should not vary too much.
|
||||
relative_diff = abs(loss_base - loss_broken) / max(loss_base, loss_broken)
|
||||
self.assertLess(relative_diff, 0.1, f"Relative difference {relative_diff} is not within 0.1")
|
||||
|
||||
@slow
|
||||
def test_gradient_accumulation_loss_alignment_with_loss_func(self):
|
||||
|
||||
Reference in New Issue
Block a user