Skip to content

Commit

Permalink
feat(agents-api): Add doc search routes
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Tomer <[email protected]>
  • Loading branch information
Diwank Tomer committed Aug 5, 2024
1 parent c9c9a06 commit 3b6e81d
Show file tree
Hide file tree
Showing 29 changed files with 287 additions and 510 deletions.
39 changes: 27 additions & 12 deletions agents-api/agents_api/autogen/Docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,7 @@ class BaseDocSearchRequest(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
confidence: Annotated[float, Field(0.5, ge=0.0, le=1.0)]
"""
The confidence cutoff level
"""
alpha: Annotated[float, Field(0.75, ge=0.0, le=1.0)]
"""
The weight to apply to BM25 vs Vector search results. 0 => pure BM25; 1 => pure vector;
"""
mmr: bool = False
"""
Whether to include the MMR algorithm in the search. Optimizes for diversity in search results.
"""
limit: Annotated[int, Field(10, ge=1, le=100)]
lang: Literal["en-US"] = "en-US"
"""
The language to be used for text-only search. Support for other languages coming soon.
Expand Down Expand Up @@ -105,6 +94,20 @@ class DocReference(BaseModel):
distance: float | None = None


class DocSearchResponse(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
docs: list[DocReference]
"""
The documents that were found
"""
time: Annotated[float, Field(gt=0.0)]
"""
The time taken to search in seconds
"""


class EmbedQueryRequest(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
Expand All @@ -129,6 +132,14 @@ class HybridDocSearchRequest(BaseDocSearchRequest):
model_config = ConfigDict(
populate_by_name=True,
)
confidence: Annotated[float, Field(0.5, ge=0.0, le=1.0)]
"""
The confidence cutoff level
"""
alpha: Annotated[float, Field(0.75, ge=0.0, le=1.0)]
"""
The weight to apply to BM25 vs Vector search results. 0 => pure BM25; 1 => pure vector;
"""
text: str
"""
Text to use in the search. In `hybrid` search mode, either `text` or both `text` and `vector` fields are required.
Expand Down Expand Up @@ -161,6 +172,10 @@ class VectorDocSearchRequest(BaseDocSearchRequest):
model_config = ConfigDict(
populate_by_name=True,
)
confidence: Annotated[float, Field(0.5, ge=0.0, le=1.0)]
"""
The confidence cutoff level
"""
vector: list[float]
"""
Vector to use in the search. Must be the same dimensions as the embedding model or else an error will be thrown.
Expand Down
3 changes: 2 additions & 1 deletion agents-api/agents_api/models/docs/search_docs_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def search_docs_hybrid(
query: str,
query_embedding: list[float],
k: int = 3,
alpha: float = 0.7, # Weight of the embedding search results (this is a good default)
embed_search_options: dict = {},
text_search_options: dict = {},
**kwargs,
Expand All @@ -122,4 +123,4 @@ def search_docs_hybrid(
**kwargs,
)

return dbsf_fuse(text_results, embedding_results)[:k]
return dbsf_fuse(text_results, embedding_results, alpha)[:k]
1 change: 1 addition & 0 deletions agents-api/agents_api/routers/docs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .get_doc import get_doc
from .list_docs import list_agent_docs, list_user_docs
from .router import router
from .search_docs import search_agent_docs, search_user_docs
111 changes: 111 additions & 0 deletions agents-api/agents_api/routers/docs/search_docs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import time
from typing import Annotated

from fastapi import Depends
from pydantic import UUID4

from ...autogen.openapi_model import (
DocSearchResponse,
HybridDocSearchRequest,
TextOnlyDocSearchRequest,
VectorDocSearchRequest,
)
from ...dependencies.developer_id import get_developer_id
from ...models.docs.search_docs_by_embedding import search_docs_by_embedding
from ...models.docs.search_docs_by_text import search_docs_by_text
from ...models.docs.search_docs_hybrid import search_docs_hybrid
from .router import router


def get_search_fn_and_params(search_params):
search_fn, params = None, None

match search_params:
case TextOnlyDocSearchRequest(text=query, limit=k):
search_fn = search_docs_by_text
params = dict(
query=query,
k=k,
)

case VectorDocSearchRequest(
vector=query_embedding, limit=k, confidence=confidence
):
search_fn = search_docs_by_embedding
params = dict(
query_embedding=query_embedding,
k=k,
confidence=confidence,
)

case HybridDocSearchRequest(
text=query,
vector=query_embedding,
limit=k,
confidence=confidence,
alpha=alpha,
):
search_fn = search_docs_hybrid
params = dict(
query=query,
query_embedding=query_embedding,
k=k,
embed_search_options=dict(confidence=confidence),
alpha=alpha,
)

return search_fn, params


@router.post("/users/{user_id}/search", tags=["docs"])
async def search_user_docs(
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
search_params: (
TextOnlyDocSearchRequest | VectorDocSearchRequest | HybridDocSearchRequest
),
user_id: UUID4,
) -> DocSearchResponse:
search_fn, params = get_search_fn_and_params(search_params)

start = time.time()
docs = search_fn(
developer_id=x_developer_id,
owner_type="user",
owner_id=user_id,
**params,
)
end = time.time()

time_taken = end - start

return DocSearchResponse(
docs=docs,
time=time_taken,
)


@router.post("/agents/{agent_id}/search", tags=["docs"])
async def search_agent_docs(
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
search_params: (
TextOnlyDocSearchRequest | VectorDocSearchRequest | HybridDocSearchRequest
),
agent_id: UUID4,
) -> DocSearchResponse:
search_fn, params = get_search_fn_and_params(search_params)

start = time.time()
docs = search_fn(
developer_id=x_developer_id,
owner_type="agent",
owner_id=agent_id,
**params,
)
end = time.time()

time_taken = end - start

return DocSearchResponse(
docs=docs,
time=time_taken,
)
14 changes: 2 additions & 12 deletions sdks/python/julep/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
AgentsCreateAgentRequestDefaultSettings,
AgentsCreateAgentRequestInstructions,
AgentsDocsSearchRouteSearchRequestBody,
AgentsDocsSearchRouteSearchRequestDirection,
AgentsDocsSearchRouteSearchRequestSortBy,
AgentsDocsSearchRouteSearchResponse,
AgentsPatchAgentRequestDefaultSettings,
AgentsPatchAgentRequestInstructions,
AgentsRouteListRequestDirection,
Expand Down Expand Up @@ -70,6 +67,7 @@
DocsDocOwner,
DocsDocOwnerRole,
DocsDocReference,
DocsDocSearchResponse,
DocsEmbedQueryRequest,
DocsEmbedQueryRequestText,
DocsEmbedQueryResponse,
Expand Down Expand Up @@ -210,9 +208,6 @@
UserDocsRouteListRequestSortBy,
UserDocsRouteListResponse,
UserDocsSearchRouteSearchRequestBody,
UserDocsSearchRouteSearchRequestDirection,
UserDocsSearchRouteSearchRequestSortBy,
UserDocsSearchRouteSearchResponse,
UsersRouteListRequestDirection,
UsersRouteListRequestSortBy,
UsersRouteListResponse,
Expand All @@ -235,9 +230,6 @@
"AgentsCreateAgentRequestDefaultSettings",
"AgentsCreateAgentRequestInstructions",
"AgentsDocsSearchRouteSearchRequestBody",
"AgentsDocsSearchRouteSearchRequestDirection",
"AgentsDocsSearchRouteSearchRequestSortBy",
"AgentsDocsSearchRouteSearchResponse",
"AgentsPatchAgentRequestDefaultSettings",
"AgentsPatchAgentRequestInstructions",
"AgentsRouteListRequestDirection",
Expand Down Expand Up @@ -291,6 +283,7 @@
"DocsDocOwner",
"DocsDocOwnerRole",
"DocsDocReference",
"DocsDocSearchResponse",
"DocsEmbedQueryRequest",
"DocsEmbedQueryRequestText",
"DocsEmbedQueryResponse",
Expand Down Expand Up @@ -432,9 +425,6 @@
"UserDocsRouteListRequestSortBy",
"UserDocsRouteListResponse",
"UserDocsSearchRouteSearchRequestBody",
"UserDocsSearchRouteSearchRequestDirection",
"UserDocsSearchRouteSearchRequestSortBy",
"UserDocsSearchRouteSearchResponse",
"UsersRouteListRequestDirection",
"UsersRouteListRequestSortBy",
"UsersRouteListResponse",
Expand Down
Loading

0 comments on commit 3b6e81d

Please sign in to comment.