diff --git a/cmd/model-agent/main.go b/cmd/model-agent/main.go index 08909beb..6e1b03da 100644 --- a/cmd/model-agent/main.go +++ b/cmd/model-agent/main.go @@ -43,6 +43,7 @@ type config struct { numDownloadWorker int namespace string logLevel string + downloadTimeout time.Duration } // Logger type alias for zap.SugaredLogger @@ -73,6 +74,7 @@ func init() { rootCmd.PersistentFlags().IntVar(&cfg.numDownloadWorker, "num-download-worker", 5, "Number of download workers") rootCmd.PersistentFlags().StringVar(&cfg.namespace, "namespace", "ome", "Kubernetes namespace to use") rootCmd.PersistentFlags().StringVar(&cfg.logLevel, "log-level", "info", "Log level (debug, info, warn, error)") + rootCmd.PersistentFlags().DurationVar(&cfg.downloadTimeout, "download-timeout", 6*time.Hour, "Maximum time allowed for a single model download before it is cancelled") _ = v.BindPFlags(rootCmd.PersistentFlags()) v.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) @@ -268,6 +270,9 @@ func initializeComponents( logger, baseModelInformer.Lister(), clusterBaseModelInformer.Lister(), + cfg.nodeName, + cfg.namespace, + cfg.downloadTimeout, ) if err != nil { return nil, nil, fmt.Errorf("failed to create gopher: %w", err) diff --git a/cmd/model-agent/main_test.go b/cmd/model-agent/main_test.go index db035df7..6b5b0507 100644 --- a/cmd/model-agent/main_test.go +++ b/cmd/model-agent/main_test.go @@ -5,6 +5,7 @@ import ( "net/http/httptest" "os" "testing" + "time" "github.com/prometheus/client_golang/prometheus" "github.com/spf13/cobra" @@ -113,6 +114,7 @@ func TestDefaultConfig(t *testing.T) { testCmd.Flags().StringVar(&cfg.downloadAuthType, "download-auth-type", "instance-principal", "authentication method for model download") testCmd.Flags().IntVar(&cfg.numDownloadWorker, "num-download-worker", 3, "number of download workers") testCmd.Flags().StringVar(&cfg.namespace, "namespace", "ome", "the namespace of the ome model agents daemon set") + testCmd.Flags().DurationVar(&cfg.downloadTimeout, "download-timeout", 6*time.Hour, "maximum time for a single model download") // Call initConfig to set cfg.nodeName initConfig(nil, nil) @@ -127,6 +129,7 @@ func TestDefaultConfig(t *testing.T) { assert.Equal(t, "instance-principal", cfg.downloadAuthType) assert.Equal(t, 3, cfg.numDownloadWorker) assert.Equal(t, "ome", cfg.namespace) + assert.Equal(t, 6*time.Hour, cfg.downloadTimeout) } func TestInitializeLogger(t *testing.T) { diff --git a/pkg/modelagent/gopher.go b/pkg/modelagent/gopher.go index ff48c041..05fe2e66 100644 --- a/pkg/modelagent/gopher.go +++ b/pkg/modelagent/gopher.go @@ -2,18 +2,23 @@ package modelagent import ( "context" + "encoding/json" "fmt" + "math/rand" "os" "path/filepath" "strings" "sync" "sync/atomic" + "syscall" "time" "k8s.io/apimachinery/pkg/labels" "github.com/oracle/oci-go-sdk/v65/objectstorage" "go.uber.org/zap" + coordinationv1 "k8s.io/api/coordination/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes" @@ -63,6 +68,18 @@ type Gopher struct { // Track active downloads for cancellation activeDownloads map[string]context.CancelFunc // key: model UID activeDownloadsMutex sync.RWMutex + + // Shared storage coordination: when modelRootDir is on a shared filesystem + // (NFS, GPFS, CephFS, Lustre), only the download leader should download. + // Other agents wait with jitter and recheck for files on disk. + isSharedStorage bool + nodeName string + namespace string + + // downloadTimeout is the maximum time allowed for a single model download. + // Prevents stuck downloads (e.g., xet retrying 403 errors forever) from + // blocking workers indefinitely. + downloadTimeout time.Duration } const ( @@ -83,11 +100,22 @@ func NewGopher( metrics *Metrics, logger *zap.SugaredLogger, baseModelLister omev1beta1lister.BaseModelLister, - clusterBaseModelLister omev1beta1lister.ClusterBaseModelLister) (*Gopher, error) { + clusterBaseModelLister omev1beta1lister.ClusterBaseModelLister, + nodeName string, + namespace string, + downloadTimeout time.Duration) (*Gopher, error) { if xetConfig == nil { return nil, fmt.Errorf("xet hugging face config cannot be nil") } + if downloadTimeout <= 0 { + return nil, fmt.Errorf("downloadTimeout must be positive, got %v", downloadTimeout) + } + + shared := isSharedFilesystem(modelRootDir, logger) + if shared { + logger.Infof("Detected shared filesystem at %s — download leader election enabled", modelRootDir) + } return &Gopher{ modelConfigParser: modelConfigParser, @@ -105,6 +133,10 @@ func NewGopher( activeDownloads: make(map[string]context.CancelFunc), baseModelLister: baseModelLister, clusterBaseModelLister: clusterBaseModelLister, + isSharedStorage: shared, + nodeName: nodeName, + namespace: namespace, + downloadTimeout: downloadTimeout, }, nil } @@ -275,8 +307,9 @@ func (s *Gopher) processTask(task *GopherTask) error { // Continue with download anyway } - // Create a cancellable context for this download - ctx, cancel = context.WithCancel(context.Background()) + // Create a context with timeout for this download to prevent stuck + // downloads (e.g., xet retrying 403 errors) from blocking workers forever. + ctx, cancel = context.WithTimeout(context.Background(), s.downloadTimeout) // Register the cancel function s.activeDownloadsMutex.Lock() @@ -324,10 +357,26 @@ func (s *Gopher) processTask(task *GopherTask) error { s.logger.Errorf("Failed to get target directory path for model %s: %v", modelInfo, err) return err } + + // Check if the model is already present on shared storage. + // Only for fresh Download tasks — DownloadOverride indicates a spec change + // or failed retry that must re-evaluate the model files. + if task.TaskType == Download && s.isModelAlreadyDownloaded(destPath) { + s.logger.Infof("Model %s already exists at %s (shared storage), skipping OCI download", modelInfo, destPath) + if err := s.skipDownloadAndUpdateConfig(destPath, task); err != nil { + return err + } + break + } + err = utils.Retry(s.downloadRetry, 100*time.Millisecond, func() error { + // Short-circuit if context is already done (timeout or cancel) + if ctx.Err() != nil { + return ctx.Err() + } downloadErr := s.downloadModel(ctx, osUri, destPath, task) if downloadErr != nil { - // Check if context was cancelled + // Check if context was cancelled during download if ctx.Err() != nil { s.logger.Infof("Download cancelled for model %s: %v", modelInfo, ctx.Err()) return ctx.Err() @@ -338,12 +387,19 @@ func (s *Gopher) processTask(task *GopherTask) error { return downloadErr }) if err != nil { - s.logger.Errorf("All download attempts failed for model %s: %v", modelInfo, err) - - // Record download failure in metrics + // Record download failure in metrics with specific error classification errorType := "download_error" - if strings.Contains(err.Error(), "MD5") { + if ctx.Err() == context.Canceled { + errorType = "download_cancelled" + s.logger.Infof("Download cancelled for OCI model %s: %v", modelInfo, err) + } else if ctx.Err() == context.DeadlineExceeded { + errorType = "download_timeout" + s.logger.Errorf("Download timed out for OCI model %s after %v: %v", modelInfo, s.downloadTimeout, err) + } else if strings.Contains(err.Error(), "MD5") { errorType = "md5_verification_error" + s.logger.Errorf("All download attempts failed for model %s: %v", modelInfo, err) + } else { + s.logger.Errorf("All download attempts failed for model %s: %v", modelInfo, err) } s.metrics.RecordFailedDownload(modelType, namespace, name, errorType) @@ -971,6 +1027,33 @@ func (s *Gopher) processHuggingFaceModel(ctx context.Context, task *GopherTask, // Create destination path destPath := getDestPath(&baseModelSpec, s.modelRootDir) + // Check if the model is already present on shared storage (e.g., another node + // already downloaded it to the same NFS/shared filesystem path). When storage is + // shared across nodes, each model-agent would otherwise independently re-download + // from HuggingFace, causing rate-limiting and hours of unnecessary I/O. + // Only for fresh Download tasks — DownloadOverride indicates a spec change + // or failed retry that must re-evaluate the model files. + if task.TaskType == Download { + if s.isModelAlreadyDownloaded(destPath) { + s.logger.Infof("Model %s already exists at %s (shared storage), skipping HuggingFace download", modelInfo, destPath) + return s.skipDownloadAndUpdateConfig(destPath, task) + } + + // On shared storage, only the download leader should proceed. Non-leaders + // wait for the leader to finish and then recheck for files on disk. + if s.isSharedStorage && !s.isDownloadLeader(ctx, modelInfo) { + if s.waitForSharedStorageModel(ctx, destPath, modelInfo) { + s.logger.Infof("Model %s appeared on shared storage at %s after waiting for leader", modelInfo, destPath) + return s.skipDownloadAndUpdateConfig(destPath, task) + } + if ctx.Err() != nil { + return fmt.Errorf("download cancelled while waiting for shared storage leader: %w", ctx.Err()) + } + // Timed out waiting — fall through to download as a fallback + s.logger.Warnf("Model %s not found after waiting for leader, proceeding with own download", modelInfo) + } + } + // fetch sha value based on model ID from Huggingface model API shaStr, isShaAvailable := s.fetchSha(ctx, hfComponents.ModelID, name) isReuseEligible, matchedModelTypeAndModeName, parentPath := s.isEligibleForOptimization(ctx, task, baseModelSpec, modelType, namespace, isShaAvailable, shaStr, name) @@ -1129,9 +1212,25 @@ func (s *Gopher) processHuggingFaceModel(ctx context.Context, task *GopherTask, // when status becomes Ready/Failed, ensuring the controller sees the final progress downloadPath, err := xet.SnapshotDownloadWithProgress(ctx, config, progressHandler, progressThrottle) + // Always release the download leader lease after a download attempt + // (success or failure). On failure, this lets non-leaders detect "no lease" + // and try their own download instead of waiting the full 5-minute expiry. + // Use a fresh context since the download context may be near its deadline. + if s.isSharedStorage { + leaseCtx, leaseCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer leaseCancel() + s.releaseDownloadLease(leaseCtx, modelInfo) + } + if err != nil { // Check error type for better handling - if strings.Contains(err.Error(), "429") || strings.Contains(err.Error(), "rate limit") { + if ctx.Err() == context.Canceled { + s.logger.Infof("Download cancelled for HuggingFace model %s: %v", modelInfo, err) + s.metrics.RecordFailedDownload(modelType, namespace, name, "download_cancelled") + } else if ctx.Err() == context.DeadlineExceeded { + s.logger.Errorf("Download timed out for HuggingFace model %s after %v: %v", modelInfo, s.downloadTimeout, err) + s.metrics.RecordFailedDownload(modelType, namespace, name, "download_timeout") + } else if strings.Contains(err.Error(), "429") || strings.Contains(err.Error(), "rate limit") { s.logger.Warnf("Rate limited while downloading HuggingFace model %s: %v", modelInfo, err) s.metrics.RecordRateLimit(modelType, namespace, name, 30*time.Second) // Estimate s.metrics.RecordFailedDownload(modelType, namespace, name, "rate_limit_error") @@ -1146,6 +1245,7 @@ func (s *Gopher) processHuggingFaceModel(ctx context.Context, task *GopherTask, s.logger.Infof("Successfully downloaded HuggingFace model %s to %s", modelInfo, downloadPath) + artifact = s.modelConfigParser.buildArtifactAttribute(shaStr, s.configMapReconciler.getModelConfigMapKey(task.BaseModel, task.ClusterBaseModel), destPath, childrenPaths) } @@ -1440,3 +1540,442 @@ func (s *Gopher) isRemoveParentArtifactDirectory(ctx context.Context, hasChildre s.logger.Infof("parent entry %s:%s exists on node configmap: %v", parentName, parentDir, exists) return !exists } + +// skipDownloadAndUpdateConfig handles the case where model files already exist +// at the destination path (e.g., downloaded by another node on shared storage, +// or left from a previous run). It parses the model config and updates the +// ConfigMap, bypassing the download step. +func (s *Gopher) skipDownloadAndUpdateConfig(destPath string, task *GopherTask) error { + var baseModel *v1beta1.BaseModel + var clusterBaseModel *v1beta1.ClusterBaseModel + if task.BaseModel != nil { + baseModel = task.BaseModel + } else if task.ClusterBaseModel != nil { + clusterBaseModel = task.ClusterBaseModel + } + if err := s.safeParseAndUpdateModelConfig(destPath, baseModel, clusterBaseModel, nil); err != nil { + return fmt.Errorf("model files exist at %s but config update failed: %w", destPath, err) + } + return nil +} + +// isModelAlreadyDownloaded checks whether the model files are already present at +// destPath. This handles the shared-storage case: when multiple nodes mount the +// same filesystem (e.g., NFS at /storage/models), the first node that finishes an +// HF download writes the files once. Subsequent nodes should detect the existing +// files and skip re-downloading. +// +// Supports three model layouts: +// 1. Sharded safetensors: model.safetensors.index.json lists all expected shards. +// 2. Diffusion pipelines: model_index.json lists component subdirectories, each +// containing its own config and weight files. +// 3. Single-file models: no index file, but config.json + at least one weight +// file (.safetensors, .bin, .pt, .gguf) present. Note: this fallback cannot +// verify shard completeness for multi-shard models without an index file. +// +// All filesystem checks treat errors conservatively as "not present" so that +// NFS I/O or permission errors fall through to the normal download path rather +// than silently skipping the download. +func (s *Gopher) isModelAlreadyDownloaded(destPath string) bool { + // Check if directory exists + info, err := os.Stat(destPath) + if err != nil { + if os.IsNotExist(err) { + s.logger.Infof("isModelAlreadyDownloaded(%s): directory does not exist", destPath) + } else { + s.logger.Warnf("isModelAlreadyDownloaded(%s): failed to stat directory: %v", destPath, err) + } + return false + } + if !info.IsDir() { + s.logger.Warnf("isModelAlreadyDownloaded(%s): path exists but is not a directory", destPath) + return false + } + + // Try each layout in order of specificity. + // When an index file exists, its verdict is authoritative — we don't fall + // through to a weaker check that might accept an incomplete download. + + // 1. Sharded safetensors with model.safetensors.index.json + indexPath := filepath.Join(destPath, "model.safetensors.index.json") + if _, err := os.Stat(indexPath); err == nil { + return s.checkSafetensorsIndex(destPath) + } + + // 2. Diffusion pipeline with model_index.json + diffIndexPath := filepath.Join(destPath, "model_index.json") + if _, err := os.Stat(diffIndexPath); err == nil { + return s.checkDiffusionIndex(destPath) + } + + // 3. Fallback: config.json + at least one weight file + if s.checkConfigAndWeights(destPath) { + return true + } + + s.logger.Infof("isModelAlreadyDownloaded(%s): no known layout matched, will proceed with download", destPath) + return false +} + +// checkSafetensorsIndex verifies a sharded safetensors model by reading +// model.safetensors.index.json and ensuring every listed shard file exists. +func (s *Gopher) checkSafetensorsIndex(destPath string) bool { + indexPath := filepath.Join(destPath, "model.safetensors.index.json") + indexData, err := os.ReadFile(indexPath) + if err != nil { + s.logger.Warnf("checkSafetensorsIndex(%s): failed to read index file: %v", destPath, err) + return false + } + + var index struct { + WeightMap map[string]string `json:"weight_map"` + } + if err := json.Unmarshal(indexData, &index); err != nil { + s.logger.Warnf("checkSafetensorsIndex(%s): failed to parse index file (may be mid-write by another node on shared storage): %v", destPath, err) + return false + } + if len(index.WeightMap) == 0 { + s.logger.Infof("checkSafetensorsIndex(%s): index has empty weight_map", destPath) + return false + } + + entries, err := os.ReadDir(destPath) + if err != nil { + s.logger.Warnf("checkSafetensorsIndex(%s): failed to read directory: %v", destPath, err) + return false + } + fileSet := make(map[string]bool, len(entries)) + for _, entry := range entries { + if !entry.IsDir() { + fileSet[entry.Name()] = true + } + } + + // Verify every expected shard exists + expectedShards := make(map[string]bool) + for _, shard := range index.WeightMap { + expectedShards[shard] = true + } + for shard := range expectedShards { + if !fileSet[shard] { + s.logger.Infof("checkSafetensorsIndex(%s): missing shard %s (expected %d shards), not treating as complete", + destPath, shard, len(expectedShards)) + return false + } + } + + s.logger.Infof("checkSafetensorsIndex(%s): verified, all %d shards present", destPath, len(expectedShards)) + return true +} + +// checkDiffusionIndex verifies a diffusion pipeline model by reading +// model_index.json and ensuring every listed component subdirectory exists +// and is non-empty. +func (s *Gopher) checkDiffusionIndex(destPath string) bool { + indexPath := filepath.Join(destPath, "model_index.json") + indexData, err := os.ReadFile(indexPath) + if err != nil { + s.logger.Warnf("checkDiffusionIndex(%s): failed to read index file: %v", destPath, err) + return false + } + + var index map[string]interface{} + if err := json.Unmarshal(indexData, &index); err != nil { + s.logger.Warnf("checkDiffusionIndex(%s): failed to parse index file (may be mid-write by another node on shared storage): %v", destPath, err) + return false + } + + // Components are top-level keys that don't start with "_" (e.g. "transformer", + // "vae", "text_encoder"). Keys like "_class_name" and "_diffusers_version" are + // metadata. + componentCount := 0 + for key, val := range index { + if len(key) > 0 && key[0] == '_' { + continue + } + // Components are arrays like ["diffusers", "ClassName"]. Skip: + // - null values (disabled components) + // - non-array values like floats (e.g. "boundary_ratio": 0.9) + // - arrays of nulls (e.g. "image_encoder": [null, null]) + arr, isArray := val.([]interface{}) + if !isArray || len(arr) < 2 { + continue + } + // Check that the first element is a non-null string (library name) + if _, isStr := arr[0].(string); !isStr { + continue + } + // Guard against path traversal from untrusted JSON keys + if filepath.Base(key) != key { + s.logger.Warnf("checkDiffusionIndex(%s): skipping suspicious component key %q", destPath, key) + return false + } + componentCount++ + compDir := filepath.Join(destPath, key) + dirInfo, err := os.Stat(compDir) + if err != nil { + s.logger.Warnf("checkDiffusionIndex(%s): failed to stat component directory %s: %v", destPath, key, err) + return false + } + if !dirInfo.IsDir() { + s.logger.Warnf("checkDiffusionIndex(%s): component %s exists but is not a directory", destPath, key) + return false + } + // Check that the component directory is not empty + compEntries, err := os.ReadDir(compDir) + if err != nil { + s.logger.Warnf("checkDiffusionIndex(%s): failed to read component directory %s: %v", destPath, key, err) + return false + } + if len(compEntries) == 0 { + s.logger.Infof("checkDiffusionIndex(%s): component directory %s is empty", destPath, key) + return false + } + } + + if componentCount == 0 { + s.logger.Infof("checkDiffusionIndex(%s): model_index.json has no components", destPath) + return false + } + + s.logger.Infof("checkDiffusionIndex(%s): verified, all %d components present", destPath, componentCount) + return true +} + +// checkConfigAndWeights is a fallback check for single-file models or models +// without an index file. It verifies config.json exists and at least one +// weight file (.safetensors, .bin, .pt, .gguf) is present. +func (s *Gopher) checkConfigAndWeights(destPath string) bool { + configPath := filepath.Join(destPath, "config.json") + if _, err := os.Stat(configPath); err != nil { + s.logger.Infof("checkConfigAndWeights(%s): no config.json found", destPath) + return false + } + + entries, err := os.ReadDir(destPath) + if err != nil { + s.logger.Warnf("checkConfigAndWeights(%s): failed to read directory: %v", destPath, err) + return false + } + + weightExtensions := map[string]bool{ + ".safetensors": true, + ".bin": true, + ".pt": true, + ".gguf": true, + } + for _, entry := range entries { + if entry.IsDir() { + continue + } + ext := filepath.Ext(entry.Name()) + if weightExtensions[ext] { + s.logger.Warnf("checkConfigAndWeights(%s): using fallback heuristic (no index file found). "+ + "config.json + weight file %s found, but shard completeness cannot be fully verified. "+ + "If the model fails to load, re-trigger a DownloadOverride to force re-download.", destPath, entry.Name()) + return true + } + } + + s.logger.Infof("checkConfigAndWeights(%s): config.json exists but no weight files found", destPath) + return false +} + +// isSharedFilesystem detects whether the given path is on a shared/network +// filesystem by checking the filesystem type via syscall.Statfs. +// Known shared filesystem types: NFS, GPFS, CephFS, Lustre, GlusterFS, FUSE. +// Note: filesystem type detection via magic numbers only works on Linux. +// On macOS/Darwin, Statfs_t has a different layout and this will return false. +func isSharedFilesystem(path string, logger *zap.SugaredLogger) bool { + var stat syscall.Statfs_t + if err := syscall.Statfs(path, &stat); err != nil { + logger.Warnf("isSharedFilesystem(%s): syscall.Statfs failed: %v — shared storage detection disabled", path, err) + return false + } + // Filesystem magic numbers (from linux/magic.h and kernel sources) + switch stat.Type { + case 0x6969: // NFS_SUPER_MAGIC + return true + case 0x47504653: // GPFS (IBM Spectrum Scale) + return true + case 0x00C36400: // CEPH_SUPER_MAGIC + return true + case 0x0BD00BD0: // LUSTRE_SUPER_MAGIC + return true + case 0x65735546: // FUSE_SUPER_MAGIC (commonly used for network mounts) + return true + case 0x6A656A62: // GlusterFS + return true + default: + return false + } +} + +const ( + // downloadLeaderLeasePrefix is the prefix for per-model K8s Leases used for + // leader election on shared storage. Each model gets its own lease so different + // models can be downloaded in parallel by different nodes. + downloadLeaderLeasePrefix = "model-download-" + + // downloadLeaderLeaseDuration is how long a leader holds a per-model lease. + downloadLeaderLeaseDuration = 5 * time.Minute + + // sharedStorageRecheckInterval is how often non-leaders recheck for files. + sharedStorageRecheckInterval = 30 * time.Second + + // sharedStorageMaxJitter is the max random jitter added before rechecks. + sharedStorageMaxJitter = 15 * time.Second +) + +// sanitizeLeaseeName converts a model identifier (e.g., "google/gemma-4-31B-it") +// into a valid K8s resource name (lowercase, no slashes, max 253 chars). +func sanitizeLeaseName(modelInfo string) string { + name := strings.ToLower(modelInfo) + name = strings.ReplaceAll(name, "/", "-") + name = strings.ReplaceAll(name, "_", "-") + name = strings.ReplaceAll(name, " ", "-") + name = strings.ReplaceAll(name, ".", "-") + // Strip any remaining non-alphanumeric/dash characters + filtered := make([]byte, 0, len(name)) + for i := 0; i < len(name); i++ { + c := name[i] + if (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' { + filtered = append(filtered, c) + } + } + name = string(filtered) + // Trim leading/trailing dashes + name = strings.Trim(name, "-") + // K8s names must be <= 253 chars + if len(name) > 200 { + name = name[:200] + } + return downloadLeaderLeasePrefix + name +} + +// isDownloadLeader checks if this node currently holds the per-model download +// Lease. Each model gets its own lease so different models can be downloaded in +// parallel by different nodes. If the lease doesn't exist, it tries to create it +// (becoming leader). If held by this node, it renews and returns true. If held by +// another node, it checks for expiry and attempts to take over; otherwise returns false. +func (s *Gopher) isDownloadLeader(ctx context.Context, modelInfo string) bool { + leaseName := sanitizeLeaseName(modelInfo) + leasesClient := s.kubeClient.CoordinationV1().Leases(s.namespace) + now := metav1.NewMicroTime(time.Now()) + + lease, err := leasesClient.Get(ctx, leaseName, metav1.GetOptions{}) + if err != nil { + if !apierrors.IsNotFound(err) { + // API server error — coordination unavailable, download independently + s.logger.Warnf("Failed to check download leader lease (API error): %v — this node will download independently", err) + return true + } + // Lease doesn't exist — try to create it and become leader + leaseDuration := int32(downloadLeaderLeaseDuration.Seconds()) + newLease := &coordinationv1.Lease{ + ObjectMeta: metav1.ObjectMeta{ + Name: leaseName, + Namespace: s.namespace, + }, + Spec: coordinationv1.LeaseSpec{ + HolderIdentity: &s.nodeName, + LeaseDurationSeconds: &leaseDuration, + AcquireTime: &now, + RenewTime: &now, + }, + } + _, createErr := leasesClient.Create(ctx, newLease, metav1.CreateOptions{}) + if createErr != nil { + s.logger.Infof("Failed to acquire download leader lease for %s (another node won): %v", modelInfo, createErr) + return false + } + s.logger.Infof("Acquired download leader lease for %s — this node (%s) will download", modelInfo, s.nodeName) + return true + } + + // Lease exists — check if we hold it or if it's expired + if lease.Spec.HolderIdentity != nil && *lease.Spec.HolderIdentity == s.nodeName { + // We hold it — renew + lease.Spec.RenewTime = &now + _, err = leasesClient.Update(ctx, lease, metav1.UpdateOptions{}) + if err != nil { + if apierrors.IsConflict(err) || apierrors.IsNotFound(err) { + s.logger.Warnf("Lost download leader lease during renewal: %v — yielding leadership", err) + return false + } + s.logger.Warnf("Failed to renew download leader lease (transient error): %v — proceeding as leader", err) + } + return true + } + + // Another node holds it — check if expired + if lease.Spec.RenewTime != nil && lease.Spec.LeaseDurationSeconds != nil { + expiry := lease.Spec.RenewTime.Time.Add(time.Duration(*lease.Spec.LeaseDurationSeconds) * time.Second) + if time.Now().After(expiry) { + // Expired — take over + lease.Spec.HolderIdentity = &s.nodeName + lease.Spec.AcquireTime = &now + lease.Spec.RenewTime = &now + _, err = leasesClient.Update(ctx, lease, metav1.UpdateOptions{}) + if err != nil { + s.logger.Infof("Failed to take over expired download leader lease: %v", err) + return false + } + s.logger.Infof("Took over expired download leader lease for %s — this node (%s) will download", modelInfo, s.nodeName) + return true + } + } + + holderID := "" + if lease.Spec.HolderIdentity != nil { + holderID = *lease.Spec.HolderIdentity + } + s.logger.Infof("Download leader lease for %s held by %s — this node (%s) will wait for shared storage files", + modelInfo, holderID, s.nodeName) + return false +} + +// releaseDownloadLease deletes the per-model download leader lease after a +// successful download. This allows waiting agents to immediately detect "no lease" +// → check files on disk → skip, instead of waiting for the lease to expire. +func (s *Gopher) releaseDownloadLease(ctx context.Context, modelInfo string) { + leaseName := sanitizeLeaseName(modelInfo) + err := s.kubeClient.CoordinationV1().Leases(s.namespace).Delete(ctx, leaseName, metav1.DeleteOptions{}) + if err != nil { + s.logger.Warnf("Failed to release download leader lease %s: %v (non-critical, lease will expire)", leaseName, err) + } else { + s.logger.Infof("Released download leader lease %s after successful download", leaseName) + } +} + +// waitForSharedStorageModel waits for a model to appear on shared storage, +// with jitter and periodic rechecks. Returns true if the model appeared (another +// node downloaded it), false if the context was cancelled or max wait exceeded. +// Maximum wait time is downloadLeaderLeaseDuration + 30s (currently 5m30s). +func (s *Gopher) waitForSharedStorageModel(ctx context.Context, destPath string, modelInfo string) bool { + maxWait := downloadLeaderLeaseDuration + 30*time.Second + deadline := time.Now().Add(maxWait) + + for time.Now().Before(deadline) { + // Add random jitter to avoid thundering herd on recheck + jitter := time.Duration(rand.Int63n(int64(sharedStorageMaxJitter))) + s.logger.Infof("Shared storage: waiting %v before rechecking %s for model %s", sharedStorageRecheckInterval+jitter, destPath, modelInfo) + + timer := time.NewTimer(sharedStorageRecheckInterval + jitter) + select { + case <-ctx.Done(): + timer.Stop() + s.logger.Infof("Shared storage: wait cancelled for model %s at %s: %v", modelInfo, destPath, ctx.Err()) + return false + case <-timer.C: + } + + if s.isModelAlreadyDownloaded(destPath) { + s.logger.Infof("Shared storage: model %s appeared at %s (downloaded by leader)", modelInfo, destPath) + return true + } + } + + s.logger.Warnf("Shared storage: timed out waiting for model %s at %s — will attempt own download", modelInfo, destPath) + return false +} diff --git a/pkg/modelagent/gopher_test.go b/pkg/modelagent/gopher_test.go index b6d56f36..c81ba48a 100644 --- a/pkg/modelagent/gopher_test.go +++ b/pkg/modelagent/gopher_test.go @@ -5,6 +5,8 @@ import ( "encoding/json" "errors" "fmt" + "os" + "path/filepath" "testing" "k8s.io/apimachinery/pkg/runtime/schema" @@ -961,6 +963,231 @@ func TestIsEligibleForOptimization_NoMatch(t *testing.T) { assert.Empty(t, parent) } +func TestIsModelAlreadyDownloaded(t *testing.T) { + logger, _ := zap.NewDevelopment() + sugaredLogger := logger.Sugar() + defer func() { _ = sugaredLogger.Sync() }() + + gopher := &Gopher{logger: sugaredLogger} + + t.Run("nonexistent directory returns false", func(t *testing.T) { + assert.False(t, gopher.isModelAlreadyDownloaded("/nonexistent/path/that/does/not/exist")) + }) + + t.Run("empty directory returns false", func(t *testing.T) { + dir := t.TempDir() + assert.False(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("directory with only config.json returns false (no weights)", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644)) + assert.False(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("file path instead of directory returns false", func(t *testing.T) { + dir := t.TempDir() + filePath := filepath.Join(dir, "somefile") + assert.NoError(t, os.WriteFile(filePath, []byte("data"), 0644)) + assert.False(t, gopher.isModelAlreadyDownloaded(filePath)) + }) + + // --- Sharded safetensors tests --- + + t.Run("safetensors index with all shards present returns true", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644)) + index := `{"metadata":{"total_size":100},"weight_map":{"w1":"model-00001-of-00002.safetensors","w2":"model-00002-of-00002.safetensors"}}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(index), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("data"), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00002-of-00002.safetensors"), []byte("data"), 0644)) + assert.True(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("safetensors index with missing shard returns false", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644)) + index := `{"metadata":{"total_size":100},"weight_map":{"w1":"model-00001-of-00002.safetensors","w2":"model-00002-of-00002.safetensors"}}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(index), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("data"), 0644)) + assert.False(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("malformed safetensors index returns false (no fallback)", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(`{invalid json`), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00001-of-00001.safetensors"), []byte("data"), 0644)) + // Index exists but is malformed → authoritative false, no fallback + assert.False(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("safetensors index with empty weight_map returns false (no fallback)", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(`{"weight_map":{}}`), 0644)) + // Index exists but empty → authoritative false, no fallback + assert.False(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + // --- Diffusion pipeline tests --- + + t.Run("diffusion model with all components returns true", func(t *testing.T) { + dir := t.TempDir() + index := `{"_class_name":"QwenImagePipeline","_diffusers_version":"0.36.0","scheduler":["diffusers","FlowMatchEulerDiscreteScheduler"],"transformer":["diffusers","QwenTransformer2DModel"],"vae":["diffusers","AutoencoderKL"]}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(index), 0644)) + // Create component subdirectories with at least one file each + for _, comp := range []string{"scheduler", "transformer", "vae"} { + compDir := filepath.Join(dir, comp) + assert.NoError(t, os.MkdirAll(compDir, 0755)) + assert.NoError(t, os.WriteFile(filepath.Join(compDir, "config.json"), []byte(`{}`), 0644)) + } + assert.True(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("diffusion model with missing component returns false", func(t *testing.T) { + dir := t.TempDir() + index := `{"_class_name":"QwenImagePipeline","scheduler":["diffusers","Scheduler"],"transformer":["diffusers","Transformer"],"vae":["diffusers","VAE"]}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(index), 0644)) + // Only create scheduler and transformer, vae is missing + for _, comp := range []string{"scheduler", "transformer"} { + compDir := filepath.Join(dir, comp) + assert.NoError(t, os.MkdirAll(compDir, 0755)) + assert.NoError(t, os.WriteFile(filepath.Join(compDir, "config.json"), []byte(`{}`), 0644)) + } + assert.False(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("diffusion model with empty component directory returns false", func(t *testing.T) { + dir := t.TempDir() + index := `{"_class_name":"Pipeline","transformer":["diffusers","Model"]}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(index), 0644)) + assert.NoError(t, os.MkdirAll(filepath.Join(dir, "transformer"), 0755)) + // transformer dir exists but is empty + assert.False(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("diffusion model with null component is skipped", func(t *testing.T) { + dir := t.TempDir() + index := `{"_class_name":"Pipeline","transformer":["diffusers","Model"],"safety_checker":null}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(index), 0644)) + compDir := filepath.Join(dir, "transformer") + assert.NoError(t, os.MkdirAll(compDir, 0755)) + assert.NoError(t, os.WriteFile(filepath.Join(compDir, "model.safetensors"), []byte("data"), 0644)) + assert.True(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("diffusion model with float and null-array metadata is skipped", func(t *testing.T) { + dir := t.TempDir() + // Real-world pattern: boundary_ratio is a float, image_encoder is [null, null] + index := `{"_class_name":"Pipeline","boundary_ratio":0.9,"image_encoder":[null,null],"transformer":["diffusers","Model"],"vae":["diffusers","VAE"]}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(index), 0644)) + for _, comp := range []string{"transformer", "vae"} { + compDir := filepath.Join(dir, comp) + assert.NoError(t, os.MkdirAll(compDir, 0755)) + assert.NoError(t, os.WriteFile(filepath.Join(compDir, "config.json"), []byte(`{}`), 0644)) + } + // boundary_ratio and image_encoder should NOT require directories + assert.True(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + // --- Fallback: config.json + weight file --- + + t.Run("config.json and single safetensors file returns true", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("weight data"), 0644)) + assert.True(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("config.json and .bin weight file returns true", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"bert"}`), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "pytorch_model.bin"), []byte("weight data"), 0644)) + assert.True(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("config.json and .gguf weight file returns true", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-q4.gguf"), []byte("weight data"), 0644)) + assert.True(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("config.json and .pt weight file returns true", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.pt"), []byte("weight data"), 0644)) + assert.True(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("only weights without config.json returns false", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("weight data"), 0644)) + assert.False(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + // --- Strategy ordering tests --- + + t.Run("safetensors index takes priority over diffusion index", func(t *testing.T) { + dir := t.TempDir() + // Failing safetensors index (missing shard) + stIndex := `{"weight_map":{"w1":"shard-00001.safetensors","w2":"shard-00002.safetensors"}}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(stIndex), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "shard-00001.safetensors"), []byte("data"), 0644)) + // shard-00002 is missing + + // Passing diffusion index + diffIndex := `{"_class_name":"Pipeline","encoder":["diffusers","Encoder"]}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(diffIndex), 0644)) + assert.NoError(t, os.MkdirAll(filepath.Join(dir, "encoder"), 0755)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "encoder", "config.json"), []byte(`{}`), 0644)) + + // Safetensors check is authoritative — must return false despite valid diffusion layout + assert.False(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("safetensors index with many weights mapping to same shard returns true", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0644)) + index := `{"weight_map":{"layer1.weight":"model-00001-of-00001.safetensors","layer1.bias":"model-00001-of-00001.safetensors","layer2.weight":"model-00001-of-00001.safetensors"}}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(index), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00001-of-00001.safetensors"), []byte("data"), 0644)) + assert.True(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + // --- Additional diffusion edge cases --- + + t.Run("malformed diffusion index returns false", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(`{not valid json`), 0644)) + assert.False(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("diffusion index with only metadata keys returns false", func(t *testing.T) { + dir := t.TempDir() + index := `{"_class_name":"Pipeline","_diffusers_version":"0.36.0"}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(index), 0644)) + assert.False(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("diffusion component exists as file not directory returns false", func(t *testing.T) { + dir := t.TempDir() + index := `{"_class_name":"Pipeline","transformer":["diffusers","Model"]}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(index), 0644)) + // "transformer" is a file, not a directory + assert.NoError(t, os.WriteFile(filepath.Join(dir, "transformer"), []byte("not a directory"), 0644)) + assert.False(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("diffusion index with path traversal component key returns false", func(t *testing.T) { + dir := t.TempDir() + index := `{"_class_name":"Pipeline","../etc":["diffusers","Exploit"]}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(index), 0644)) + assert.False(t, gopher.isModelAlreadyDownloaded(dir)) + }) +} + func TestIsEligibleForOptimization_AlwaysDownloadNotEligible(t *testing.T) { nodeName := "node-1" sha := "123abc"