Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/modelagent/configmap_reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,7 @@ func (c *ConfigMapReconciler) FindMatchedModelFromConfigMap(configMap *corev1.Co
searchingError = fmt.Errorf("parentPath value for %q is not a string", k)
continue
}
if strings.ToLower(k) == strings.ToLower(currentModelTypeAndNodeName) {
if strings.EqualFold(k, currentModelTypeAndNodeName) {
continue
}

Expand Down
6 changes: 3 additions & 3 deletions pkg/modelagent/gopher.go
Original file line number Diff line number Diff line change
Expand Up @@ -1189,13 +1189,13 @@ func (s *Gopher) handelReuseArtifactIfNecessary(ctx context.Context, baseModelSp
var err error
// prioritize searching parent path in ClusterBaseModel
// with hoping different basemodel in different namespaces could be linked to the same parent path to lower the chance of downloading artifact
if strings.ToLower(modelType) == strings.ToLower(constants.ClusterBaseModel) || strings.ToLower(modelType) == strings.ToLower(constants.BaseModel) {
if strings.EqualFold(modelType, constants.ClusterBaseModel) || strings.EqualFold(modelType, constants.BaseModel) {
matchedModelTypeAndModelName, matchedParentPath, err = s.configMapReconciler.getModelDataByArtifactSha(ctx, shaStr, constants.LowerCaseClusterBaseModel, currentModelTypeAndNodeName)
if err != nil {
s.logger.Warnf("get error when finding matched model in configmap for model : %s: %s", modelName, err)
}
}
if strings.ToLower(modelType) == strings.ToLower(constants.BaseModel) && matchedModelTypeAndModelName == "" {
if strings.EqualFold(modelType, constants.BaseModel) && matchedModelTypeAndModelName == "" {
// build namespaced model type
namespacedModelType := fmt.Sprintf("%s.%s", namespace, constants.LowerCaseBaseModel)
matchedModelTypeAndModelName, matchedParentPath, err = s.configMapReconciler.getModelDataByArtifactSha(ctx, shaStr, namespacedModelType, currentModelTypeAndNodeName)
Expand Down Expand Up @@ -1412,7 +1412,7 @@ Parameters:
*/
func (s *Gopher) removeChildPathFromParentConfigMapIfNecessary(ctx context.Context, hasChildren bool, parentName string, modelTypeAndModelName string, destPath string) {
// if it does not have child, and its parent is not itself, need to remove the path from parent entry
if !hasChildren && strings.ToLower(parentName) != strings.ToLower(modelTypeAndModelName) {
if !hasChildren && !strings.EqualFold(parentName, modelTypeAndModelName) {
err := s.configMapReconciler.updateConfigMapWithRemovedChildPath(ctx, parentName, destPath)
if err != nil {
s.logger.Errorf("failed to remove model %s child path %s from parentName %s", modelTypeAndModelName, destPath, parentName)
Expand Down
57 changes: 0 additions & 57 deletions pkg/modelagent/gopher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,11 @@ import (
"fmt"
"testing"

"k8s.io/apimachinery/pkg/runtime/schema"

"go.uber.org/zap/zaptest"

"github.com/stretchr/testify/assert"
"go.uber.org/zap"
corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/labels"
k8sfake "k8s.io/client-go/kubernetes/fake"
Expand Down Expand Up @@ -1295,60 +1292,6 @@ func TestRemoveChildPathFromParentConfigMapIfNecessary_ErrorWhenParentMissing_No
assert.Equal(t, 0, countConfigMapUpdates(client), "no ConfigMap update should occur when parent key missing")
}

// Mocks specialized for isRemoveParentArtifactDirectory tests
type testBaseModelNamespaceLister struct {
models map[string]*v1beta1.BaseModel
}

func (l testBaseModelNamespaceLister) List(selector labels.Selector) ([]*v1beta1.BaseModel, error) {
out := make([]*v1beta1.BaseModel, 0, len(l.models))
for _, m := range l.models {
out = append(out, m)
}
return out, nil
}

func (l testBaseModelNamespaceLister) Get(name string) (*v1beta1.BaseModel, error) {
if m, ok := l.models[name]; ok {
return m, nil
}
return nil, apierrors.NewNotFound(schema.GroupResource{Group: "ome.io", Resource: "basemodels"}, name)
}

type testBaseModelLister struct {
byNS map[string]testBaseModelNamespaceLister
}

func (l testBaseModelLister) List(selector labels.Selector) ([]*v1beta1.BaseModel, error) {
return nil, nil
}

func (l testBaseModelLister) BaseModels(namespace string) omev1beta1lister.BaseModelNamespaceLister {
if ns, ok := l.byNS[namespace]; ok {
return ns
}
return testBaseModelNamespaceLister{models: map[string]*v1beta1.BaseModel{}}
}

type testClusterBaseModelLister struct {
models map[string]*v1beta1.ClusterBaseModel
}

func (l testClusterBaseModelLister) List(selector labels.Selector) ([]*v1beta1.ClusterBaseModel, error) {
out := make([]*v1beta1.ClusterBaseModel, 0, len(l.models))
for _, m := range l.models {
out = append(out, m)
}
return out, nil
}

func (l testClusterBaseModelLister) Get(name string) (*v1beta1.ClusterBaseModel, error) {
if m, ok := l.models[name]; ok {
return m, nil
}
return nil, apierrors.NewNotFound(schema.GroupResource{Group: "ome.io", Resource: "clusterbasemodels"}, name)
}

func TestIsRemoveParentArtifactDirectory_HasChildren_False(t *testing.T) {
cm := makeConfigMap("node-1", map[string]string{})
g, client := newGopherAndClientWithConfigMap(cm, t)
Expand Down
13 changes: 11 additions & 2 deletions pkg/modelagent/hf_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@ var (

// FetchAttributeFromHfModelMetaData retrieves a single top-level attribute from the Hugging Face model metadata endpoint for the provided modelId.
func FetchAttributeFromHfModelMetaData(ctx context.Context, modelId string, attribute string) (interface{}, error) {
modelMetaDataUrl, err := hfModelMetaDataUrl(modelId)
return fetchAttributeFromHfModelMetaDataWithEndpoint(ctx, modelId, attribute, DefaultEndpoint)
}

// fetchAttributeFromHfModelMetaDataWithEndpoint is the internal implementation that accepts a configurable base endpoint.
func fetchAttributeFromHfModelMetaDataWithEndpoint(ctx context.Context, modelId string, attribute string, endpoint string) (interface{}, error) {
modelMetaDataUrl, err := hfModelMetaDataUrlWithEndpoint(modelId, endpoint)
if err != nil {
return nil, fmt.Errorf("failed to build model metadata URL: %s", err)
}
Expand Down Expand Up @@ -253,6 +258,10 @@ func GetHTTPClient() *http.Client {
// Resulting URL format:
// {https://huggingface.co/api/models/{modelId}
func hfModelMetaDataUrl(modelId string) (string, error) {
return hfModelMetaDataUrlWithEndpoint(modelId, DefaultEndpoint)
}

func hfModelMetaDataUrlWithEndpoint(modelId string, endpoint string) (string, error) {
if modelId == "" {
return "", fmt.Errorf("no model name has been specified")
}
Expand All @@ -262,6 +271,6 @@ func hfModelMetaDataUrl(modelId string) (string, error) {
return "", fmt.Errorf("invalid model name %q: expected format <namespace>/<model>", modelId)
}

baseUrl := fmt.Sprintf("%s/%s", DefaultEndpoint, HfAPI)
baseUrl := fmt.Sprintf("%s/%s", endpoint, HfAPI)
return fmt.Sprintf("%s/models/%s", baseUrl, modelId), nil
}
30 changes: 17 additions & 13 deletions pkg/modelagent/hf_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,24 +47,23 @@ func TestFetchAttributeFromHfModelMetaData(t *testing.T) {
statusCode: 404,
wantErr: true,
expectedValue: "",
errMessageStr: "failed to invoke HuggingFace endpoint https://huggingface.co/api/models/deepseek-ai/DeepSeek-V3-unknown: response status code",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if !tt.wantErr || tt.statusCode > 0 {
server := createMockHfMetaDataServer(tt.statusCode, tt.attribute, tt.modelId)
defer server.Close()
}
server := createMockHfMetaDataServer(tt.statusCode, tt.attribute, tt.modelId)
defer server.Close()

ctx := context.Background()

value, err := FetchAttributeFromHfModelMetaData(ctx, tt.modelId, tt.attribute)
value, err := fetchAttributeFromHfModelMetaDataWithEndpoint(ctx, tt.modelId, tt.attribute, server.URL)

if tt.wantErr {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errMessageStr)
if tt.errMessageStr != "" {
assert.Contains(t, err.Error(), tt.errMessageStr)
}
} else {
require.NoError(t, err)
assert.Equal(t, tt.expectedValue, value)
Expand All @@ -79,20 +78,25 @@ func createMockHfMetaDataServer(statusCode int, attribute string, modelId string
if statusCode == 200 {
if modelId == "deepseek-ai/DeepSeek-V3" {
if attribute == "sha" {
var data map[string]interface{}
data[attribute] = "e815299b0bcbac849fa540c768ef21845365c9eb"
data := map[string]interface{}{
attribute: "e815299b0bcbac849fa540c768ef21845365c9eb",
}
bytes, _ := json.Marshal(data)
writer.Write(bytes)
} else {
writer.Write(make([]byte, 100))
// Return valid JSON without the requested attribute
data := map[string]interface{}{
"sha": "e815299b0bcbac849fa540c768ef21845365c9eb",
}
bytes, _ := json.Marshal(data)
writer.Write(bytes)
}

}
} else if statusCode == 404 {
} else {
writer.WriteHeader(statusCode)
writer.Write([]byte("Repository not found"))
}
}

}))

}
Expand Down
Loading