Skip to content

Commit 3ea7138

Browse files
bug fix for addition of new features
Signed-off-by: Francisco Javier Arceo <[email protected]>
1 parent 198cb20 commit 3ea7138

File tree

2 files changed

+28
-17
lines changed

2 files changed

+28
-17
lines changed

sdk/python/feast/infra/online_stores/sqlite.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ def retrieve_online_documents(
355355
# Convert the embedding to a binary format instead of using SerializeToString()
356356
query_embedding_bin = serialize_f32(embedding, config.online_store.vector_len)
357357
table_name = _table_id(project, table)
358+
vector_field = _get_vector_field(table)
358359

359360
cur.execute(
360361
f"""
@@ -369,14 +370,15 @@ def retrieve_online_documents(
369370
f"""
370371
INSERT INTO vec_table(rowid, vector_value)
371372
select rowid, vector_value from {table_name}
373+
where feature_name = "{vector_field}"
372374
"""
373375
)
374376
cur.execute(
377+
f"""
378+
CREATE VIRTUAL TABLE IF NOT EXISTS vec_table using vec0(
379+
vector_value float[{config.online_store.vector_len}]
380+
);
375381
"""
376-
INSERT INTO vec_table(rowid, vector_value)
377-
VALUES (?, ?)
378-
""",
379-
(0, query_embedding_bin),
380382
)
381383

382384
# Have to join this with the {table_name} to get the feature name and entity_key
@@ -473,16 +475,7 @@ def retrieve_online_documents_v2(
473475

474476
query_embedding_bin = serialize_f32(query, online_store.vector_len) # type: ignore
475477
table_name = _table_id(config.project, table)
476-
vector_fields: List[Field] = [
477-
f for f in table.features if getattr(f, "vector_index", None)
478-
]
479-
assert len(vector_fields) > 0, (
480-
f"No vector field found, please update feature view = {table.name} to declare a vector field"
481-
)
482-
assert len(vector_fields) < 2, (
483-
"Only one vector field is supported, please update feature view = {table.name} to declare one vector field"
484-
)
485-
vector_field: str = vector_fields[0].name
478+
vector_field = _get_vector_field(table)
486479

487480
cur.execute(
488481
f"""
@@ -696,3 +689,20 @@ def update(self):
696689

697690
def teardown(self):
698691
self.conn.execute(f"DROP TABLE IF EXISTS {self.name}")
692+
693+
694+
def _get_vector_field(table: FeatureView) -> str:
695+
"""
696+
Get the vector field from the feature view. There can be only one.
697+
"""
698+
vector_fields: List[Field] = [
699+
f for f in table.features if getattr(f, "vector_index", None)
700+
]
701+
assert len(vector_fields) > 0, (
702+
f"No vector field found, please update feature view = {table.name} to declare a vector field"
703+
)
704+
assert len(vector_fields) < 2, (
705+
"Only one vector field is supported, please update feature view = {table.name} to declare one vector field"
706+
)
707+
vector_field: str = vector_fields[0].name
708+
return vector_field

sdk/python/tests/unit/online_store/test_online_retrieval.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -713,9 +713,10 @@ def test_sqlite_get_online_documents() -> None:
713713
)
714714
assert record_count == len(data) * len(document_embeddings_fv.features)
715715

716-
query_embedding = np.random.random(
717-
vector_length,
718-
)
716+
# query_embedding = np.random.random(
717+
# vector_length,
718+
# )
719+
query_embedding = [float(x) for x in np.random.random(vector_length)]
719720
result = store.retrieve_online_documents(
720721
feature="document_embeddings:Embeddings", query=query_embedding, top_k=3
721722
).to_dict()

0 commit comments

Comments
 (0)