Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions air_llm/airllm/airllm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@

from .profiler import LayeredProfiler

from optimum.bettertransformer import BetterTransformer
try:
from optimum.bettertransformer import BetterTransformer
bettertransformer_available = True
except (ImportError, ModuleNotFoundError):
bettertransformer_available = False

from .utils import clean_memory, load_layer, \
find_or_create_local_splitted_path
Expand Down Expand Up @@ -184,7 +188,7 @@ def init_model(self):
# Load meta model (no memory used)
self.model = None

if self.get_use_better_transformer():
if self.get_use_better_transformer() and bettertransformer_available:
try:
with init_empty_weights():
self.model = AutoModelForCausalLM.from_config(self.config, trust_remote_code=True)
Expand Down
20 changes: 18 additions & 2 deletions air_llm/airllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import torch
import torch.nn as nn
from safetensors import safe_open
from safetensors.torch import load_file, save_file

from .persist import ModelPersister
Expand Down Expand Up @@ -208,11 +209,26 @@ def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitt
if os.path.exists(checkpoint_path / 'pytorch_model.bin.index.json'):
with open(checkpoint_path / 'pytorch_model.bin.index.json', 'rb') as f:
index = json.load(f)['weight_map']
else:
elif os.path.exists(checkpoint_path / 'model.safetensors.index.json'):
safetensors_format = True
assert os.path.exists(checkpoint_path / 'model.safetensors.index.json'), f'model.safetensors.index.json should exist.'
with open(checkpoint_path / 'model.safetensors.index.json', 'rb') as f:
index = json.load(f)['weight_map']
elif os.path.exists(checkpoint_path / 'model.safetensors'):
# Single-file safetensors (no shard index) — common for small models <= ~7B
safetensors_format = True
with safe_open(checkpoint_path / 'model.safetensors', framework='pt', device='cpu') as f:
index = {k: 'model.safetensors' for k in f.keys()}
elif os.path.exists(checkpoint_path / 'pytorch_model.bin'):
# Single-file PyTorch bin (no shard index)
_state = torch.load(checkpoint_path / 'pytorch_model.bin', map_location='cpu')
index = {k: 'pytorch_model.bin' for k in _state.keys()}
del _state
else:
raise FileNotFoundError(
f"No model weights found in {checkpoint_path}. "
"Expected one of: pytorch_model.bin.index.json, model.safetensors.index.json, "
"model.safetensors, or pytorch_model.bin."
)

if layer_names is None:
n_layers = len(set([int(k.split('.')[2]) for k in index.keys() if 'model.layers' in k]))
Expand Down
112 changes: 112 additions & 0 deletions air_llm/tests/test_single_file_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""
Tests for single-file model support in split_and_save_layers.

Covers the case where a model ships as a single model.safetensors or
pytorch_model.bin file (no shard index), which is common for models <= ~7B.
"""
import json
import os
import tempfile
import shutil
import unittest

import torch
from safetensors.torch import save_file


class TestSingleFileModelSplit(unittest.TestCase):

def setUp(self):
self.tmpdir = tempfile.mkdtemp(prefix="airllm_single_file_test_")

def tearDown(self):
shutil.rmtree(self.tmpdir, ignore_errors=True)

def _make_fake_model_state(self):
"""Minimal Llama-style state dict with 1 decoder layer."""
hidden = 64
inter = 128
vocab = 100
heads = 4
return {
"model.embed_tokens.weight": torch.randn(vocab, hidden),
"model.layers.0.input_layernorm.weight": torch.randn(hidden),
"model.layers.0.self_attn.q_proj.weight": torch.randn(hidden, hidden),
"model.layers.0.self_attn.k_proj.weight": torch.randn(hidden // heads, hidden),
"model.layers.0.self_attn.v_proj.weight": torch.randn(hidden // heads, hidden),
"model.layers.0.self_attn.o_proj.weight": torch.randn(hidden, hidden),
"model.layers.0.mlp.gate_proj.weight": torch.randn(inter, hidden),
"model.layers.0.mlp.up_proj.weight": torch.randn(inter, hidden),
"model.layers.0.mlp.down_proj.weight": torch.randn(hidden, inter),
"model.layers.0.post_attention_layernorm.weight": torch.randn(hidden),
"model.norm.weight": torch.randn(hidden),
"lm_head.weight": torch.randn(vocab, hidden),
}

# ------------------------------------------------------------------
# single model.safetensors (no index)
# ------------------------------------------------------------------
def test_split_single_safetensors_file(self):
state = self._make_fake_model_state()
save_file(state, os.path.join(self.tmpdir, "model.safetensors"))

from airllm.utils import split_and_save_layers
split_path = split_and_save_layers(self.tmpdir)

self.assertTrue(os.path.isdir(split_path))
expected_files = [
"model.embed_tokens.safetensors",
"model.layers.0.safetensors",
"model.norm.safetensors",
"lm_head.safetensors",
]
for fname in expected_files:
self.assertTrue(
os.path.exists(os.path.join(split_path, fname)),
f"Expected shard file missing: {fname}",
)

# ------------------------------------------------------------------
# single pytorch_model.bin (no index)
# ------------------------------------------------------------------
def test_split_single_pytorch_bin_file(self):
state = self._make_fake_model_state()
torch.save(state, os.path.join(self.tmpdir, "pytorch_model.bin"))

from airllm.utils import split_and_save_layers
split_path = split_and_save_layers(self.tmpdir)

self.assertTrue(os.path.isdir(split_path))
self.assertTrue(
os.path.exists(os.path.join(split_path, "model.embed_tokens.safetensors"))
)

# ------------------------------------------------------------------
# sharded model.safetensors.index.json still works (regression)
# ------------------------------------------------------------------
def test_split_sharded_safetensors_still_works(self):
state = self._make_fake_model_state()
shard_file = "model-00001-of-00001.safetensors"
save_file(state, os.path.join(self.tmpdir, shard_file))
index = {"metadata": {}, "weight_map": {k: shard_file for k in state}}
with open(os.path.join(self.tmpdir, "model.safetensors.index.json"), "w") as f:
json.dump(index, f)

from airllm.utils import split_and_save_layers
split_path = split_and_save_layers(self.tmpdir)

self.assertTrue(
os.path.exists(os.path.join(split_path, "model.embed_tokens.safetensors"))
)

# ------------------------------------------------------------------
# no weights at all → FileNotFoundError
# ------------------------------------------------------------------
def test_raises_when_no_weights(self):
from airllm.utils import split_and_save_layers
with self.assertRaises(FileNotFoundError):
split_and_save_layers(self.tmpdir)


if __name__ == "__main__":
unittest.main()