Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions auto_round/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1975,6 +1975,10 @@ def rename_weights_files(path: str, prefix="diffusion_pytorch_model"):
# rename safetensors
files = sorted(glob.glob(f"{path}/*.safetensors"))
total = len(files)
if total == 1:
new = f"{prefix}.safetensors"
os.rename(files[0], os.path.join(path, new))
Comment thread
xin3he marked this conversation as resolved.
return
Comment thread
xin3he marked this conversation as resolved.

for i, f in enumerate(files, 1):
new = f"{prefix}-{i:05d}-of-{total:05d}.safetensors"
Expand Down
31 changes: 31 additions & 0 deletions test/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,37 @@ def tiny_flux_model_path():
shutil.rmtree(tiny_model_path, ignore_errors=True)


@pytest.fixture(scope="session")
def tiny_z_image_model_path():
model_name_or_path = "Tongyi-MAI/Z-Image"
tiny_model_path = "./tmp/tiny_z_image_model_path"
tiny_model_path = save_tiny_model(
model_name_or_path,
tiny_model_path,
num_layers=1,
is_diffusion=True,
from_config=True,
config_overrides={
"dim": 256,
"n_heads": 2,
"n_kv_heads": 2,
"n_layers": 1,
"n_refiner_layers": 1,
"cap_feat_dim": 512,
"in_channels": 16,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"attention_head_dim": 128,
"joint_attention_dim": 256,
"pooled_projection_dim": 256,
"hidden_size": 512,
"intermediate_size": 256,
},
)
yield tiny_model_path
shutil.rmtree(tiny_model_path, ignore_errors=True)
Comment thread
xin3he marked this conversation as resolved.


@pytest.fixture(scope="session")
def tiny_untied_qwen_model_path():
model_name_or_path = qwen_name_or_path
Expand Down
38 changes: 2 additions & 36 deletions test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,41 +110,6 @@ def get_model_path(model_name: str) -> str:
return model_name


def get_captions_dataset_path() -> str:
"""Find captions_source.tsv locally or download it to tmp.

Checks /dataset/, /tf_dataset/, and test/tmp/ for the file.
If not found, downloads from the mlcommons URL to test/tmp/.

Returns:
str: The path to captions_source.tsv.
"""
import urllib.request

filename = "captions_source.tsv"
url = (
"https://raw.githubusercontent.com/mlcommons/inference/refs/heads/master/"
"text_to_image/coco2014/captions/captions_source.tsv"
)

local_candidates = [
f"/dataset/{filename}",
f"/tf_dataset/{filename}",
os.path.join(os.path.dirname(__file__), "tmp", filename),
]
for path in local_candidates:
if os.path.exists(path):
return path

# Download to tmp
tmp_dir = os.path.join(os.path.dirname(__file__), "tmp")
os.makedirs(tmp_dir, exist_ok=True)
tmp_path = os.path.join(tmp_dir, filename)
print(f"[Helper] Downloading {filename} from {url} to {tmp_path}")
urllib.request.urlretrieve(url, tmp_path)
return tmp_path


opt_name_or_path = get_model_path("facebook/opt-125m")
qwen_name_or_path = get_model_path("Qwen/Qwen3-0.6B")
lamini_name_or_path = get_model_path("MBZUAI/LaMini-GPT-124M")
Expand Down Expand Up @@ -315,7 +280,8 @@ def _get_module(cls_name, mod_name, folder_name):
and isinstance(v, list)
and v[0] in ["diffusers", "transformers"]
):
_reduce_config_layers(getattr(model, k).config, num_layers, num_experts)
tiny_module = _get_module(v[0], v[1], k)
setattr(model, k, tiny_module)
return model
else:
trust_remote_code = kwargs.get("trust_remote_code", True)
Expand Down
27 changes: 19 additions & 8 deletions test/test_cuda/models/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from packaging import version
from PIL import Image

from auto_round import AutoRoundDiffusion
from auto_round import AutoRound

from ...envs import multi_card, require_gptqmodel, require_optimum, require_vlm_env
from ...helpers import get_captions_dataset_path, get_model_path, transformers_version
from ...helpers import get_model_path, transformers_version


class TestAutoRound:
Expand All @@ -37,18 +37,29 @@ def test_diffusion_rtn(self, tiny_flux_model_path):
pipe.transformer.single_transformer_blocks = pipe.transformer.single_transformer_blocks[:2]

## quantize the model
autoround = AutoRoundDiffusion(
autoround = AutoRound(
pipe,
tokenizer=None,
scheme="MXFP4",
iters=0,
disable_opt_rtn=True,
num_inference_steps=2,
dataset=get_captions_dataset_path(),
dataset="coco2014",
)
# skip model saving since it takes much time
autoround.quantize()

def test_z_image_tune(self, tiny_z_image_model_path, tmp_path):
autoround = AutoRound(
tiny_z_image_model_path,
iters=1,
nsamples=1,
num_inference_steps=2,
dataset="coco2014",
)
# skip model saving since it takes much time
autoround.quantize_and_save(tmp_path)

@require_optimum
def test_diffusion_tune(self, tiny_flux_model_path, tmp_path):
from diffusers import AutoPipelineForText2Image
Expand All @@ -71,15 +82,15 @@ def test_diffusion_tune(self, tiny_flux_model_path, tmp_path):

## quantize the model
# https://raw.githubusercontent.com/mlcommons/inference/refs/heads/master/text_to_image/coco2014/captions/captions_source.tsv
autoround = AutoRoundDiffusion(
autoround = AutoRound(
pipe,
tokenizer=None,
scheme="MXFP4",
iters=1,
nsamples=1,
num_inference_steps=2,
layer_config=layer_config,
dataset=get_captions_dataset_path(),
dataset="coco2014",
)
# skip model saving since it takes much time
autoround.quantize_and_save(tmp_path)
Expand Down Expand Up @@ -116,15 +127,15 @@ def test_diffusion_tune_on_multi_cards(self, tiny_flux_model_path, tmp_path):

## quantize the model
# https://raw.githubusercontent.com/mlcommons/inference/refs/heads/master/text_to_image/coco2014/captions/captions_source.tsv
autoround = AutoRoundDiffusion(
autoround = AutoRound(
pipe,
tokenizer=None,
scheme="MXFP4",
iters=1,
nsamples=1,
num_inference_steps=2,
layer_config=layer_config,
dataset=get_captions_dataset_path(),
dataset="coco2014",
device_map="0,1",
)
# skip model saving since it takes much time
Expand Down
Loading