Skip to content

Commit 9b2bfa2

Browse files
Kangyan-Zhouclaude
andcommitted
feat: skip re-downloading models on shared storage
When multiple nodes mount the same filesystem (e.g., GPFS/NFS at /storage/models), the model-agent on each node would independently re-download from HuggingFace or OCI, causing rate-limiting and hours of unnecessary I/O. Add isModelAlreadyDownloaded() that checks: 1. config.json exists 2. If model.safetensors.index.json exists, ALL expected shards present 3. Otherwise, at least one weight file (.safetensors/.bin/.pt/.gguf) Only applies to fresh Download tasks (not DownloadOverride) so spec updates and failed retries still re-evaluate. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent aac400b commit 9b2bfa2

File tree

2 files changed

+232
-0
lines changed

2 files changed

+232
-0
lines changed

pkg/modelagent/gopher.go

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package modelagent
22

33
import (
44
"context"
5+
"encoding/json"
56
"fmt"
67
"os"
78
"path/filepath"
@@ -324,6 +325,25 @@ func (s *Gopher) processTask(task *GopherTask) error {
324325
s.logger.Errorf("Failed to get target directory path for model %s: %v", modelInfo, err)
325326
return err
326327
}
328+
329+
// Check if the model is already present on shared storage.
330+
// Only for fresh Download tasks — DownloadOverride indicates a spec change
331+
// or failed retry that must re-evaluate the model files.
332+
if task.TaskType == Download && s.isModelAlreadyDownloaded(destPath) {
333+
s.logger.Infof("Model %s already exists at %s (shared storage), skipping OCI download", modelInfo, destPath)
334+
var baseModel *v1beta1.BaseModel
335+
var clusterBaseModel *v1beta1.ClusterBaseModel
336+
if task.BaseModel != nil {
337+
baseModel = task.BaseModel
338+
} else if task.ClusterBaseModel != nil {
339+
clusterBaseModel = task.ClusterBaseModel
340+
}
341+
if err := s.safeParseAndUpdateModelConfig(destPath, baseModel, clusterBaseModel, nil); err != nil {
342+
s.logger.Errorf("Failed to parse and update model config for pre-existing model: %v", err)
343+
}
344+
break
345+
}
346+
327347
err = utils.Retry(s.downloadRetry, 100*time.Millisecond, func() error {
328348
downloadErr := s.downloadModel(ctx, osUri, destPath, task)
329349
if downloadErr != nil {
@@ -971,6 +991,29 @@ func (s *Gopher) processHuggingFaceModel(ctx context.Context, task *GopherTask,
971991
// Create destination path
972992
destPath := getDestPath(&baseModelSpec, s.modelRootDir)
973993

994+
// Check if the model is already present on shared storage (e.g., another node
995+
// already downloaded it to the same NFS/shared filesystem path). When storage is
996+
// shared across nodes, each model-agent would otherwise independently re-download
997+
// from HuggingFace, causing rate-limiting and hours of unnecessary I/O.
998+
// Only for fresh Download tasks — DownloadOverride indicates a spec change
999+
// or failed retry that must re-evaluate the model files.
1000+
if task.TaskType == Download && s.isModelAlreadyDownloaded(destPath) {
1001+
s.logger.Infof("Model %s already exists at %s (shared storage), skipping HuggingFace download", modelInfo, destPath)
1002+
1003+
var baseModel *v1beta1.BaseModel
1004+
var clusterBaseModel *v1beta1.ClusterBaseModel
1005+
if task.BaseModel != nil {
1006+
baseModel = task.BaseModel
1007+
} else if task.ClusterBaseModel != nil {
1008+
clusterBaseModel = task.ClusterBaseModel
1009+
}
1010+
1011+
if err := s.safeParseAndUpdateModelConfig(destPath, baseModel, clusterBaseModel, nil); err != nil {
1012+
s.logger.Errorf("Failed to parse and update model config for pre-existing model: %v", err)
1013+
}
1014+
return nil
1015+
}
1016+
9741017
// fetch sha value based on model ID from Huggingface model API
9751018
shaStr, isShaAvailable := s.fetchSha(ctx, hfComponents.ModelID, name)
9761019
isReuseEligible, matchedModelTypeAndModeName, parentPath := s.isEligibleForOptimization(ctx, task, baseModelSpec, modelType, namespace, isShaAvailable, shaStr, name)
@@ -1440,3 +1483,96 @@ func (s *Gopher) isRemoveParentArtifactDirectory(ctx context.Context, hasChildre
14401483
s.logger.Infof("parent entry %s:%s exists on node configmap: %v", parentName, parentDir, exists)
14411484
return !exists
14421485
}
1486+
1487+
// isModelAlreadyDownloaded checks whether the model files are already present at
1488+
// destPath. This handles the shared-storage case: when multiple nodes mount the
1489+
// same filesystem (e.g., NFS at /storage/models), the first node that finishes an
1490+
// HF download writes the files once. Subsequent nodes should detect the existing
1491+
// files and skip re-downloading.
1492+
//
1493+
// The check is deliberately conservative — it requires:
1494+
// 1. The directory exists
1495+
// 2. A config.json file exists
1496+
// 3. At least one model weight file exists (.safetensors, .bin, .pt, .gguf)
1497+
// 4. If model.safetensors.index.json exists, ALL expected shards must be present
1498+
//
1499+
// This avoids false positives from partially-downloaded directories.
1500+
func (s *Gopher) isModelAlreadyDownloaded(destPath string) bool {
1501+
// Check if directory exists
1502+
info, err := os.Stat(destPath)
1503+
if err != nil || !info.IsDir() {
1504+
return false
1505+
}
1506+
1507+
// Check for config.json (primary indicator of a complete HF download).
1508+
// Use err != nil (not os.IsNotExist) so that NFS I/O errors and permission
1509+
// errors are treated conservatively as "not present" rather than silently
1510+
// falling through as "exists".
1511+
configPath := filepath.Join(destPath, "config.json")
1512+
if _, err := os.Stat(configPath); err != nil {
1513+
return false
1514+
}
1515+
1516+
// Read directory entries once for all subsequent checks
1517+
entries, err := os.ReadDir(destPath)
1518+
if err != nil {
1519+
return false
1520+
}
1521+
1522+
// Build a set of filenames for fast lookup
1523+
fileSet := make(map[string]bool, len(entries))
1524+
for _, entry := range entries {
1525+
if !entry.IsDir() {
1526+
fileSet[entry.Name()] = true
1527+
}
1528+
}
1529+
1530+
// If model.safetensors.index.json exists, verify ALL expected shards are present.
1531+
// This is the strongest completeness check — it catches partial downloads.
1532+
indexPath := filepath.Join(destPath, "model.safetensors.index.json")
1533+
if _, err := os.Stat(indexPath); err == nil {
1534+
indexData, err := os.ReadFile(indexPath)
1535+
if err != nil {
1536+
s.logger.Warnf("Failed to read model index file %s: %v", indexPath, err)
1537+
return false
1538+
}
1539+
1540+
var index struct {
1541+
WeightMap map[string]string `json:"weight_map"`
1542+
}
1543+
if err := json.Unmarshal(indexData, &index); err != nil {
1544+
s.logger.Warnf("Failed to parse model index file %s: %v", indexPath, err)
1545+
return false
1546+
}
1547+
1548+
// Collect unique shard filenames from weight_map values
1549+
expectedShards := make(map[string]bool)
1550+
for _, shard := range index.WeightMap {
1551+
expectedShards[shard] = true
1552+
}
1553+
1554+
// Verify every expected shard exists on disk
1555+
for shard := range expectedShards {
1556+
if !fileSet[shard] {
1557+
s.logger.Infof("Model at %s is missing shard %s (expected %d shards), not treating as complete",
1558+
destPath, shard, len(expectedShards))
1559+
return false
1560+
}
1561+
}
1562+
1563+
s.logger.Infof("Model at %s has all %d expected shards from index", destPath, len(expectedShards))
1564+
return true
1565+
}
1566+
1567+
// No index file — fall back to checking for at least one weight file
1568+
weightExtensions := []string{".safetensors", ".bin", ".pt", ".gguf"}
1569+
for name := range fileSet {
1570+
for _, ext := range weightExtensions {
1571+
if strings.HasSuffix(name, ext) {
1572+
return true
1573+
}
1574+
}
1575+
}
1576+
1577+
return false
1578+
}

pkg/modelagent/gopher_test.go

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import (
55
"encoding/json"
66
"errors"
77
"fmt"
8+
"os"
9+
"path/filepath"
810
"testing"
911

1012
"k8s.io/apimachinery/pkg/runtime/schema"
@@ -961,6 +963,100 @@ func TestIsEligibleForOptimization_NoMatch(t *testing.T) {
961963
assert.Empty(t, parent)
962964
}
963965

966+
func TestIsModelAlreadyDownloaded(t *testing.T) {
967+
logger, _ := zap.NewDevelopment()
968+
sugaredLogger := logger.Sugar()
969+
defer func() { _ = sugaredLogger.Sync() }()
970+
971+
gopher := &Gopher{logger: sugaredLogger}
972+
973+
t.Run("nonexistent directory returns false", func(t *testing.T) {
974+
assert.False(t, gopher.isModelAlreadyDownloaded("/nonexistent/path/that/does/not/exist"))
975+
})
976+
977+
t.Run("empty directory returns false", func(t *testing.T) {
978+
dir := t.TempDir()
979+
assert.False(t, gopher.isModelAlreadyDownloaded(dir))
980+
})
981+
982+
t.Run("directory with only config.json returns false (no weights)", func(t *testing.T) {
983+
dir := t.TempDir()
984+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644))
985+
assert.False(t, gopher.isModelAlreadyDownloaded(dir))
986+
})
987+
988+
t.Run("directory with only weights returns false (no config.json)", func(t *testing.T) {
989+
dir := t.TempDir()
990+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("weight data"), 0644))
991+
assert.False(t, gopher.isModelAlreadyDownloaded(dir))
992+
})
993+
994+
t.Run("directory with config.json and safetensors returns true", func(t *testing.T) {
995+
dir := t.TempDir()
996+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644))
997+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("weight data"), 0644))
998+
assert.True(t, gopher.isModelAlreadyDownloaded(dir))
999+
})
1000+
1001+
t.Run("directory with config.json and .bin returns true", func(t *testing.T) {
1002+
dir := t.TempDir()
1003+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644))
1004+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "pytorch_model.bin"), []byte("weight data"), 0644))
1005+
assert.True(t, gopher.isModelAlreadyDownloaded(dir))
1006+
})
1007+
1008+
t.Run("directory with config.json and .pt returns true", func(t *testing.T) {
1009+
dir := t.TempDir()
1010+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644))
1011+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.pt"), []byte("weight data"), 0644))
1012+
assert.True(t, gopher.isModelAlreadyDownloaded(dir))
1013+
})
1014+
1015+
t.Run("directory with config.json and .gguf returns true", func(t *testing.T) {
1016+
dir := t.TempDir()
1017+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644))
1018+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.gguf"), []byte("weight data"), 0644))
1019+
assert.True(t, gopher.isModelAlreadyDownloaded(dir))
1020+
})
1021+
1022+
t.Run("file path instead of directory returns false", func(t *testing.T) {
1023+
dir := t.TempDir()
1024+
filePath := filepath.Join(dir, "somefile")
1025+
assert.NoError(t, os.WriteFile(filePath, []byte("data"), 0644))
1026+
assert.False(t, gopher.isModelAlreadyDownloaded(filePath))
1027+
})
1028+
1029+
// Shard completeness tests using model.safetensors.index.json
1030+
1031+
t.Run("index with all shards present returns true", func(t *testing.T) {
1032+
dir := t.TempDir()
1033+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644))
1034+
index := `{"metadata":{"total_size":100},"weight_map":{"w1":"model-00001-of-00002.safetensors","w2":"model-00002-of-00002.safetensors"}}`
1035+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(index), 0644))
1036+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("data"), 0644))
1037+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00002-of-00002.safetensors"), []byte("data"), 0644))
1038+
assert.True(t, gopher.isModelAlreadyDownloaded(dir))
1039+
})
1040+
1041+
t.Run("index with missing shard returns false", func(t *testing.T) {
1042+
dir := t.TempDir()
1043+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644))
1044+
index := `{"metadata":{"total_size":100},"weight_map":{"w1":"model-00001-of-00002.safetensors","w2":"model-00002-of-00002.safetensors"}}`
1045+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(index), 0644))
1046+
// Only write shard 1, shard 2 is missing
1047+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("data"), 0644))
1048+
assert.False(t, gopher.isModelAlreadyDownloaded(dir))
1049+
})
1050+
1051+
t.Run("malformed index file returns false", func(t *testing.T) {
1052+
dir := t.TempDir()
1053+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644))
1054+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(`{invalid json`), 0644))
1055+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00001-of-00001.safetensors"), []byte("data"), 0644))
1056+
assert.False(t, gopher.isModelAlreadyDownloaded(dir))
1057+
})
1058+
}
1059+
9641060
func TestIsEligibleForOptimization_AlwaysDownloadNotEligible(t *testing.T) {
9651061
nodeName := "node-1"
9661062
sha := "123abc"

0 commit comments

Comments
 (0)