Skip to content

Commit f0552c7

Browse files
committed
Added test for testing the new OpenAI api
Signed-off-by: Chaitany patel <[email protected]>
1 parent e96a116 commit f0552c7

9 files changed

Lines changed: 1203 additions & 334 deletions

File tree

sdk/python/feast/feature_store.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,12 @@
6868
)
6969
from feast.feast_object import FeastObject
7070
from feast.feature_service import FeatureService
71-
from feast.feature_view import DUMMY_ENTITY, DUMMY_ENTITY_NAME, FeatureView
71+
from feast.feature_view import (
72+
DUMMY_ENTITY,
73+
DUMMY_ENTITY_ID,
74+
DUMMY_ENTITY_NAME,
75+
FeatureView,
76+
)
7277
from feast.filter_models import ComparisonFilter, CompoundFilter, convert_dict_to_filter
7378
from feast.inference import (
7479
update_data_sources_with_inferred_event_timestamp_col,
@@ -2950,6 +2955,12 @@ def retrieve_online_documents_openai(
29502955

29512956
response_dict = response.to_dict()
29522957

2958+
entity_key_names = {
2959+
col.name
2960+
for col in feature_view.entity_columns
2961+
if col.name != DUMMY_ENTITY_ID
2962+
}
2963+
29532964
result_data = []
29542965
if response_dict:
29552966
first_key = next(iter(response_dict))
@@ -2963,14 +2974,24 @@ def retrieve_online_documents_openai(
29632974
val = values[i] if i < len(values) else None
29642975
if key == "distance":
29652976
score = float(val) if val is not None else 0.0
2966-
else:
2977+
elif key not in entity_key_names:
29672978
attributes[key] = val
29682979
if isinstance(val, str):
29692980
content_parts.append({"type": "text", "text": val})
29702981

2982+
if entity_key_names:
2983+
key_parts = [
2984+
str(response_dict[k][i])
2985+
for k in sorted(entity_key_names)
2986+
if k in response_dict and i < len(response_dict[k])
2987+
]
2988+
file_id = f"{vector_store_id}_{'_'.join(key_parts)}"
2989+
else:
2990+
file_id = f"{vector_store_id}_{i}"
2991+
29712992
result_data.append(
29722993
{
2973-
"file_id": f"{vector_store_id}_{i}",
2994+
"file_id": file_id,
29742995
"filename": vector_store_id,
29752996
"score": score,
29762997
"attributes": attributes,

sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py

Lines changed: 139 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,56 @@ class ElasticSearchOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig):
4646

4747
# The number of rows to write in a single batch
4848
write_batch_size: Optional[int] = 40
49+
enable_openai_compatible_store: Optional[bool] = False
50+
51+
52+
logger = logging.getLogger(__name__)
53+
54+
_NUMERIC_COMPARISON_OPS = {"gt", "gte", "lt", "lte"}
55+
56+
57+
def _filters_contain_numeric_comparison(
58+
filter_obj: Union[ComparisonFilter, CompoundFilter],
59+
) -> bool:
60+
if isinstance(filter_obj, ComparisonFilter):
61+
return filter_obj.type in _NUMERIC_COMPARISON_OPS and isinstance(
62+
filter_obj.value, (int, float)
63+
)
64+
if isinstance(filter_obj, CompoundFilter):
65+
return any(_filters_contain_numeric_comparison(f) for f in filter_obj.filters)
66+
return False
4967

5068

5169
class ElasticSearchOnlineStore(OnlineStore):
5270
_client: Optional[Elasticsearch] = None
71+
_index_value_num_cache: Optional[Dict[str, bool]] = None
72+
73+
def _index_has_value_num(self, config: RepoConfig, index_name: str) -> bool:
74+
"""Check the actual ES index mapping for the value_num field.
75+
76+
Caches the result per index so we only hit ES once.
77+
"""
78+
if self._index_value_num_cache is None:
79+
self._index_value_num_cache = {}
80+
if index_name in self._index_value_num_cache:
81+
return self._index_value_num_cache[index_name]
82+
try:
83+
mapping = self._get_client(config).indices.get_mapping(index=index_name)
84+
templates = (
85+
mapping.get(index_name, {})
86+
.get("mappings", {})
87+
.get("dynamic_templates", [])
88+
)
89+
for tmpl in templates:
90+
for _, tmpl_body in tmpl.items():
91+
props = tmpl_body.get("mapping", {}).get("properties", {})
92+
if "value_num" in props:
93+
self._index_value_num_cache[index_name] = True
94+
return True
95+
except Exception:
96+
pass
97+
self._index_value_num_cache[index_name] = False
98+
return False
5399

54100
def _get_client(self, config: RepoConfig) -> Elasticsearch:
55101
online_store_config = config.online_store
@@ -95,6 +141,7 @@ def online_write_batch(
95141
progress: Optional[Callable[[int], Any]],
96142
) -> None:
97143
insert_values = []
144+
include_value_num = self._index_has_value_num(config, table.name)
98145
grouped_docs: dict[str, dict[str, Any]] = defaultdict(
99146
lambda: {
100147
"features": {},
@@ -116,7 +163,7 @@ def online_write_batch(
116163
doc_key = f"{encoded_entity_key}_{timestamp}"
117164

118165
for feature_name, value in values.items():
119-
doc = _encode_feature_value(value)
166+
doc = _encode_feature_value(value, include_value_num=include_value_num)
120167
grouped_docs[doc_key]["features"][feature_name] = doc
121168
grouped_docs[doc_key]["timestamp"] = timestamp
122169
grouped_docs[doc_key]["created_ts"] = created_ts
@@ -210,6 +257,20 @@ def create_index(self, config: RepoConfig, table: FeatureView):
210257
_get_feature_view_vector_field_metadata(table), "vector_length", 512
211258
)
212259

260+
feature_properties: Dict[str, Any] = {
261+
"feature_value": {"type": "binary"},
262+
"value_text": {"type": "text"},
263+
"vector_value": {
264+
"type": "dense_vector",
265+
"dims": vector_field_length,
266+
"index": True,
267+
"similarity": config.online_store.similarity,
268+
},
269+
}
270+
271+
if getattr(config.online_store, "enable_openai_compatible_store", False):
272+
feature_properties["value_num"] = {"type": "double"}
273+
213274
index_mapping = {
214275
"dynamic_templates": [
215276
{
@@ -218,16 +279,7 @@ def create_index(self, config: RepoConfig, table: FeatureView):
218279
"match": "*",
219280
"mapping": {
220281
"type": "object",
221-
"properties": {
222-
"feature_value": {"type": "binary"},
223-
"value_text": {"type": "text"},
224-
"vector_value": {
225-
"type": "dense_vector",
226-
"dims": vector_field_length,
227-
"index": True,
228-
"similarity": config.online_store.similarity,
229-
},
230-
},
282+
"properties": feature_properties,
231283
},
232284
}
233285
}
@@ -344,6 +396,7 @@ def retrieve_online_documents(
344396
def _translate_filters(
345397
self,
346398
filters: Optional[Union[ComparisonFilter, CompoundFilter]],
399+
has_value_num: bool = False,
347400
) -> List[Dict[str, Any]]:
348401
"""Translate filter objects into Elasticsearch Query DSL filter clauses.
349402
@@ -353,62 +406,75 @@ def _translate_filters(
353406
"""
354407
if filters is None:
355408
return []
356-
return [self._translate_single_filter(filters)]
409+
return [self._translate_single_filter(filters, has_value_num=has_value_num)]
357410

358411
def _translate_single_filter(
359412
self,
360413
filter_obj: Union[ComparisonFilter, CompoundFilter],
414+
has_value_num: bool = False,
361415
) -> Dict[str, Any]:
362416
if isinstance(filter_obj, ComparisonFilter):
363-
return self._translate_comparison_filter(filter_obj)
417+
return self._translate_comparison_filter(
418+
filter_obj, has_value_num=has_value_num
419+
)
364420
elif isinstance(filter_obj, CompoundFilter):
365-
return self._translate_compound_filter(filter_obj)
421+
return self._translate_compound_filter(
422+
filter_obj, has_value_num=has_value_num
423+
)
366424
raise ValueError(f"Unknown filter type: {type(filter_obj)}")
367425

368426
def _translate_comparison_filter(
369427
self,
370428
f: ComparisonFilter,
429+
has_value_num: bool = False,
371430
) -> Dict[str, Any]:
372-
"""Translate a ComparisonFilter to an ES Query DSL clause.
431+
"""Translate a ComparisonFilter to an ES Query DSL clause."""
432+
is_numeric = isinstance(f.value, (int, float)) and not isinstance(f.value, bool)
433+
is_numeric_list = (
434+
isinstance(f.value, list)
435+
and f.value
436+
and isinstance(f.value[0], (int, float))
437+
and not isinstance(f.value[0], bool)
438+
)
373439

374-
Feature values in Elasticsearch are stored under
375-
``<feature_name>.value_text``, so filters target that nested path.
376-
"""
377-
field = f"{f.key}.value_text"
440+
if has_value_num and (is_numeric or is_numeric_list):
441+
field = f"{f.key}.value_num"
442+
fmt_val = f.value
443+
fmt_list = f.value if is_numeric_list else None
444+
else:
445+
field = f"{f.key}.value_text"
446+
fmt_val = str(f.value)
447+
fmt_list = [str(v) for v in f.value] if isinstance(f.value, list) else None
378448

379449
if f.type == "eq":
380-
return {"term": {field: str(f.value)}}
450+
return {"term": {field: fmt_val}}
381451
elif f.type == "ne":
382-
return {"bool": {"must_not": [{"term": {field: str(f.value)}}]}}
383-
elif f.type == "gt":
384-
return {"range": {field: {"gt": f.value}}}
385-
elif f.type == "gte":
386-
return {"range": {field: {"gte": f.value}}}
387-
elif f.type == "lt":
388-
return {"range": {field: {"lt": f.value}}}
389-
elif f.type == "lte":
390-
return {"range": {field: {"lte": f.value}}}
452+
return {"bool": {"must_not": [{"term": {field: fmt_val}}]}}
453+
elif f.type in ("gt", "gte", "lt", "lte"):
454+
return {"range": {field: {f.type: fmt_val}}}
391455
elif f.type == "in":
392456
if not isinstance(f.value, list):
393457
raise ValueError(
394458
f"'in' filter requires a list value, got {type(f.value)}"
395459
)
396-
return {"terms": {field: [str(v) for v in f.value]}}
460+
return {"terms": {field: fmt_list}}
397461
elif f.type == "nin":
398462
if not isinstance(f.value, list):
399463
raise ValueError(
400464
f"'nin' filter requires a list value, got {type(f.value)}"
401465
)
402-
return {
403-
"bool": {"must_not": [{"terms": {field: [str(v) for v in f.value]}}]}
404-
}
466+
return {"bool": {"must_not": [{"terms": {field: fmt_list}}]}}
405467
raise ValueError(f"Unsupported comparison operator: {f.type}")
406468

407469
def _translate_compound_filter(
408470
self,
409471
f: CompoundFilter,
472+
has_value_num: bool = False,
410473
) -> Dict[str, Any]:
411-
clauses = [self._translate_single_filter(sub) for sub in f.filters]
474+
clauses = [
475+
self._translate_single_filter(sub, has_value_num=has_value_num)
476+
for sub in f.filters
477+
]
412478
if f.type == "and":
413479
return {"bool": {"must": clauses}}
414480
else:
@@ -458,7 +524,24 @@ def retrieve_online_documents_v2(
458524
source_fields += composite_key_name
459525
body["_source"] = source_fields
460526

461-
metadata_filters = self._translate_filters(filters)
527+
has_value_num = self._index_has_value_num(config, es_index)
528+
529+
if (
530+
filters
531+
and _filters_contain_numeric_comparison(filters)
532+
and not has_value_num
533+
):
534+
logger.warning(
535+
"Numeric comparison filters (gt, gte, lt, lte) are being used "
536+
"but this index does not have a 'value_num' field. Numeric "
537+
"fields are stored as text, which causes lexicographic "
538+
"comparison instead of numeric comparison (e.g. '9' > '100'). "
539+
"To fix this, set 'enable_openai_compatible_store: true' in "
540+
"your online_store config, then teardown and re-apply your "
541+
"feature store to recreate indices with the value_num field."
542+
)
543+
544+
metadata_filters = self._translate_filters(filters, has_value_num=has_value_num)
462545

463546
if embedding:
464547
similarity = (distance_metric or config.online_store.similarity).lower()
@@ -575,12 +658,15 @@ def _to_value_proto(value: Any) -> ValueProto:
575658
return val_proto
576659

577660

578-
def _encode_feature_value(value: ValueProto) -> Dict[str, Any]:
661+
def _encode_feature_value(
662+
value: ValueProto,
663+
include_value_num: bool = False,
664+
) -> Dict[str, Any]:
579665
"""
580666
Encode a ValueProto into a dictionary for Elasticsearch storage.
581667
"""
582668
encoded_value = base64.b64encode(value.SerializeToString()).decode("utf-8")
583-
result = {"feature_value": encoded_value}
669+
result: Dict[str, Any] = {"feature_value": encoded_value}
584670
vector_val = get_list_val_str(value)
585671

586672
if vector_val:
@@ -591,8 +677,24 @@ def _encode_feature_value(value: ValueProto) -> Dict[str, Any]:
591677
result["value_text"] = value.bytes_val.decode("utf-8")
592678
elif value.HasField("int64_val"):
593679
result["value_text"] = str(value.int64_val)
680+
if include_value_num:
681+
result["value_num"] = value.int64_val
682+
elif value.HasField("int32_val"):
683+
result["value_text"] = str(value.int32_val)
684+
if include_value_num:
685+
result["value_num"] = value.int32_val
594686
elif value.HasField("double_val"):
595687
result["value_text"] = str(value.double_val)
688+
if include_value_num:
689+
result["value_num"] = value.double_val
690+
elif value.HasField("float_val"):
691+
result["value_text"] = str(value.float_val)
692+
if include_value_num:
693+
result["value_num"] = value.float_val
694+
elif value.HasField("bool_val"):
695+
result["value_text"] = str(value.bool_val)
696+
if include_value_num:
697+
result["value_num"] = 1.0 if value.bool_val else 0.0
596698
return result
597699

598700

0 commit comments

Comments
 (0)