Skip to content

Commit

Permalink
feat: Add metadata_filter query param to list endpoints
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
creatorrr committed Mar 13, 2024
1 parent 9703100 commit 009a772
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 16 deletions.
16 changes: 14 additions & 2 deletions agents-api/agents_api/models/agent/list_agents.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import json
from typing import Any
from uuid import UUID


def list_agents_query(developer_id: UUID, limit: int = 100, offset: int = 0):
def list_agents_query(
developer_id: UUID,
limit: int = 100,
offset: int = 0,
metadata_filter: dict[str, Any] = {},
):
metadata_filter_str = ", ".join(
[f'metadata->"{json.dumps(k)}" == {v}' for k, v in metadata_filter.items()]
)

return f"""
{{
input[developer_id] <- [[to_uuid("{developer_id}")]]
Expand All @@ -24,7 +35,8 @@ def list_agents_query(developer_id: UUID, limit: int = 100, offset: int = 0):
created_at,
updated_at,
metadata,
}}
}},
{metadata_filter_str}
:limit {limit}
:offset {offset}
Expand Down
14 changes: 13 additions & 1 deletion agents-api/agents_api/models/session/list_sessions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import json
from typing import Any
from uuid import UUID


def list_sessions_query(developer_id: UUID, limit: int = 100, offset: int = 0):
def list_sessions_query(
developer_id: UUID,
limit: int = 100,
offset: int = 0,
metadata_filter: dict[str, Any] = {},
):
metadata_filter_str = ", ".join(
[f'metadata->"{json.dumps(k)}" == {v}' for k, v in metadata_filter.items()]
)

return f"""
input[developer_id] <- [[
to_uuid("{developer_id}"),
Expand All @@ -28,6 +39,7 @@ def list_sessions_query(developer_id: UUID, limit: int = 100, offset: int = 0):
metadata,
@ "NOW"
}},
{metadata_filter_str}
*session_lookup{{
agent_id,
user_id,
Expand Down
20 changes: 16 additions & 4 deletions agents-api/agents_api/models/user/list_users.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import json
from typing import Any
from uuid import UUID


def list_users_query(developer_id: UUID, limit: int = 100, offset: int = 0):
def list_users_query(
developer_id: UUID,
limit: int = 100,
offset: int = 0,
metadata_filter: dict[str, Any] = {},
):
metadata_filter_str = ", ".join(
[f'metadata->"{json.dumps(k)}" == {v}' for k, v in metadata_filter.items()]
)

return f"""
input[developer_id] <- [[to_uuid("{developer_id}")]]
Expand All @@ -13,7 +24,7 @@ def list_users_query(developer_id: UUID, limit: int = 100, offset: int = 0):
updated_at,
metadata,
] :=
input[developer_id],
input[developer_id],
*users {{
user_id: id,
developer_id,
Expand All @@ -22,8 +33,9 @@ def list_users_query(developer_id: UUID, limit: int = 100, offset: int = 0):
created_at,
updated_at,
metadata,
}}
}},
{metadata_filter_str}
:limit {limit}
:offset {offset}
:sort -created_at
Expand Down
37 changes: 33 additions & 4 deletions agents-api/agents_api/routers/agents/routers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from fastapi import APIRouter, HTTPException, status, Depends
import json
from pydantic import UUID4, BaseModel
from starlette.status import HTTP_201_CREATED, HTTP_202_ACCEPTED
from json import JSONDecodeError
from typing import Annotated
from uuid import uuid4

from fastapi import APIRouter, HTTPException, status, Depends
from pycozo.client import QueryException
from pydantic import UUID4, BaseModel
from starlette.status import HTTP_201_CREATED, HTTP_202_ACCEPTED

from agents_api.clients.cozo import client
from agents_api.clients.embed import embed
Expand Down Expand Up @@ -271,7 +273,16 @@ async def list_agents(
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
limit: int = 100,
offset: int = 0,
metadata_filter: str = "{}",
) -> AgentList:
try:
metadata_filter = json.loads(metadata_filter)
except JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="metadata_filter is not a valid JSON",
)

return AgentList(
items=[
Agent(**row.to_dict())
Expand All @@ -280,6 +291,7 @@ async def list_agents(
developer_id=x_developer_id,
limit=limit,
offset=offset,
metadata_filter=metadata_filter,
)
).iterrows()
]
Expand Down Expand Up @@ -326,7 +338,24 @@ async def create_docs(agent_id: UUID4, request: CreateDoc) -> ResourceCreatedRes


@router.get("/agents/{agent_id}/docs", tags=["agents"])
async def list_docs(agent_id: UUID4, limit: int = 100, offset: int = 0) -> DocsList:
async def list_docs(
agent_id: UUID4, limit: int = 100, offset: int = 0, metadata_filter: str = "{}"
) -> DocsList:
try:
metadata_filter = json.loads(metadata_filter)
except JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="metadata_filter is not a valid JSON",
)

# TODO: Implement metadata filter
if metadata_filter:
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="metadata_filter is not implemented",
)

resp = client.run(
list_docs_snippets_by_owner_query(
owner_type="agent",
Expand Down
15 changes: 14 additions & 1 deletion agents-api/agents_api/routers/sessions/routers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
from json import JSONDecodeError
from typing import Annotated
from uuid import uuid4

Expand Down Expand Up @@ -96,12 +98,23 @@ async def list_sessions(
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
limit: int = 100,
offset: int = 0,
metadata_filter: str = "{}",
) -> SessionList:
try:
metadata_filter = json.loads(metadata_filter)
except JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="metadata_filter is not a valid JSON",
)

return SessionList(
items=[
Session(**row.to_dict())
for _, row in client.run(
list_sessions_query(x_developer_id, limit, offset),
list_sessions_query(
x_developer_id, limit, offset, metadata_filter=metadata_filter
),
).iterrows()
]
)
Expand Down
30 changes: 29 additions & 1 deletion agents-api/agents_api/routers/users/routers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
from json import JSONDecodeError
from typing import Annotated
from uuid import uuid4

Expand Down Expand Up @@ -166,7 +168,16 @@ async def list_users(
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
limit: int = 100,
offset: int = 0,
metadata_filter: str = "{}",
) -> UserList:
try:
metadata_filter = json.loads(metadata_filter)
except JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="metadata_filter is not a valid JSON",
)

return UserList(
items=[
User(**row.to_dict())
Expand Down Expand Up @@ -221,7 +232,24 @@ async def create_docs(user_id: UUID4, request: CreateDoc) -> ResourceCreatedResp


@router.get("/users/{user_id}/docs", tags=["users"])
async def list_docs(user_id: UUID4, limit: int = 100, offset: int = 0) -> DocsList:
async def list_docs(
user_id: UUID4, limit: int = 100, offset: int = 0, metadata_filter: str = "{}"
) -> DocsList:
try:
metadata_filter = json.loads(metadata_filter)
except JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="metadata_filter is not a valid JSON",
)

# TODO: Implement metadata filter
if metadata_filter:
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="metadata_filter is not implemented",
)

resp = client.run(
list_docs_snippets_by_owner_query(
owner_type="user",
Expand Down
7 changes: 4 additions & 3 deletions agents-api/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 009a772

Please sign in to comment.