Skip to content

Commit 65cd347

Browse files
Updating retrieve online documents v2 to work for other fields for sqlite
Signed-off-by: Francisco Javier Arceo <[email protected]>
1 parent ed63895 commit 65cd347

File tree

5 files changed

+249
-84
lines changed

5 files changed

+249
-84
lines changed

sdk/python/feast/feature_view.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,10 @@ def __init__(
191191
else:
192192
features.append(field)
193193

194+
assert len([f for f in features if f.vector_index]) < 2, (
195+
f"Only one vector feature is allowed per feature view. Please update {self.name}."
196+
)
197+
194198
# TODO(felixwang9817): Add more robust validation of features.
195199
cols = [field.name for field in schema]
196200
for col in cols:

sdk/python/feast/infra/online_stores/sqlite.py

Lines changed: 129 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
import sys
1919
from datetime import date, datetime
2020
from pathlib import Path
21-
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple
21+
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union
2222

2323
from pydantic import StrictStr
2424

2525
from feast import Entity
2626
from feast.feature_view import FeatureView
27+
from feast.field import Field
2728
from feast.infra.infra_object import SQLITE_INFRA_OBJECT_CLASS_TYPE, InfraObject
2829
from feast.infra.key_encoding_utils import (
2930
deserialize_entity_key,
@@ -38,7 +39,13 @@
3839
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
3940
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
4041
from feast.repo_config import FeastConfigBaseModel, RepoConfig
41-
from feast.utils import _build_retrieve_online_document_record, to_naive_utc
42+
from feast.type_map import feast_value_type_to_python_type
43+
from feast.types import FEAST_VECTOR_TYPES
44+
from feast.utils import (
45+
_build_retrieve_online_document_record,
46+
_serialize_vector_to_float_list,
47+
to_naive_utc,
48+
)
4249

4350

4451
def adapt_date_iso(val: date):
@@ -94,6 +101,7 @@ class SqliteOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig):
94101

95102
vector_enabled: bool = False
96103
vector_len: Optional[int] = None
104+
text_search_enabled: bool = False
97105

98106

99107
class SqliteOnlineStore(OnlineStore):
@@ -144,9 +152,8 @@ def online_write_batch(
144152
progress: Optional[Callable[[int], Any]],
145153
) -> None:
146154
conn = self._get_conn(config)
147-
148155
project = config.project
149-
156+
feature_type_dict = {f.name: f.dtype for f in table.features}
150157
with conn:
151158
for entity_key, values, timestamp, created_ts in data:
152159
entity_key_bin = serialize_entity_key(
@@ -160,71 +167,51 @@ def online_write_batch(
160167
table_name = _table_id(project, table)
161168
for feature_name, val in values.items():
162169
if config.online_store.vector_enabled:
163-
vector_bin = serialize_f32(
164-
val.float_list_val.val, config.online_store.vector_len
165-
) # type: ignore
166-
conn.execute(
167-
f"""
168-
UPDATE {table_name}
169-
SET value = ?, vector_value = ?, event_ts = ?, created_ts = ?
170-
WHERE (entity_key = ? AND feature_name = ?)
171-
""",
172-
(
173-
# SET
174-
val.SerializeToString(),
175-
vector_bin,
176-
timestamp,
177-
created_ts,
178-
# WHERE
179-
entity_key_bin,
180-
feature_name,
181-
),
182-
)
170+
if feature_type_dict[feature_name] in FEAST_VECTOR_TYPES:
171+
val_bin = serialize_f32(
172+
val.float_list_val.val, config.online_store.vector_len
173+
) # type: ignore
183174

175+
else:
176+
val_bin = feast_value_type_to_python_type(val)
184177
conn.execute(
185-
f"""INSERT OR IGNORE INTO {table_name}
186-
(entity_key, feature_name, value, vector_value, event_ts, created_ts)
187-
VALUES (?, ?, ?, ?, ?, ?)""",
178+
f"""
179+
INSERT INTO {table_name} (entity_key, feature_name, value, vector_value, event_ts, created_ts)
180+
VALUES (?, ?, ?, ?, ?, ?)
181+
ON CONFLICT(entity_key, feature_name) DO UPDATE SET
182+
value = excluded.value,
183+
vector_value = excluded.vector_value,
184+
event_ts = excluded.event_ts,
185+
created_ts = excluded.created_ts;
186+
""",
188187
(
189-
entity_key_bin,
190-
feature_name,
191-
val.SerializeToString(),
192-
vector_bin,
193-
timestamp,
194-
created_ts,
188+
entity_key_bin, # entity_key
189+
feature_name, # feature_name
190+
val.SerializeToString(), # value
191+
val_bin, # vector_value
192+
timestamp, # event_ts
193+
created_ts, # created_ts
195194
),
196195
)
197-
198196
else:
199197
conn.execute(
200198
f"""
201-
UPDATE {table_name}
202-
SET value = ?, event_ts = ?, created_ts = ?
203-
WHERE (entity_key = ? AND feature_name = ?)
199+
INSERT INTO {table_name} (entity_key, feature_name, value, event_ts, created_ts)
200+
VALUES (?, ?, ?, ?, ?)
201+
ON CONFLICT(entity_key, feature_name) DO UPDATE SET
202+
value = excluded.value,
203+
event_ts = excluded.event_ts,
204+
created_ts = excluded.created_ts;
204205
""",
205206
(
206-
# SET
207-
val.SerializeToString(),
208-
timestamp,
209-
created_ts,
210-
# WHERE
211-
entity_key_bin,
212-
feature_name,
207+
entity_key_bin, # entity_key
208+
feature_name, # feature_name
209+
val.SerializeToString(), # value
210+
timestamp, # event_ts
211+
created_ts, # created_ts
213212
),
214213
)
215214

216-
conn.execute(
217-
f"""INSERT OR IGNORE INTO {table_name}
218-
(entity_key, feature_name, value, event_ts, created_ts)
219-
VALUES (?, ?, ?, ?, ?)""",
220-
(
221-
entity_key_bin,
222-
feature_name,
223-
val.SerializeToString(),
224-
timestamp,
225-
created_ts,
226-
),
227-
)
228215
if progress:
229216
progress(1)
230217

@@ -482,13 +469,21 @@ def retrieve_online_documents_v2(
482469
conn = self._get_conn(config)
483470
cur = conn.cursor()
484471

485-
online_store = config.online_store
486-
if not isinstance(online_store, SqliteOnlineStoreConfig):
487-
raise ValueError("online_store must be SqliteOnlineStoreConfig")
488472
if not online_store.vector_len:
489473
raise ValueError("vector_len is not configured in the online store config")
474+
490475
query_embedding_bin = serialize_f32(query, online_store.vector_len) # type: ignore
491476
table_name = _table_id(config.project, table)
477+
vector_fields: List[Field] = [
478+
f for f in table.features if getattr(f, "vector_index", None)
479+
]
480+
assert len(vector_fields) > 0, (
481+
f"No vector field found, please update feature view = {table.name} to declare a vector field"
482+
)
483+
assert len(vector_fields) < 2, (
484+
"Only one vector field is supported, please update feature view = {table.name} to declare one vector field"
485+
)
486+
vector_field: str = vector_fields[0].name
492487

493488
cur.execute(
494489
f"""
@@ -500,17 +495,19 @@ def retrieve_online_documents_v2(
500495

501496
cur.execute(
502497
f"""
503-
INSERT INTO vec_table(rowid, vector_value)
498+
INSERT INTO vec_table (rowid, vector_value)
504499
select rowid, vector_value from {table_name}
500+
where feature_name = "{vector_field}"
505501
"""
506502
)
507503

508504
cur.execute(
509505
f"""
510506
select
511-
fv.entity_key,
512-
fv.feature_name,
513-
fv.value,
507+
fv2.entity_key,
508+
fv2.feature_name,
509+
fv2.value,
510+
fv.vector_value,
514511
f.distance,
515512
fv.event_ts,
516513
fv.created_ts
@@ -526,38 +523,80 @@ def retrieve_online_documents_v2(
526523
) f
527524
left join {table_name} fv
528525
on f.rowid = fv.rowid
529-
where fv.feature_name in ({",".join(["?" for _ in requested_features])})
526+
left join {table_name} fv2
527+
on fv.entity_key = fv2.entity_key
528+
where fv2.feature_name != "{vector_field}"
530529
""",
531530
(
532531
query_embedding_bin,
533532
top_k,
534-
*[f.split(":")[-1] for f in requested_features],
535533
),
536534
)
537535

538536
rows = cur.fetchall()
539-
result: List[
537+
results: List[
540538
Tuple[
541539
Optional[datetime],
542540
Optional[EntityKeyProto],
543541
Optional[Dict[str, ValueProto]],
544542
]
545543
] = []
546544

547-
for entity_key, feature_name, value_bin, distance, event_ts, created_ts in rows:
548-
val = ValueProto()
549-
val.ParseFromString(value_bin)
550-
entity_key_proto = None
551-
if entity_key:
552-
entity_key_proto = deserialize_entity_key(
553-
entity_key,
554-
entity_key_serialization_version=config.entity_key_serialization_version,
545+
entity_dict: Dict[
546+
str, Dict[str, Union[str, ValueProto, EntityKeyProto, datetime]]
547+
] = {}
548+
for (
549+
entity_key,
550+
feature_name,
551+
value_bin,
552+
vector_value,
553+
distance,
554+
event_ts,
555+
created_ts,
556+
) in rows:
557+
entity_key_proto = deserialize_entity_key(
558+
entity_key,
559+
entity_key_serialization_version=config.entity_key_serialization_version,
560+
)
561+
if entity_key not in entity_dict:
562+
entity_dict[entity_key] = {}
563+
564+
feature_val = ValueProto()
565+
feature_val.ParseFromString(value_bin)
566+
entity_dict[entity_key]["entity_key_proto"] = entity_key_proto
567+
entity_dict[entity_key][feature_name] = feature_val
568+
entity_dict[entity_key][vector_field] = _serialize_vector_to_float_list(
569+
vector_value
570+
)
571+
entity_dict[entity_key]["distance"] = ValueProto(float_val=distance)
572+
entity_dict[entity_key]["event_ts"] = event_ts
573+
entity_dict[entity_key]["created_ts"] = created_ts
574+
575+
for entity_key_value in entity_dict:
576+
res_event_ts: Optional[datetime] = None
577+
res_entity_key_proto: Optional[EntityKeyProto] = None
578+
if isinstance(entity_dict[entity_key_value]["event_ts"], datetime):
579+
res_event_ts = entity_dict[entity_key_value]["event_ts"] # type: ignore[assignment]
580+
581+
if isinstance(
582+
entity_dict[entity_key_value]["entity_key_proto"], EntityKeyProto
583+
):
584+
res_entity_key_proto = entity_dict[entity_key_value]["entity_key_proto"] # type: ignore[assignment]
585+
586+
res_dict: Dict[str, ValueProto] = {
587+
k: v
588+
for k, v in entity_dict[entity_key_value].items()
589+
if isinstance(v, ValueProto) and isinstance(k, str)
590+
}
591+
592+
results.append(
593+
(
594+
res_event_ts,
595+
res_entity_key_proto,
596+
res_dict,
555597
)
556-
res = {feature_name: val}
557-
res["distance"] = ValueProto(float_val=distance)
558-
result.append((event_ts, entity_key_proto, res))
559-
560-
return result
598+
)
599+
return results
561600

562601

563602
def _initialize_conn(
@@ -640,7 +679,17 @@ def update(self):
640679
except ModuleNotFoundError:
641680
logging.warning("Cannot use sqlite_vec for vector search")
642681
self.conn.execute(
643-
f"CREATE TABLE IF NOT EXISTS {self.name} (entity_key BLOB, feature_name TEXT, value BLOB, vector_value BLOB, event_ts timestamp, created_ts timestamp, PRIMARY KEY(entity_key, feature_name))"
682+
f"""
683+
CREATE TABLE IF NOT EXISTS {self.name} (
684+
entity_key BLOB,
685+
feature_name TEXT,
686+
value BLOB,
687+
vector_value BLOB,
688+
event_ts timestamp,
689+
created_ts timestamp,
690+
PRIMARY KEY(entity_key, feature_name)
691+
)
692+
"""
644693
)
645694
self.conn.execute(
646695
f"CREATE INDEX IF NOT EXISTS {self.name}_ek ON {self.name} (entity_key);"

sdk/python/feast/types.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from abc import ABC, abstractmethod
1515
from datetime import datetime, timezone
1616
from enum import Enum
17-
from typing import Dict, Union
17+
from typing import Dict, List, Union
1818

1919
import pyarrow
2020

@@ -196,6 +196,17 @@ def __str__(self):
196196
UnixTimestamp: pyarrow.timestamp("us", tz=_utc_now().tzname()),
197197
}
198198

199+
FEAST_VECTOR_TYPES: List[Union[ValueType, PrimitiveFeastType, ComplexFeastType]] = [
200+
ValueType.BYTES_LIST,
201+
ValueType.INT32_LIST,
202+
ValueType.INT64_LIST,
203+
ValueType.FLOAT_LIST,
204+
ValueType.BOOL_LIST,
205+
]
206+
for k in VALUE_TYPES_TO_FEAST_TYPES:
207+
if k in FEAST_VECTOR_TYPES:
208+
FEAST_VECTOR_TYPES.append(VALUE_TYPES_TO_FEAST_TYPES[k])
209+
199210

200211
def from_feast_to_pyarrow_type(feast_type: FeastType) -> pyarrow.DataType:
201212
"""

sdk/python/tests/example_repos/example_feature_repo_1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@
125125
vector_search_metric="L2",
126126
),
127127
Field(name="item_id", dtype=String),
128+
Field(name="content", dtype=String),
129+
Field(name="title", dtype=String),
128130
],
129131
source=rag_documents_source,
130132
ttl=timedelta(hours=24),

0 commit comments

Comments
 (0)