Skip to content

Commit f005d2f

Browse files
patelchaitanyntkathole
authored andcommitted
Added test for testing the new OpenAI api
Signed-off-by: Chaitany patel <[email protected]>
1 parent 4fbeb6e commit f005d2f

15 files changed

Lines changed: 1184 additions & 413 deletions

File tree

sdk/python/feast/feature_store.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,12 @@
6868
)
6969
from feast.feast_object import FeastObject
7070
from feast.feature_service import FeatureService
71-
from feast.feature_view import DUMMY_ENTITY, DUMMY_ENTITY_NAME, FeatureView
71+
from feast.feature_view import (
72+
DUMMY_ENTITY,
73+
DUMMY_ENTITY_ID,
74+
DUMMY_ENTITY_NAME,
75+
FeatureView,
76+
)
7277
from feast.filter_models import ComparisonFilter, CompoundFilter, convert_dict_to_filter
7378
from feast.inference import (
7479
update_data_sources_with_inferred_event_timestamp_col,
@@ -2956,8 +2961,8 @@ def retrieve_online_documents_v2(
29562961
top_k,
29572962
distance_metric,
29582963
query_string,
2959-
include_feature_view_version_metadata,
29602964
filters,
2965+
include_feature_view_version_metadata,
29612966
)
29622967

29632968
def retrieve_online_documents_openai(
@@ -3076,6 +3081,12 @@ def retrieve_online_documents_openai(
30763081

30773082
response_dict = response.to_dict()
30783083

3084+
entity_key_names = {
3085+
col.name
3086+
for col in feature_view.entity_columns
3087+
if col.name != DUMMY_ENTITY_ID
3088+
}
3089+
30793090
result_data = []
30803091
if response_dict:
30813092
first_key = next(iter(response_dict))
@@ -3089,14 +3100,24 @@ def retrieve_online_documents_openai(
30893100
val = values[i] if i < len(values) else None
30903101
if key == "distance":
30913102
score = float(val) if val is not None else 0.0
3092-
else:
3103+
elif key not in entity_key_names:
30933104
attributes[key] = val
30943105
if isinstance(val, str):
30953106
content_parts.append({"type": "text", "text": val})
30963107

3108+
if entity_key_names:
3109+
key_parts = [
3110+
str(response_dict[k][i])
3111+
for k in sorted(entity_key_names)
3112+
if k in response_dict and i < len(response_dict[k])
3113+
]
3114+
file_id = f"{vector_store_id}_{'_'.join(key_parts)}"
3115+
else:
3116+
file_id = f"{vector_store_id}_{i}"
3117+
30973118
result_data.append(
30983119
{
3099-
"file_id": f"{vector_store_id}_{i}",
3120+
"file_id": file_id,
31003121
"filename": vector_store_id,
31013122
"score": score,
31023123
"attributes": attributes,
@@ -3191,7 +3212,7 @@ def _retrieve_from_online_store_v2(
31913212
config=self.config,
31923213
table=table,
31933214
requested_features=requested_features,
3194-
query=query,
3215+
embedding=query,
31953216
top_k=top_k,
31963217
distance_metric=distance_metric,
31973218
query_string=query_string,

sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py

Lines changed: 148 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,61 @@ class ElasticSearchOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig):
4646

4747
# The number of rows to write in a single batch
4848
write_batch_size: Optional[int] = 40
49+
enable_openai_compatible_store: Optional[bool] = False
50+
51+
52+
logger = logging.getLogger(__name__)
53+
54+
_NUMERIC_COMPARISON_OPS = {"gt", "gte", "lt", "lte"}
55+
56+
57+
def _filters_contain_numeric_comparison(
58+
filter_obj: Union[ComparisonFilter, CompoundFilter],
59+
) -> bool:
60+
if isinstance(filter_obj, ComparisonFilter):
61+
return filter_obj.type in _NUMERIC_COMPARISON_OPS and isinstance(
62+
filter_obj.value, (int, float)
63+
)
64+
if isinstance(filter_obj, CompoundFilter):
65+
return any(_filters_contain_numeric_comparison(f) for f in filter_obj.filters)
66+
return False
4967

5068

5169
class ElasticSearchOnlineStore(OnlineStore):
5270
_client: Optional[Elasticsearch] = None
71+
_index_value_num_cache: Optional[Dict[str, bool]] = None
72+
73+
def _index_has_value_num(self, config: RepoConfig, index_name: str) -> bool:
74+
"""Check the actual ES index mapping for the value_num field.
75+
76+
Caches the result per index so we only hit ES once.
77+
"""
78+
if self._index_value_num_cache is None:
79+
self._index_value_num_cache = {}
80+
if index_name in self._index_value_num_cache:
81+
return self._index_value_num_cache[index_name]
82+
try:
83+
mapping = self._get_client(config).indices.get_mapping(index=index_name)
84+
templates = (
85+
mapping.get(index_name, {})
86+
.get("mappings", {})
87+
.get("dynamic_templates", [])
88+
)
89+
for tmpl in templates:
90+
for _, tmpl_body in tmpl.items():
91+
props = tmpl_body.get("mapping", {}).get("properties", {})
92+
if "value_num" in props:
93+
self._index_value_num_cache[index_name] = True
94+
return True
95+
except Exception as e:
96+
logging.warning(
97+
"Failed to check index mapping for value_num on '%s': %s: %s",
98+
index_name,
99+
type(e).__name__,
100+
e,
101+
)
102+
self._index_value_num_cache[index_name] = False
103+
return False
53104

54105
def _get_client(self, config: RepoConfig) -> Elasticsearch:
55106
online_store_config = config.online_store
@@ -95,6 +146,7 @@ def online_write_batch(
95146
progress: Optional[Callable[[int], Any]],
96147
) -> None:
97148
insert_values = []
149+
include_value_num = self._index_has_value_num(config, table.name)
98150
grouped_docs: dict[str, dict[str, Any]] = defaultdict(
99151
lambda: {
100152
"features": {},
@@ -116,7 +168,7 @@ def online_write_batch(
116168
doc_key = f"{encoded_entity_key}_{timestamp}"
117169

118170
for feature_name, value in values.items():
119-
doc = _encode_feature_value(value)
171+
doc = _encode_feature_value(value, include_value_num=include_value_num)
120172
grouped_docs[doc_key]["features"][feature_name] = doc
121173
grouped_docs[doc_key]["timestamp"] = timestamp
122174
grouped_docs[doc_key]["created_ts"] = created_ts
@@ -210,6 +262,20 @@ def create_index(self, config: RepoConfig, table: FeatureView):
210262
_get_feature_view_vector_field_metadata(table), "vector_length", 512
211263
)
212264

265+
feature_properties: Dict[str, Any] = {
266+
"feature_value": {"type": "binary"},
267+
"value_text": {"type": "text"},
268+
"vector_value": {
269+
"type": "dense_vector",
270+
"dims": vector_field_length,
271+
"index": True,
272+
"similarity": config.online_store.similarity,
273+
},
274+
}
275+
276+
if getattr(config.online_store, "enable_openai_compatible_store", False):
277+
feature_properties["value_num"] = {"type": "double"}
278+
213279
index_mapping = {
214280
"dynamic_templates": [
215281
{
@@ -218,16 +284,7 @@ def create_index(self, config: RepoConfig, table: FeatureView):
218284
"match": "*",
219285
"mapping": {
220286
"type": "object",
221-
"properties": {
222-
"feature_value": {"type": "binary"},
223-
"value_text": {"type": "text"},
224-
"vector_value": {
225-
"type": "dense_vector",
226-
"dims": vector_field_length,
227-
"index": True,
228-
"similarity": config.online_store.similarity,
229-
},
230-
},
287+
"properties": feature_properties,
231288
},
232289
}
233290
}
@@ -343,6 +400,7 @@ def retrieve_online_documents(
343400
def _translate_filters(
344401
self,
345402
filters: Optional[Union[ComparisonFilter, CompoundFilter]],
403+
has_value_num: bool = False,
346404
) -> List[Dict[str, Any]]:
347405
"""Translate filter objects into Elasticsearch Query DSL filter clauses.
348406
@@ -352,62 +410,75 @@ def _translate_filters(
352410
"""
353411
if filters is None:
354412
return []
355-
return [self._translate_single_filter(filters)]
413+
return [self._translate_single_filter(filters, has_value_num=has_value_num)]
356414

357415
def _translate_single_filter(
358416
self,
359417
filter_obj: Union[ComparisonFilter, CompoundFilter],
418+
has_value_num: bool = False,
360419
) -> Dict[str, Any]:
361420
if isinstance(filter_obj, ComparisonFilter):
362-
return self._translate_comparison_filter(filter_obj)
421+
return self._translate_comparison_filter(
422+
filter_obj, has_value_num=has_value_num
423+
)
363424
elif isinstance(filter_obj, CompoundFilter):
364-
return self._translate_compound_filter(filter_obj)
425+
return self._translate_compound_filter(
426+
filter_obj, has_value_num=has_value_num
427+
)
365428
raise ValueError(f"Unknown filter type: {type(filter_obj)}")
366429

367430
def _translate_comparison_filter(
368431
self,
369432
f: ComparisonFilter,
433+
has_value_num: bool = False,
370434
) -> Dict[str, Any]:
371-
"""Translate a ComparisonFilter to an ES Query DSL clause.
435+
"""Translate a ComparisonFilter to an ES Query DSL clause."""
436+
is_numeric = isinstance(f.value, (int, float)) and not isinstance(f.value, bool)
437+
is_numeric_list = (
438+
isinstance(f.value, list)
439+
and f.value
440+
and isinstance(f.value[0], (int, float))
441+
and not isinstance(f.value[0], bool)
442+
)
372443

373-
Feature values in Elasticsearch are stored under
374-
``<feature_name>.value_text``, so filters target that nested path.
375-
"""
376-
field = f"{f.key}.value_text"
444+
if has_value_num and (is_numeric or is_numeric_list):
445+
field = f"{f.key}.value_num"
446+
fmt_val = f.value
447+
fmt_list = f.value if is_numeric_list else None
448+
else:
449+
field = f"{f.key}.value_text"
450+
fmt_val = str(f.value)
451+
fmt_list = [str(v) for v in f.value] if isinstance(f.value, list) else None
377452

378453
if f.type == "eq":
379-
return {"term": {field: str(f.value)}}
454+
return {"term": {field: fmt_val}}
380455
elif f.type == "ne":
381-
return {"bool": {"must_not": [{"term": {field: str(f.value)}}]}}
382-
elif f.type == "gt":
383-
return {"range": {field: {"gt": f.value}}}
384-
elif f.type == "gte":
385-
return {"range": {field: {"gte": f.value}}}
386-
elif f.type == "lt":
387-
return {"range": {field: {"lt": f.value}}}
388-
elif f.type == "lte":
389-
return {"range": {field: {"lte": f.value}}}
456+
return {"bool": {"must_not": [{"term": {field: fmt_val}}]}}
457+
elif f.type in ("gt", "gte", "lt", "lte"):
458+
return {"range": {field: {f.type: fmt_val}}}
390459
elif f.type == "in":
391460
if not isinstance(f.value, list):
392461
raise ValueError(
393462
f"'in' filter requires a list value, got {type(f.value)}"
394463
)
395-
return {"terms": {field: [str(v) for v in f.value]}}
464+
return {"terms": {field: fmt_list}}
396465
elif f.type == "nin":
397466
if not isinstance(f.value, list):
398467
raise ValueError(
399468
f"'nin' filter requires a list value, got {type(f.value)}"
400469
)
401-
return {
402-
"bool": {"must_not": [{"terms": {field: [str(v) for v in f.value]}}]}
403-
}
470+
return {"bool": {"must_not": [{"terms": {field: fmt_list}}]}}
404471
raise ValueError(f"Unsupported comparison operator: {f.type}")
405472

406473
def _translate_compound_filter(
407474
self,
408475
f: CompoundFilter,
476+
has_value_num: bool = False,
409477
) -> Dict[str, Any]:
410-
clauses = [self._translate_single_filter(sub) for sub in f.filters]
478+
clauses = [
479+
self._translate_single_filter(sub, has_value_num=has_value_num)
480+
for sub in f.filters
481+
]
411482
if f.type == "and":
412483
return {"bool": {"must": clauses}}
413484
else:
@@ -422,8 +493,8 @@ def retrieve_online_documents_v2(
422493
top_k: int,
423494
distance_metric: Optional[str] = None,
424495
query_string: Optional[str] = None,
425-
include_feature_view_version_metadata: bool = False,
426496
filters: Optional[Union[ComparisonFilter, CompoundFilter]] = None,
497+
include_feature_view_version_metadata: bool = False,
427498
) -> List[
428499
Tuple[
429500
Optional[datetime],
@@ -458,7 +529,24 @@ def retrieve_online_documents_v2(
458529
source_fields += composite_key_name
459530
body["_source"] = source_fields
460531

461-
metadata_filters = self._translate_filters(filters)
532+
has_value_num = self._index_has_value_num(config, es_index)
533+
534+
if (
535+
filters
536+
and _filters_contain_numeric_comparison(filters)
537+
and not has_value_num
538+
):
539+
logger.warning(
540+
"Numeric comparison filters (gt, gte, lt, lte) are being used "
541+
"but this index does not have a 'value_num' field. Numeric "
542+
"fields are stored as text, which causes lexicographic "
543+
"comparison instead of numeric comparison (e.g. '9' > '100'). "
544+
"To fix this, set 'enable_openai_compatible_store: true' in "
545+
"your online_store config, then teardown and re-apply your "
546+
"feature store to recreate indices with the value_num field."
547+
)
548+
549+
metadata_filters = self._translate_filters(filters, has_value_num=has_value_num)
462550

463551
if embedding:
464552
similarity = (distance_metric or config.online_store.similarity).lower()
@@ -554,14 +642,14 @@ def _to_value_proto(value: Any) -> ValueProto:
554642
val_proto = ValueProto()
555643
if isinstance(value, ValueProto):
556644
return value
557-
if isinstance(value, float):
645+
if isinstance(value, bool):
646+
val_proto.bool_val = value
647+
elif isinstance(value, float):
558648
val_proto.float_val = value
559649
elif isinstance(value, str):
560650
val_proto.string_val = value
561651
elif isinstance(value, int):
562652
val_proto.int64_val = value
563-
elif isinstance(value, bool):
564-
val_proto.bool_val = value
565653
elif isinstance(value, list) and all(isinstance(v, float) for v in value):
566654
val_proto.float_list_val.val.extend(value)
567655
elif isinstance(value, dict) and "feature_value" in value:
@@ -575,12 +663,15 @@ def _to_value_proto(value: Any) -> ValueProto:
575663
return val_proto
576664

577665

578-
def _encode_feature_value(value: ValueProto) -> Dict[str, Any]:
666+
def _encode_feature_value(
667+
value: ValueProto,
668+
include_value_num: bool = False,
669+
) -> Dict[str, Any]:
579670
"""
580671
Encode a ValueProto into a dictionary for Elasticsearch storage.
581672
"""
582673
encoded_value = base64.b64encode(value.SerializeToString()).decode("utf-8")
583-
result = {"feature_value": encoded_value}
674+
result: Dict[str, Any] = {"feature_value": encoded_value}
584675
vector_val = get_list_val_str(value)
585676

586677
if vector_val:
@@ -591,8 +682,24 @@ def _encode_feature_value(value: ValueProto) -> Dict[str, Any]:
591682
result["value_text"] = value.bytes_val.decode("utf-8")
592683
elif value.HasField("int64_val"):
593684
result["value_text"] = str(value.int64_val)
685+
if include_value_num:
686+
result["value_num"] = value.int64_val
687+
elif value.HasField("int32_val"):
688+
result["value_text"] = str(value.int32_val)
689+
if include_value_num:
690+
result["value_num"] = value.int32_val
594691
elif value.HasField("double_val"):
595692
result["value_text"] = str(value.double_val)
693+
if include_value_num:
694+
result["value_num"] = value.double_val
695+
elif value.HasField("float_val"):
696+
result["value_text"] = str(value.float_val)
697+
if include_value_num:
698+
result["value_num"] = value.float_val
699+
elif value.HasField("bool_val"):
700+
result["value_text"] = str(value.bool_val)
701+
if include_value_num:
702+
result["value_num"] = 1.0 if value.bool_val else 0.0
596703
return result
597704

598705

0 commit comments

Comments
 (0)