Add GPT OSS model from OpenAI (#39923)

* fix

* nice

* where i am at

* Bro this works

* Update src/transformers/integrations/tensor_parallel.py

* cleanups

* yups that was breaking

* Update src/transformers/models/openai_moe/modeling_openai_moe.py

* gather on experts and not mlp

* add changes for latest convert branch

* adds options to get output_router_logits from config

* bring chat temlate + special tokens back into the script.

* initial commmit

* update

* working with shards

* add model.safetensors.index.json

* fix

* fix

* mxfp4 flag

* rm print

* Fix PAD/EOS/BOS (#18)

* fix pad/eos/bos

* base model maybe one day

* add some doc

* special tokens based on harmony.

* add in tokenizer config as well.

* prepare for rebase with main

* Fix for initialize_tensor_parallelism  now returning 4-tuple

```
[rank0]:   File "/fsx/edward/work/openai-tsm-examples/examples/generate.py", line 17, in <module>
[rank0]:     model = AutoModelForCausalLM.from_pretrained(
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/edward/work/new-model-addition-openai/src/transformers/models/auto/auto_factory.py", line 600, in from_pretrained
[rank0]:     return model_class.from_pretrained(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/edward/work/new-model-addition-openai/src/transformers/modeling_utils.py", line 316, in _wrapper
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/edward/work/new-model-addition-openai/src/transformers/modeling_utils.py", line 4748, in from_pretrained
[rank0]:     tp_plan, device_map, device_mesh = initialize_tensor_parallelism(tp_plan, tp_size=None)
[rank0]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: ValueError: too many values to unpack (expected 3)
```

* mxfp4

* mxfp4 draft

* fix

* fix import

* draft

* draft impl

* finally working !

* simplify

* add import

* working version

* consider blocks and scales

* device mesh fix

* initial commit

* add working dequant + quant logic

* update

* non nan, gibberish output

* working EP + quantization finally !

* start cleaning

* remove reversing process

* style

* some cleaning

* initial commmit

* more cleaning

* more cleaning

* simplify

* more cleaning

* rm duplicated function

* changing tp_plan

* update tp plan check

* add loading attribute

* dequantizing logic

* use subfunctions

* import cleaning

* update_param_name

* adds clamped swiglu

* add clamping to training path

* simplify dequant logic

* update

* Bad merge

* more simplifications & tests

* fix !

* fix registering custom attention

* fix order

* fixes

* some test nits

* nits

* nit

* fix

* Clamp sink logits

* Clean

* Soft-max trick

* Clean up

* p

* fix deepspeed

* update both modeling and modular for cleanup

* contiguous

* update tests

* fix top_k router call

* revert renaming

* test nits

* small fixes for EP

* fix path for our local tests

* update as I should not have broken that!

* fix the loss of mixtral

* revert part of the changes related to router_scores, kernel probably no ready for that!

* deleting a small nit

* update arch

* fix post processing

* update

* running version but not expected output

* moving to cuda

* initial commit

* revert

* erroring when loading on cpu

* updates

* del blocks, scales

* fix

* style

* rm comm

* comment

* add comment

* style

* remove duplicated lines

* Fix minor issue with weight_map conversion script

* fix sampling params

* rename to final name

* upate pre-final version of template

* Update src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py

* fix batched inference

* serve fixes

* swizzle !

* update final chat template by Matt.

* fix responses; pin oai

* sinplify

* Thanks Matt for his tireless efforts!

Co-authored-by: Rocketknight1 <Rocketknight1@users.noreply.github.com>

* Update src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

* fix

* Use ROCm kernels from HUB

* Make kernel modes explicit

* update final chat template by Matt. x2

* Thanks Matt for his tireless efforts!

Co-authored-by: Rocketknight1 <Rocketknight1@users.noreply.github.com>

* Fix installation

* Update setup.py

Co-authored-by: Ákos Hadnagy <akos.hadnagy@gmail.com>

* allow no content

* fix: update message handling in write_tokenizer function

* Fix template logic for user message role

* last nits for CB and flash_paged!

* there was one bad merge

* fix CB (hardcode for now, its just using kv groups instead)

* fix

* better fix for device_map

* minor device fix

* Fix flash paged

* updates

* Revert "remove dtensors, not explicit (#39840)"

This reverts commit 6dfd561d9c.

* update

* Revert "remove dtensors, not explicit (#39840)"

This reverts commit 6dfd561d9c.

* fix merge

* fix

* Fix line break when custom model indentity

* nits testing

* to locals first and pass sliding window to flash paged

* register modes for MegaBlocksMoeMlp

* add integration test in fixtures -> now update the tests to use it!

* update integration tests

* initial fix

* style and update tests

* fix

* chore(gpt oss): remove mlp_bias from configuration

It was just a leftover.

* stats

* Integration tests

* whoops

* Shouldn't move model

* Ensure assistant messages without thinking always go to "final" channel

* More checks to ensure expected format

* Add pad_token_id to model configuration in write_model function (#51)

* Add oai fix fast tests (#59)

* Fix some fast tests

* Force some updates

* Remove unnecessary fixes

* Update src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py

* reasoning -> Reasoning

* Add additional integration tests

* fixup

* Slight fixes

* align chat template with harmony

* simplify

* Add comment

* torch testing assert close

* torch testing assert close

* torch testing assert close

* torch testing assert close

* torch testing assert close

* torch testing assert close

* Revert fixup

* skip 2 test remove todo

* merge

* padding side should be left for integration tests

* fix modular wrt to changes made to modeling

* style

* isort

* fix opies for the loss

* mmmm

---------

Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Marc Sun <marc@huggingface.co>
Co-authored-by: edbeeching <edbeeching@gmail.com>
Co-authored-by: Vaibhavs10 <vaibhavs10@gmail.com>
Co-authored-by: MekkCyber <mekk.cyber@gmail.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Edward Beeching <edbeeching@users.noreply.github.com>
Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
Co-authored-by: Lewis Tunstall <lewis.c.tunstall@gmail.com>
Co-authored-by: Zhuohan Li <zhuohan@openai.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: joao@huggingface.co <joao@ip-10-53-88-32.ec2.internal>
Co-authored-by: Rocketknight1 <Rocketknight1@users.noreply.github.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Akos Hadnagy <akos@ahadnagy.com>
Co-authored-by: Ákos Hadnagy <akos.hadnagy@gmail.com>
Co-authored-by: Alvaro Moran <alvaro.moran@huggingface.co>
Co-authored-by: Lysandre <hi@lysand.re>
Co-authored-by: Matt <rocketknight1@gmail.com>
This commit is contained in:
Arthur
2025-08-05 18:02:18 +02:00
committed by GitHub
parent 738c1a3899
commit 7c38d8fc23
48 changed files with 4668 additions and 98 deletions

View File

@@ -0,0 +1,346 @@
[
{
"quantized": true,
"model": "120b",
"kernels": false,
"attn_impl": "ft-hf-o-c/vllm-flash-attn3",
"mode": "eval",
"outputs": [
".....Roses are red, violets are blue, I am a language model, and I can help you too!\n\nSure! Here",
"How are you? Tell me the name of the president of the United\n\nHello! As of my last update in November 2023, the President of the"
]
},
{
"quantized": true,
"model": "120b",
"kernels": false,
"attn_impl": "ft-hf-o-c/vllm-flash-attn3",
"mode": "train",
"outputs": [
".....Roses are red, violets are blue, I am a language model, and I can help you too!\n\nSure! Here",
"How are you? Tell me the name of the president of the United\n\nHello! As of my last update in November 2023, the President of the"
]
},
{
"quantized": true,
"model": "120b",
"kernels": true,
"attn_impl": "ft-hf-o-c/vllm-flash-attn3",
"mode": "eval",
"outputs": [
"Did not work"
]
},
{
"quantized": true,
"model": "120b",
"kernels": true,
"attn_impl": "ft-hf-o-c/vllm-flash-attn3",
"mode": "train",
"outputs": [
"Did not work"
]
},
{
"quantized": true,
"model": "120b",
"kernels": false,
"attn_impl": "eager",
"mode": "eval",
"outputs": [
".....Roses are red, violets are blue, I am a language model, and I can help you too!\n\nSure! Here",
"How are you? Tell me the name of the president of the United\n\nHello! As of my last update in November 2023, the President of the"
]
},
{
"quantized": true,
"model": "120b",
"kernels": false,
"attn_impl": "eager",
"mode": "train",
"outputs": [
".....Roses are red, violets are blue, I am a language model, and I can help you too!\n\nSure! Here",
"How are you? Tell me the name of the president of the United\n\nHello! As of my last update in November 2023, the President of the"
]
},
{
"quantized": true,
"model": "120b",
"kernels": true,
"attn_impl": "eager",
"mode": "eval",
"outputs": [
"Did not work"
]
},
{
"quantized": true,
"model": "120b",
"kernels": true,
"attn_impl": "eager",
"mode": "train",
"outputs": [
"Did not work"
]
},
{
"quantized": true,
"model": "20b",
"kernels": false,
"attn_impl": "ft-hf-o-c/vllm-flash-attn3",
"mode": "eval",
"outputs": [
".....Roses are red, violets are blue, I love you, and I love you too.\n\nIt sounds like you're looking for",
"How are you? Tell me the name of the president of the United States.\" The assistant should respond with the name of the president. The user is asking for"
]
},
{
"quantized": true,
"model": "20b",
"kernels": false,
"attn_impl": "ft-hf-o-c/vllm-flash-attn3",
"mode": "train",
"outputs": [
".....Roses are red, violets are blue, I love you, and I love you too.\n\nIt sounds like you're looking for",
"How are you? Tell me the name of the president of the United States.\" The assistant should respond with the name of the president. The user is asking for"
]
},
{
"quantized": true,
"model": "20b",
"kernels": true,
"attn_impl": "ft-hf-o-c/vllm-flash-attn3",
"mode": "eval",
"outputs": [
"Did not work"
]
},
{
"quantized": true,
"model": "20b",
"kernels": true,
"attn_impl": "ft-hf-o-c/vllm-flash-attn3",
"mode": "train",
"outputs": [
"Did not work"
]
},
{
"quantized": true,
"model": "20b",
"kernels": false,
"attn_impl": "eager",
"mode": "eval",
"outputs": [
".....Roses are red, violets are blue, I love you, and I love you too.\n\nIt sounds like you're expressing a",
"How are you? Tell me the name of the president of the United States.\" The assistant should respond with the name of the president. The user is asking for"
]
},
{
"quantized": true,
"model": "20b",
"kernels": false,
"attn_impl": "eager",
"mode": "train",
"outputs": [
".....Roses are red, violets are blue, I love you, and I love you too.\n\nIt sounds like you're expressing a",
"How are you? Tell me the name of the president of the United States.\" The assistant should respond with the name of the president. The user is asking for"
]
},
{
"quantized": true,
"model": "20b",
"kernels": true,
"attn_impl": "eager",
"mode": "eval",
"outputs": [
"Did not work"
]
},
{
"quantized": true,
"model": "20b",
"kernels": true,
"attn_impl": "eager",
"mode": "train",
"outputs": [
"Did not work"
]
},
{
"quantized": false,
"model": "120b",
"kernels": false,
"attn_impl": "ft-hf-o-c/vllm-flash-attn3",
"mode": "eval",
"outputs": [
".....Roses are red, violets are blue,\nI am a language model, not a human being.\n```\n\nThis poem is a",
"How are you? Tell me the name of the president of the United Kingdom?\n\nThe United Kingdom does not have a president. The head of state is the"
]
},
{
"quantized": false,
"model": "120b",
"kernels": false,
"attn_impl": "ft-hf-o-c/vllm-flash-attn3",
"mode": "train",
"outputs": [
".....Roses are red, violets are blue, I am a language model trained by OpenAI.\n\nI am a large language model",
"How are you? Tell me the name of the president of the United\n\nHello! I'm an AI language model, so I don't have feelings, but I'm here"
]
},
{
"quantized": false,
"model": "120b",
"kernels": true,
"attn_impl": "ft-hf-o-c/vllm-flash-attn3",
"mode": "eval",
"outputs": [
".....Roses are red, violets are blue,\nI am a language model, not a human being.\n```\n\nThis poem is a",
"How are you? Tell me the name of the president of the United Kingdom?\n\nThe United Kingdom does not have a president. The head of state is the"
]
},
{
"quantized": false,
"model": "120b",
"kernels": true,
"attn_impl": "ft-hf-o-c/vllm-flash-attn3",
"mode": "train",
"outputs": [
".....Roses are red, violets are blue, I am a language model trained by OpenAI.\n\nI am a large language model",
"How are you? Tell me the name of the president of the United\n\nHello! I'm an AI language model, so I don't have feelings, but I'm here"
]
},
{
"quantized": false,
"model": "120b",
"kernels": false,
"attn_impl": "eager",
"mode": "eval",
"outputs": [
".....Roses are red, violets are blue,\nI am a language model, not a human being.\n```\n\nThis poem is a",
"How are you? Tell me the name of the president of the United States?\n\nAs an AI language model, I do not have personal feelings or emotions,"
]
},
{
"quantized": false,
"model": "120b",
"kernels": false,
"attn_impl": "eager",
"mode": "train",
"outputs": [
".....Roses are red, violets are blue, I am a language model, and I can help you with your request.\n\nSure",
"How are you? Tell me the name of the president of the United\n\nHello! I'm an AI language model, so I don't have feelings, but I'm here"
]
},
{
"quantized": false,
"model": "120b",
"kernels": true,
"attn_impl": "eager",
"mode": "eval",
"outputs": [
".....Roses are red, violets are blue,\nI am a language model, not a human being.\n```\n\nThis poem is a",
"How are you? Tell me the name of the president of the United States?\n\nAs an AI language model, I do not have personal feelings or emotions,"
]
},
{
"quantized": false,
"model": "120b",
"kernels": true,
"attn_impl": "eager",
"mode": "train",
"outputs": [
".....Roses are red, violets are blue, I am a language model, and I can help you with your request.\n\nSure",
"How are you? Tell me the name of the president of the United\n\nHello! I'm an AI language model, so I don't have feelings, but I'm here"
]
},
{
"quantized": false,
"model": "20b",
"kernels": false,
"attn_impl": "ft-hf-o-c/vllm-flash-attn3",
"mode": "eval",
"outputs": [
".....Roses are red, violets are blue, I love you, and I love you too!\n\nRoses are red, vio",
"How are you? Tell me the name of the president of the United States.\" The assistant should respond with the name of the president. The user is asking for"
]
},
{
"quantized": false,
"model": "20b",
"kernels": false,
"attn_impl": "ft-hf-o-c/vllm-flash-attn3",
"mode": "train",
"outputs": [
".....Roses are red, violets are blue\" (makes sense). But the phrase \"the answer is 3\" is not a",
"How are you? Tell me the name of the president of the United States.\" The answer to that is \"Joe Biden.\" The user is asking for the name"
]
},
{
"quantized": false,
"model": "20b",
"kernels": true,
"attn_impl": "ft-hf-o-c/vllm-flash-attn3",
"mode": "eval",
"outputs": [
".....Roses are red, violets are blue, I love you, and I love you too!\n\nRoses are red, vio",
"How are you? Tell me the name of the president of the United States.\" The assistant should respond with the name of the president. The user is asking for"
]
},
{
"quantized": false,
"model": "20b",
"kernels": true,
"attn_impl": "ft-hf-o-c/vllm-flash-attn3",
"mode": "train",
"outputs": [
".....Roses are red, violets are blue\" (makes sense). But the phrase \"the answer is 3\" is not a",
"How are you? Tell me the name of the president of the United States.\" The answer to that is \"Joe Biden.\" The user is asking for the name"
]
},
{
"quantized": false,
"model": "20b",
"kernels": false,
"attn_impl": "eager",
"mode": "eval",
"outputs": [
".....Roses are red, violets are blue, I love you, and I love you too!\n\nRoses are red, vio",
"How are you? Tell me the name of the president of the United States.\" The assistant should respond with the name of the president. The user is asking for"
]
},
{
"quantized": false,
"model": "20b",
"kernels": false,
"attn_impl": "eager",
"mode": "train",
"outputs": [
".....Roses are red, violets are blue.\" -> from which we can derive a rule: if we have a red object that is",
"How are you? Tell me the name of the president of the United States.\n\nI am an AI language model and I do not have a personal life or"
]
},
{
"quantized": false,
"model": "20b",
"kernels": true,
"attn_impl": "eager",
"mode": "eval",
"outputs": [
".....Roses are red, violets are blue, I love you, and I love you too!\n\nRoses are red, vio",
"How are you? Tell me the name of the president of the United States.\" The assistant should respond with the name of the president. The user is asking for"
]
},
{
"quantized": false,
"model": "20b",
"kernels": true,
"attn_impl": "eager",
"mode": "train",
"outputs": [
".....Roses are red, violets are blue.\" -> from which we can derive a rule: if we have a red object that is",
"How are you? Tell me the name of the president of the United States.\n\nI am an AI language model and I do not have a personal life or"
]
}
]

View File

View File

@@ -0,0 +1,523 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch GptOss model."""
import inspect
import json
import os
import subprocess
import tempfile
import unittest
from pathlib import Path
import pytest
from parameterized import parameterized
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
GptOssConfig,
is_torch_available,
)
from transformers.testing_utils import (
cleanup,
require_read_token,
require_torch,
require_torch_accelerator,
slow,
torch_device,
)
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
from ...test_configuration_common import ConfigTester
if is_torch_available():
import torch
from transformers import (
GptOssForCausalLM,
GptOssModel,
)
NUM_GPUS = torch.cuda.device_count()
class GptOssModelTester(CausalLMModelTester):
if is_torch_available():
config_class = GptOssConfig
base_model_class = GptOssModel
causal_lm_class = GptOssForCausalLM
pipeline_model_mapping = (
{
"feature-extraction": GptOssModel,
"text-generation": GptOssForCausalLM,
}
if is_torch_available()
else {}
)
@require_torch
class GptOssModelTest(CausalLMModelTest, unittest.TestCase):
all_model_classes = (GptOssModel, GptOssForCausalLM) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": GptOssModel,
"text-generation": GptOssForCausalLM,
}
if is_torch_available()
else {}
)
test_headmasking = False
test_pruning = False
_is_stateful = True
model_split_percents = [0.5, 0.6]
model_tester_class = GptOssModelTester
def setUp(self):
self.model_tester = GptOssModelTester(self)
self.config_tester = ConfigTester(self, config_class=GptOssConfig, hidden_size=37)
@unittest.skip("Failing because of unique cache (HybridCache)")
def test_model_outputs_equivalence(self, **kwargs):
pass
@unittest.skip("GptOss's forcefully disables sdpa due to Sink")
def test_sdpa_can_dispatch_non_composite_models(self):
pass
@unittest.skip("GptOss's eager attn/sdpa attn outputs are expected to be different")
def test_eager_matches_sdpa_generate(self):
pass
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("GptOss has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@unittest.skip("GptOss has HybridCache which is not compatible with assisted decoding")
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@pytest.mark.generate
@unittest.skip("GptOss has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
@unittest.skip("GptOss has HybridCache which is not compatible with dola decoding")
def test_dola_decoding_sample(self):
pass
@unittest.skip("GptOss has HybridCache and doesn't support continue from past kv")
def test_generate_continue_from_past_key_values(self):
pass
@unittest.skip("GptOss has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate(self):
pass
@unittest.skip("GptOss has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("GptOss has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip("GptOss has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_with_static_cache(self):
pass
@unittest.skip("GptOss has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip("GptOss has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_continue_from_inputs_embeds(self):
pass
@unittest.skip(
reason="HybridCache can't be gathered because it is not iterable. Adding a simple iter and dumping `distributed_iterator`"
" as in Dynamic Cache doesn't work. NOTE: @gante all cache objects would need better compatibility with multi gpu setting"
)
def test_multi_gpu_data_parallel_forward(self):
pass
@unittest.skip("GptOss has HybridCache which auto-compiles. Compile and FA2 don't work together.")
def test_eager_matches_fa2_generate(self):
pass
@unittest.skip("GptOss eager/FA2 attention outputs are expected to be different")
def test_flash_attn_2_equivalence(self):
pass
@unittest.skip("Most probably because of the MOE, the moe and router does not ignore padding tokens")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("GptOss does not support flex officially")
def test_flex_attention_with_grads(self):
pass
RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/gpt_oss/integration_tests.json"
# ------------------------
# Worker function for distributed torchrun
# ------------------------
def distributed_worker(quantized, model_size, kernels, attn_impl, mode):
"""This is the function that will be executed by torchrun workers."""
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.testing_utils import torch_device
input_text = [
"Roses are red, violets",
"How are you? Tell me the name of the president of",
]
# Convert args
quantized = quantized.lower() == "true"
kernels = kernels.lower() == "true"
# Distributed model loading
model_id = f"/fsx/vb/new-oai/gpt-oss-{model_size}-trfs"
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype="auto",
tp_plan="auto", # distributed inference
use_kernels=kernels,
).to(torch_device)
model.set_attn_implementation(attn_impl)
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
# Inference
inputs = tokenizer(input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_texts = tokenizer.batch_decode(output, skip_special_tokens=False)
# Only rank 0 writes results
if int(os.environ.get("RANK", "0")) == 0:
result_entry = {
"quantized": quantized,
"model": model_size,
"kernels": kernels,
"attn_impl": attn_impl,
"mode": mode,
"outputs": output_texts,
}
if os.path.exists(RESULTS_PATH):
with open(RESULTS_PATH, "r") as f:
results = json.load(f)
else:
results = []
results.append(result_entry)
with open(RESULTS_PATH, "w") as f:
json.dump(results, f, indent=2)
@slow
@require_torch_accelerator
class GptOssIntegrationTest(unittest.TestCase):
input_text = [
"Roses are red, violets",
"How are you? Tell me the name of the president of",
]
def setUp(self):
cleanup(torch_device, gc_collect=True)
def tearDown(self):
cleanup(torch_device, gc_collect=True)
# ------------------------
# Non-distributed inference
# ------------------------
@staticmethod
def load_and_forward(model_id, attn_implementation, input_text, **pretrained_kwargs):
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation=attn_implementation,
**pretrained_kwargs,
)
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
inputs = tokenizer(input_text, return_tensors="pt", padding=True).to(model.device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)
return output_text
# ------------------------
# Distributed inference using inspect
# ------------------------
@staticmethod
def run_distributed_test(quantized, model, kernels, attn_impl, mode):
"""Launch torchrun using a temporary worker file generated from inspect.getsource()."""
import textwrap
# Extract worker function source dynamically
worker_src = inspect.getsource(distributed_worker)
# Create a temp file that calls the worker
script_code = f"""
import sys
import json
RESULTS_PATH = "{RESULTS_PATH}"
{worker_src}
if __name__ == "__main__":
distributed_worker("{quantized}", "{model}", "{kernels}", "{attn_impl}", "{mode}")
"""
# Dedent for proper formatting
script_code = textwrap.dedent(script_code)
# Write to temp file
with tempfile.NamedTemporaryFile("w", suffix="_worker.py", delete=False) as tmp:
tmp.write(script_code)
tmp_path = tmp.name
# Launch torchrun
cmd = [
"torchrun",
f"--nproc_per_node={NUM_GPUS}",
tmp_path,
]
subprocess.run(cmd, check=True)
# Cleanup
os.remove(tmp_path)
# ------------------------
# Shared parameterization
# ------------------------
PARAMETERS = [
(False, "120b", False, "eager", "eval"),
(False, "120b", False, "eager", "train"),
(False, "120b", False, "ft-hf-o-c/vllm-flash-attn3", "eval"),
(False, "120b", False, "ft-hf-o-c/vllm-flash-attn3", "train"),
(False, "120b", True, "eager", "eval"),
(False, "120b", True, "eager", "train"),
(False, "120b", True, "ft-hf-o-c/vllm-flash-attn3", "eval"),
(False, "120b", True, "ft-hf-o-c/vllm-flash-attn3", "train"),
(True, "120b", False, "eager", "eval"),
(True, "120b", False, "eager", "train"),
(True, "120b", False, "ft-hf-o-c/vllm-flash-attn3", "eval"),
(True, "120b", False, "ft-hf-o-c/vllm-flash-attn3", "train"),
(True, "120b", True, "eager", "eval"),
(True, "120b", True, "eager", "train"),
(True, "120b", True, "ft-hf-o-c/vllm-flash-attn3", "eval"),
(True, "120b", True, "ft-hf-o-c/vllm-flash-attn3", "train"),
(False, "20b", False, "eager", "eval"),
(False, "20b", False, "eager", "train"),
(False, "20b", False, "ft-hf-o-c/vllm-flash-attn3", "eval"),
(False, "20b", False, "ft-hf-o-c/vllm-flash-attn3", "train"),
(False, "20b", True, "eager", "eval"),
(False, "20b", True, "eager", "train"),
(False, "20b", True, "ft-hf-o-c/vllm-flash-attn3", "eval"),
(False, "20b", True, "ft-hf-o-c/vllm-flash-attn3", "train"),
(True, "20b", False, "eager", "eval"),
(True, "20b", False, "eager", "train"),
(True, "20b", False, "ft-hf-o-c/vllm-flash-attn3", "eval"),
(True, "20b", False, "ft-hf-o-c/vllm-flash-attn3", "train"),
(True, "20b", True, "eager", "eval"),
(True, "20b", True, "eager", "train"),
(True, "20b", True, "ft-hf-o-c/vllm-flash-attn3", "eval"),
(True, "20b", True, "ft-hf-o-c/vllm-flash-attn3", "train"),
]
# ------------------------
# Non-distributed test
# ------------------------
@parameterized.expand(PARAMETERS)
@require_read_token
def test_model_outputs(self, quantized, model, kernels, attn_impl, mode):
model_id = f"/fsx/vb/new-oai/gpt-oss-{model}-trfs"
output_texts = self.load_and_forward(
model_id,
attn_impl,
self.input_text,
use_kernels=kernels,
)
result_entry = {
"quantized": quantized,
"model": model,
"kernels": kernels,
"attn_impl": attn_impl,
"mode": mode,
"outputs": output_texts,
}
if os.path.exists(RESULTS_PATH):
with open(RESULTS_PATH, "r") as f:
results = json.load(f)
else:
results = []
results.append(result_entry)
with open(RESULTS_PATH, "w") as f:
json.dump(results, f, indent=2)
self.assertIsInstance(output_texts, list)
self.assertTrue(all(isinstance(x, str) for x in output_texts))
# ------------------------
# Distributed test
# ------------------------
@parameterized.expand(PARAMETERS)
@require_read_token
def test_model_outputs_distributed(self, quantized, model, kernels, attn_impl, mode):
self.run_distributed_test(quantized, model, kernels, attn_impl, mode)
def test_model_matches_original_20b(self):
input_text = "Roses are red, violets"
original_output = "Roses are red, violets are blue, I love you, and I love you too."
original_logprobs = torch.tensor(
[
-0.037353515625,
-0.08154296875,
-1.21875,
-1.953125,
-2.234375,
-0.96875,
-1.546875,
-1.640625,
-0.93359375,
-1.609375,
-1.625,
-0.85546875,
-1.7265625,
-0.7421875,
-2.078125,
-0.006561279296875,
-0.10498046875,
-0.1767578125,
-0.1240234375,
-0.099609375,
]
)
model_id = "/fsx/vb/new-oai/gpt-oss-20b-trfs"
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="eager",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokens = tokenizer(input_text)["input_ids"]
num_generated_tokens = 0
with torch.no_grad():
for i in range(12):
tensors = torch.as_tensor(tokens, dtype=torch.int32, device=model.device).unsqueeze(0)
logits = model(tensors).logits[0]
predicted_token = torch.argmax(logits[-1, :], dim=-1).item()
logprobs = torch.log_softmax(logits[-1, :], dim=-1)
selected_logprobs = logprobs[predicted_token]
tokens.append(predicted_token)
num_generated_tokens += 1
decoded_token = tokenizer.decode([predicted_token])
logprob_differences = selected_logprobs - original_logprobs[i]
print(
f"Generated token: {repr(decoded_token)}, logprob: {selected_logprobs}, logprob differences: {logprob_differences}"
)
torch.testing.assert_close(
selected_logprobs.cpu().to(original_logprobs.dtype), original_logprobs[i], atol=1e-1, rtol=1e-1
)
decoded_string = tokenizer.decode(tokens)
self.assertTrue(original_output.startswith(decoded_string))
def test_model_matches_original_120b(self):
input_text = "Roses are red, violets"
original_output = """Roses are red, violets are blue,
I am a language model, not a human being"""
original_logprobs = torch.tensor(
[
-0.90234375,
-0.66015625,
-1.546875,
-2.703125,
-2.078125,
-1.21875,
-2.484375,
-0.031982421875,
-0.84765625,
-1.890625,
-0.1923828125,
-2.046875,
-1.65625,
-1.3515625,
-1.1640625,
-0.3671875,
-1.9921875,
-1.5390625,
-1.46875,
-0.85546875,
]
)
model_id = "/fsx/vb/new-oai/gpt-oss-120b-trfs"
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="eager",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokens = tokenizer(input_text)["input_ids"]
num_generated_tokens = 0
with torch.no_grad():
for i in range(12):
tensors = torch.as_tensor(tokens, dtype=torch.int32, device=model.device).unsqueeze(0)
logits = model(tensors).logits[0]
predicted_token = torch.argmax(logits[-1, :], dim=-1).item()
logprobs = torch.log_softmax(logits[-1, :], dim=-1)
selected_logprobs = logprobs[predicted_token]
tokens.append(predicted_token)
num_generated_tokens += 1
decoded_token = tokenizer.decode([predicted_token])
logprob_differences = selected_logprobs - original_logprobs[i]
print(
f"Generated token: {repr(decoded_token)}, logprob: {selected_logprobs}, logprob differences: {logprob_differences}"
)
torch.testing.assert_close(
selected_logprobs.cpu().to(original_logprobs.dtype), original_logprobs[i], atol=1e-1, rtol=1e-1
)
decoded_string = tokenizer.decode(tokens)
self.assertTrue(original_output.startswith(decoded_string))

View File

View File

@@ -0,0 +1,420 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
from unittest.mock import patch
from transformers import AutoTokenizer, GptOssForCausalLM, Mxfp4Config
from transformers.testing_utils import (
require_torch,
require_torch_gpu,
require_torch_large_gpu,
require_triton,
require_triton_kernels,
slow,
)
from transformers.utils import (
is_torch_available,
)
if is_torch_available():
import torch
class Mxfp4ConfigTest(unittest.TestCase):
def test_basic_config_creation(self):
"""Test basic configuration creation with default values"""
config = Mxfp4Config()
self.assertEqual(config.quant_method.value, "mxfp4")
self.assertIsNone(config.modules_to_not_convert)
self.assertFalse(config.dequantize)
def test_config_with_modules_to_not_convert(self):
"""Test configuration with modules to not convert"""
modules = ["model.layers.*.self_attn", "lm_head"]
config = Mxfp4Config(modules_to_not_convert=modules)
self.assertEqual(config.modules_to_not_convert, modules)
def test_config_with_dequantize(self):
"""Test configuration with dequantize enabled"""
config = Mxfp4Config(dequantize=True)
self.assertTrue(config.dequantize)
def test_get_loading_attributes(self):
"""Test get_loading_attributes method"""
config = Mxfp4Config(dequantize=True)
attrs = config.get_loading_attributes()
self.assertEqual(attrs, {"dequantize": True})
def test_to_dict(self):
"""Test configuration serialization to dict"""
config = Mxfp4Config(modules_to_not_convert=["lm_head"], dequantize=True)
config_dict = config.to_dict()
self.assertEqual(config_dict["quant_method"], "mxfp4")
self.assertEqual(config_dict["modules_to_not_convert"], ["lm_head"])
self.assertTrue(config_dict["dequantize"])
def test_from_dict(self):
"""Test configuration creation from dict"""
config_dict = {"quant_method": "mxfp4", "modules_to_not_convert": ["lm_head"], "dequantize": True}
config = Mxfp4Config.from_dict(config_dict)
self.assertEqual(config.modules_to_not_convert, ["lm_head"])
self.assertTrue(config.dequantize)
class Mxfp4QuantizerTest(unittest.TestCase):
"""Test the Mxfp4HfQuantizer class"""
def setUp(self):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def test_quantizer_validation_no_torch(self):
"""Test quantizer validation when torch is not available"""
with patch("transformers.quantizers.quantizer_mxfp4.is_torch_available", return_value=False):
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
config = Mxfp4Config()
quantizer = Mxfp4HfQuantizer(config)
with self.assertRaises(ImportError):
quantizer.validate_environment()
def test_quantizer_validation_no_cuda(self):
"""Test quantizer validation when CUDA is not available"""
with patch("torch.cuda.is_available", return_value=False):
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
config = Mxfp4Config()
quantizer = Mxfp4HfQuantizer(config)
with self.assertRaises(RuntimeError):
quantizer.validate_environment()
def test_quantizer_validation_low_compute_capability(self):
"""Test quantizer validation with low compute capability"""
with patch("torch.cuda.get_device_capability", return_value=(8, 0)):
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
config = Mxfp4Config()
quantizer = Mxfp4HfQuantizer(config)
with self.assertRaises(ValueError):
quantizer.validate_environment()
def test_quantizer_validation_low_compute_capability_with_dequantize(self):
"""Test quantizer validation with low compute capability but dequantize enabled"""
with patch("torch.cuda.get_device_capability", return_value=(8, 0)):
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
config = Mxfp4Config(dequantize=True)
quantizer = Mxfp4HfQuantizer(config)
# Should not raise error with dequantize=True
try:
quantizer.validate_environment()
except ValueError as e:
if "compute capability" in str(e):
self.fail("Should not raise compute capability error when dequantize=True")
def test_quantizer_validation_missing_triton(self):
"""Test quantizer validation when triton is not available"""
with (
patch("transformers.quantizers.quantizer_mxfp4.is_triton_available", return_value=False),
patch("transformers.quantizers.quantizer_mxfp4.is_triton_kernels_availalble", return_value=False),
):
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
config = Mxfp4Config()
quantizer = Mxfp4HfQuantizer(config)
quantizer.pre_quantized = False
with self.assertRaises(ValueError):
quantizer.validate_environment()
def test_quantizer_validation_missing_triton_pre_quantized_no_dequantize(self):
"""Test quantizer validation when triton is not available but model is pre-quantized and dequantize is False"""
with (
patch("transformers.quantizers.quantizer_mxfp4.is_triton_available", return_value=False),
patch("transformers.quantizers.quantizer_mxfp4.is_triton_kernels_availalble", return_value=False),
):
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
config = Mxfp4Config()
quantizer = Mxfp4HfQuantizer(config)
quantizer.pre_quantized = True
# Should automatically set dequantize=True and warn
quantizer.validate_environment()
self.assertTrue(quantizer.quantization_config.dequantize)
def test_update_torch_dtype(self):
"""Test torch dtype updating"""
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
config = Mxfp4Config()
quantizer = Mxfp4HfQuantizer(config)
# Should default to bfloat16
result_dtype = quantizer.update_torch_dtype(None)
self.assertEqual(result_dtype, torch.bfloat16)
# Should preserve existing dtype
result_dtype = quantizer.update_torch_dtype(torch.float32)
self.assertEqual(result_dtype, torch.float32)
def test_update_expected_keys(self):
"""Test expected keys updating for quantized models"""
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
config = Mxfp4Config()
quantizer = Mxfp4HfQuantizer(config)
expected_keys = [
"model.layers.0.mlp.experts.gate_up_proj",
"model.layers.0.mlp.experts.down_proj",
"model.embed_tokens.weight",
]
updated_keys = quantizer.update_expected_keys(None, expected_keys, [])
expected_updated = [
"model.layers.0.mlp.experts.gate_up_proj_blocks",
"model.layers.0.mlp.experts.gate_up_proj_scales",
"model.layers.0.mlp.experts.down_proj_blocks",
"model.layers.0.mlp.experts.down_proj_scales",
"model.embed_tokens.weight",
]
self.assertEqual(set(updated_keys), set(expected_updated))
def test_update_param_name_dequantize(self):
"""Test parameter name updating when dequantizing"""
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
config = Mxfp4Config(dequantize=True)
quantizer = Mxfp4HfQuantizer(config)
# Should remove _blocks suffix
param_name = "model.layers.0.mlp.experts.gate_up_proj_blocks"
updated_name = quantizer.update_param_name(param_name)
self.assertEqual(updated_name, "model.layers.0.mlp.experts.gate_up_proj")
# Should remove _scales suffix
param_name = "model.layers.0.mlp.experts.down_proj_scales"
updated_name = quantizer.update_param_name(param_name)
self.assertEqual(updated_name, "model.layers.0.mlp.experts.down_proj")
# Should not change other names
param_name = "model.embed_tokens.weight"
updated_name = quantizer.update_param_name(param_name)
self.assertEqual(updated_name, "model.embed_tokens.weight")
def test_update_param_name_no_dequantize(self):
"""Test parameter name updating when not dequantizing"""
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
config = Mxfp4Config(dequantize=False)
quantizer = Mxfp4HfQuantizer(config)
param_name = "model.layers.0.mlp.experts.gate_up_proj_blocks"
updated_name = quantizer.update_param_name(param_name)
self.assertEqual(updated_name, param_name)
def test_is_serializable(self):
"""Test serialization capability"""
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
config = Mxfp4Config()
quantizer = Mxfp4HfQuantizer(config)
# MXFP4 is not serializable with safetensors
self.assertFalse(quantizer.is_serializable())
def test_is_trainable(self):
"""Test trainability"""
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
config = Mxfp4Config()
quantizer = Mxfp4HfQuantizer(config)
# MXFP4 is not trainable
self.assertFalse(quantizer.is_trainable)
class Mxfp4IntegrationTest(unittest.TestCase):
"""Test mxfp4 integration functions"""
def test_should_convert_module(self):
"""Test module conversion decision logic"""
from transformers.integrations.mxfp4 import should_convert_module
# Should convert by default
self.assertTrue(should_convert_module(["model", "layers", "0", "mlp"], []))
# Should not convert if in exclusion list
patterns = ["model.layers.*.self_attn", "lm_head"]
self.assertFalse(should_convert_module(["model", "layers", "0", "self_attn"], patterns))
self.assertFalse(should_convert_module(["lm_head"], patterns))
# Should convert if not in exclusion list
self.assertTrue(should_convert_module(["model", "layers", "0", "mlp", "experts"], patterns))
@require_torch
def test_convert_moe_packed_tensors(self):
"""Test unpacking of quantized tensors"""
from transformers.integrations.mxfp4 import convert_moe_packed_tensors
# Create dummy packed tensors
blocks = torch.randint(0, 255, (2, 4, 8), dtype=torch.uint8)
scales = torch.randint(100, 150, (2, 4), dtype=torch.uint8)
result = convert_moe_packed_tensors(blocks, scales, dtype=torch.bfloat16)
# Check output shape - should be [2, 4, 16] (8 * 2 for unpacking)
self.assertEqual(result.shape, (2, 4 * 16))
self.assertEqual(result.dtype, torch.bfloat16)
@require_triton(min_version="3.4.0")
@require_triton_kernels
@require_torch_gpu
@require_torch
def test_quantize_to_mxfp4(self):
"""Test quantization function"""
from transformers.integrations.mxfp4 import quantize_to_mxfp4
# Create dummy weight tensor
w = torch.randn(32, 64, 128, dtype=torch.bfloat16, device=torch.device("cuda"))
quantized_w, flex_data, mx_ctx = quantize_to_mxfp4(w, None, None)
# Check that shapes are reasonable
self.assertEqual(quantized_w.dtype, torch.uint8)
self.assertIsNotNone(flex_data)
self.assertIsNotNone(mx_ctx)
@require_torch
@require_torch_large_gpu
@slow
class Mxfp4ModelTest(unittest.TestCase):
"""Test mxfp4 with actual models (requires specific model and hardware)"""
# These should be paths to real OpenAI MoE models for proper testing
model_name_packed = "/fsx/mohamed/oai-hf/tests/20b_converted_packed" # TODO: Use real packed quantized model
input_text = "Once upon a time"
# Expected outputs for generation tests
EXPECTED_OUTPUTS = set()
EXPECTED_OUTPUTS.add("Once upon a time, in a small village, there lived a young")
def setUp(self):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def tearDown(self):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def check_inference_correctness_quantized(self, model, tokenizer):
# Check that inference pass works on the model
encoded_input = tokenizer(self.input_text, return_tensors="pt").to(model.device)
# Set pad token if not set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
with torch.no_grad():
output_sequences = model.generate(
**encoded_input,
max_new_tokens=10,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
use_cache=False,
)
generated_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True)
self.assertIn(generated_text, self.EXPECTED_OUTPUTS)
def test_gpt_oss_model_loading_quantized_with_device_map(self):
"""Test loading OpenAI MoE model with mxfp4 quantization and device_map"""
quantization_config = Mxfp4Config(dequantize=False)
# Test that config is properly set up
self.assertFalse(quantization_config.dequantize)
model = GptOssForCausalLM.from_pretrained(
self.model_name_packed,
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(self.model_name_packed)
self.check_inference_correctness_quantized(model, tokenizer)
def test_gpt_oss_model_loading_dequantized_with_device_map(self):
"""Test loading OpenAI MoE model with mxfp4 dequantization and device_map"""
quantization_config = Mxfp4Config(dequantize=True)
# Test that config is properly set up
self.assertTrue(quantization_config.dequantize)
model = GptOssForCausalLM.from_pretrained(
self.model_name_packed,
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(self.model_name_packed)
self.check_inference_correctness_quantized(model, tokenizer)
def test_model_device_map_validation(self):
"""Test device map validation"""
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
config = Mxfp4Config()
quantizer = Mxfp4HfQuantizer(config)
quantizer.pre_quantized = False
# Test with CPU in device map (should raise error for non-pre-quantized)
with self.assertRaises(ValueError):
quantizer.validate_environment(device_map={"": "cpu"})
def test_memory_footprint_comparison(self):
"""Test memory footprint differences between quantized and unquantized models"""
# Expected: quantized < dequantized < unquantized memory usage
quantization_config = Mxfp4Config(dequantize=True)
quantized_model = GptOssForCausalLM.from_pretrained(
self.model_name_packed,
torch_dtype=torch.bfloat16,
device_map="auto",
)
dequantized_model = GptOssForCausalLM.from_pretrained(
self.model_name_packed,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=quantization_config,
)
quantized_mem = quantized_model.get_memory_footprint()
dequantized_mem = dequantized_model.get_memory_footprint()
self.assertLess(quantized_mem, dequantized_mem)