Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
deserialize_entity_key,
serialize_entity_key,
)
from feast.infra.online_stores.helpers import compute_table_id
from feast.infra.online_stores.online_store import OnlineStore
from feast.infra.online_stores.vector_store import VectorStoreConfig
from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto
Expand Down Expand Up @@ -164,7 +165,9 @@ def _get_or_create_collection(
) -> Dict[str, Any]:
self.client = self._connect(config)
vector_field_dict = {k.name: k for k in table.schema if k.vector_index}
collection_name = _table_id(config.project, table)
collection_name = _table_id(
config.project, table, config.registry.enable_online_feature_view_versioning
)
if collection_name not in self._collections:
# Create a composite key by combining entity fields
composite_key_name = _get_composite_key_name(table)
Expand Down Expand Up @@ -346,7 +349,9 @@ def online_read(
requested_features: Optional[List[str]] = None,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
self.client = self._connect(config)
collection_name = _table_id(config.project, table)
collection_name = _table_id(
config.project, table, config.registry.enable_online_feature_view_versioning
)
collection = self._get_or_create_collection(config, table)

composite_key_name = _get_composite_key_name(table)
Expand Down Expand Up @@ -493,11 +498,12 @@ def update(
for table in tables_to_keep:
self._get_or_create_collection(config, table)

# Always drop the base collection plus any "_v{N}" siblings, regardless of
# the current versioning flag. This handles mixed-state repos where
# versioning was toggled on/off across applies and would otherwise leave
# orphan collections behind in Milvus.
for table in tables_to_delete:
collection_name = _table_id(config.project, table)
if self._collections.get(collection_name, None):
self.client.drop_collection(collection_name)
self._collections.pop(collection_name, None)
self._drop_all_version_collections(config.project, table)

def plan(
self, config: RepoConfig, desired_registry_proto: RegistryProto
Expand All @@ -511,11 +517,9 @@ def teardown(
entities: Sequence[Entity],
):
self.client = self._connect(config)
# See update(): drop base + all "_v{N}" siblings to handle mixed-state repos.
for table in tables:
collection_name = _table_id(config.project, table)
if self._collections.get(collection_name, None):
self.client.drop_collection(collection_name)
self._collections.pop(collection_name, None)
self._drop_all_version_collections(config.project, table)

def retrieve_online_documents_v2(
self,
Expand Down Expand Up @@ -551,7 +555,9 @@ def retrieve_online_documents_v2(
k.name: k.dtype for k in table.entity_columns
}
self.client = self._connect(config)
collection_name = _table_id(config.project, table)
collection_name = _table_id(
config.project, table, config.registry.enable_online_feature_view_versioning
)
collection = self._get_or_create_collection(config, table)
if not config.online_store.vector_enabled:
raise ValueError("Vector search is not enabled in the online store config")
Expand Down Expand Up @@ -748,9 +754,28 @@ def retrieve_online_documents_v2(
result_list.append((res_ts, entity_key_proto, res if res else None))
return result_list

def _drop_all_version_collections(self, project: str, table: FeatureView) -> None:
"""Drop the base collection and every ``_v{N}`` versioned sibling.

Mirrors the ``_drop_all_version_tables`` helpers in the MySQL/PostgreSQL
online stores. Always called from ``update`` and ``teardown`` so a
repo that toggles versioning on and off does not leave orphan
collections behind in Milvus.
"""
base = f"{project}_{table.name}"
versioned_prefix = f"{base}_v"
assert self.client is not None, "Milvus client is not initialized"
for collection_name in self.client.list_collections():
if collection_name == base or (
collection_name.startswith(versioned_prefix)
and collection_name[len(versioned_prefix) :].isdigit()
):
self.client.drop_collection(collection_name)
self._collections.pop(collection_name, None)


def _table_id(project: str, table: FeatureView) -> str:
return f"{project}_{table.name}"
def _table_id(project: str, table: FeatureView, enable_versioning: bool = False) -> str:
return compute_table_id(project, table, enable_versioning)


def _get_composite_key_name(table: FeatureView) -> str:
Expand Down
8 changes: 8 additions & 0 deletions sdk/python/feast/infra/online_stores/online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,14 @@ def _check_versioned_read_support(self, grouped_refs):
supported_types.append(DynamoDBOnlineStore)
except Exception:
pass
try:
from feast.infra.online_stores.milvus_online_store.milvus import (
MilvusOnlineStore,
)

supported_types.append(MilvusOnlineStore)
except ImportError:
pass

if isinstance(self, tuple(supported_types)):
return
Expand Down
180 changes: 180 additions & 0 deletions sdk/python/tests/unit/infra/online_store/test_milvus_versioning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
"""Unit tests for Milvus online store feature view versioning."""

from datetime import timedelta
from unittest.mock import MagicMock

from feast import Entity, FeatureView
from feast.field import Field
from feast.types import Float32
from feast.value_type import ValueType


def _make_feature_view(name="driver_stats", version_number=None, version_tag=None):
entity = Entity(
name="driver_id",
join_keys=["driver_id"],
value_type=ValueType.INT64,
)
fv = FeatureView(
name=name,
entities=[entity],
ttl=timedelta(days=1),
schema=[Field(name="trips_today", dtype=Float32)],
)
if version_number is not None:
fv.current_version_number = version_number
if version_tag is not None:
fv.projection.version_tag = version_tag
return fv


def _make_config(project="test_project", versioning=False):
config = MagicMock()
config.project = project
config.entity_key_serialization_version = 2
config.registry.enable_online_feature_view_versioning = versioning
return config


class TestTableId:
"""Test _table_id with versioning enabled/disabled."""

def test_no_versioning(self):
from feast.infra.online_stores.milvus_online_store.milvus import _table_id

fv = _make_feature_view()
config = _make_config(versioning=False)
assert _table_id(config.project, fv) == "test_project_driver_stats"

def test_versioning_enabled_with_version(self):
from feast.infra.online_stores.milvus_online_store.milvus import _table_id

fv = _make_feature_view(version_number=2)
config = _make_config(versioning=True)
assert (
_table_id(config.project, fv, enable_versioning=True)
== "test_project_driver_stats_v2"
)

def test_projection_version_tag_takes_priority(self):
from feast.infra.online_stores.milvus_online_store.milvus import _table_id

fv = _make_feature_view(version_number=1, version_tag=3)
config = _make_config(versioning=True)
assert (
_table_id(config.project, fv, enable_versioning=True)
== "test_project_driver_stats_v3"
)

def test_version_zero_no_suffix(self):
from feast.infra.online_stores.milvus_online_store.milvus import _table_id

fv = _make_feature_view(version_number=0)
config = _make_config(versioning=True)
assert (
_table_id(config.project, fv, enable_versioning=True)
== "test_project_driver_stats"
)

def test_versioning_enabled_no_version_set(self):
from feast.infra.online_stores.milvus_online_store.milvus import _table_id

fv = _make_feature_view()
config = _make_config(versioning=True)
assert (
_table_id(config.project, fv, enable_versioning=True)
== "test_project_driver_stats"
)

def test_versioning_disabled_ignores_version(self):
from feast.infra.online_stores.milvus_online_store.milvus import _table_id

fv = _make_feature_view(version_number=5)
config = _make_config(versioning=False)
assert _table_id(config.project, fv) == "test_project_driver_stats"


class TestMilvusVersionedReadSupport:
"""Test that MilvusOnlineStore passes _check_versioned_read_support."""

def test_allowed_with_version_tag(self):
from feast.infra.online_stores.milvus_online_store.milvus import (
MilvusOnlineStore,
)

store = MilvusOnlineStore()
fv = _make_feature_view()
fv.projection.version_tag = 2
store._check_versioned_read_support([(fv, ["trips_today"])])

def test_allowed_without_version_tag(self):
from feast.infra.online_stores.milvus_online_store.milvus import (
MilvusOnlineStore,
)

store = MilvusOnlineStore()
fv = _make_feature_view()
store._check_versioned_read_support([(fv, ["trips_today"])])


class TestTeardownDropsAllVersions:
"""Teardown should drop the base collection AND all versioned collections."""

def _build_store_with_collections(self, existing_collections):
from feast.infra.online_stores.milvus_online_store.milvus import (
MilvusOnlineStore,
)

store = MilvusOnlineStore()
store.client = MagicMock()
store.client.list_collections.return_value = existing_collections
store._connect = MagicMock(return_value=store.client)
store._collections = {name: MagicMock() for name in existing_collections}
return store

def test_teardown_drops_base_and_all_versioned_collections(self):
fv = _make_feature_view()
config = _make_config(versioning=True)
existing = [
"test_project_driver_stats",
"test_project_driver_stats_v1",
"test_project_driver_stats_v2",
"test_project_other_view", # unrelated, must not be dropped
]
store = self._build_store_with_collections(existing)

store.teardown(config, [fv], [])

dropped = {call.args[0] for call in store.client.drop_collection.call_args_list}
assert dropped == {
"test_project_driver_stats",
"test_project_driver_stats_v1",
"test_project_driver_stats_v2",
}
assert "test_project_other_view" not in dropped

def test_update_drops_all_versions_for_deleted_table(self):
fv = _make_feature_view()
config = _make_config(versioning=True)
existing = [
"test_project_driver_stats",
"test_project_driver_stats_v3",
"test_project_driver_stats_v4",
]
store = self._build_store_with_collections(existing)

store.update(
config=config,
tables_to_delete=[fv],
tables_to_keep=[],
entities_to_delete=[],
entities_to_keep=[],
partial=False,
)

dropped = {call.args[0] for call in store.client.drop_collection.call_args_list}
assert dropped == {
"test_project_driver_stats",
"test_project_driver_stats_v3",
"test_project_driver_stats_v4",
}
Loading