Skip to content

Commit abb332b

Browse files
authored
feat: support pydantic data model (#50)
1 parent b7e5ce7 commit abb332b

File tree

9 files changed

+320
-24
lines changed

9 files changed

+320
-24
lines changed

docarray/array/mixins/pydantic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@ def to_pydantic_model(self) -> 'PydanticDocumentArray':
2323

2424
@classmethod
2525
def from_pydantic_model(
26-
cls: Type['T'], model: List['BaseModel'], ndarray_as_list: bool = False
26+
cls: Type['T'],
27+
model: List['BaseModel'],
2728
) -> 'T':
2829
"""Convert a list of PydanticDocument into
2930
3031
:param model: the pydantic data model object that represents a DocumentArray
31-
:param ndarray_as_list: if set to True, `embedding` and `blob` are auto-casted to ndarray. :return:
3232
:return: a DocumentArray
3333
"""
3434
from ... import Document
3535

36-
return cls(Document.from_pydantic_model(m, ndarray_as_list) for m in model)
36+
return cls(Document.from_pydantic_model(m) for m in model)

docarray/document/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import overload, Dict, Optional, List, TYPE_CHECKING
1+
from typing import overload, Dict, Optional, List, TYPE_CHECKING, Union, Sequence
22

33
from .data import DocumentData, default_values
44
from .mixins import AllMixins
@@ -58,10 +58,10 @@ def __init__(
5858
location: Optional[List[float]] = None,
5959
embedding: Optional['ArrayType'] = None,
6060
modality: Optional[str] = None,
61-
evaluations: Optional[Dict[str, 'NamedScore']] = None,
62-
scores: Optional[Dict[str, 'NamedScore']] = None,
63-
chunks: Optional['DocumentArray'] = None,
64-
matches: Optional['DocumentArray'] = None,
61+
evaluations: Optional[Dict[str, Dict[str, 'StructValueType']]] = None,
62+
scores: Optional[Dict[str, Dict[str, 'StructValueType']]] = None,
63+
chunks: Optional[Sequence['Document']] = None,
64+
matches: Optional[Sequence['Document']] = None,
6565
):
6666
...
6767

docarray/document/mixins/pydantic.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,31 @@ class PydanticMixin:
1313
"""Provide helper functions to convert to/from a Pydantic model"""
1414

1515
@classmethod
16-
def json_schema(cls, indent: int = 2) -> str:
16+
def get_json_schema(cls, indent: int = 2) -> str:
1717
"""Return a JSON Schema of Document class."""
1818
from ..pydantic_model import PydanticDocument as DP
1919

20-
return DP.schema_json(indent=indent)
20+
from pydantic import schema_json_of
21+
22+
return schema_json_of(DP, title='Document Schema', indent=indent)
2123

2224
def to_pydantic_model(self) -> 'PydanticDocument':
2325
"""Convert a Document object into a Pydantic model."""
2426
from ..pydantic_model import PydanticDocument as DP
2527

26-
return DP(**{f: getattr(self, f) for f in self.non_empty_fields})
28+
_p_dict = {}
29+
for f in self.non_empty_fields:
30+
v = getattr(self, f)
31+
if f in ('matches', 'chunks'):
32+
_p_dict[f] = v.to_pydantic_model()
33+
elif f in ('scores', 'evaluations'):
34+
_p_dict[f] = {k: v.to_dict() for k, v in v.items()}
35+
else:
36+
_p_dict[f] = v
37+
return DP(**_p_dict)
2738

2839
@classmethod
29-
def from_pydantic_model(
30-
cls: Type['T'], model: 'BaseModel', ndarray_as_list: bool = False
31-
) -> 'T':
40+
def from_pydantic_model(cls: Type['T'], model: 'BaseModel') -> 'T':
3241
"""Build a Document object from a Pydantic model
3342
3443
:param model: the pydantic data model object that represents a Document
@@ -38,15 +47,23 @@ def from_pydantic_model(
3847
from ... import Document
3948

4049
fields = {}
41-
for (field, value) in model.dict(exclude_none=True).items():
50+
if model.chunks:
51+
fields['chunks'] = [Document.from_pydantic_model(d) for d in model.chunks]
52+
if model.matches:
53+
fields['matches'] = [Document.from_pydantic_model(d) for d in model.matches]
54+
55+
for (field, value) in model.dict(
56+
exclude_none=True, exclude={'chunks', 'matches'}
57+
).items():
4258
f_name = field
43-
if f_name == 'chunks' or f_name == 'matches':
44-
fields[f_name] = [Document.from_pydantic_model(d) for d in value]
45-
elif f_name == 'scores' or f_name == 'evaluations':
46-
fields[f_name] = defaultdict(value)
59+
if f_name == 'scores' or f_name == 'evaluations':
60+
from docarray.score import NamedScore
61+
62+
fields[f_name] = defaultdict(NamedScore)
63+
for k, v in value.items():
64+
fields[f_name][k] = NamedScore(v)
4765
elif f_name == 'embedding' or f_name == 'blob':
48-
if not ndarray_as_list:
49-
fields[f_name] = np.array(value)
66+
fields[f_name] = np.array(value)
5067
else:
5168
fields[f_name] = value
5269
return Document(**fields)

docarray/proto/io/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def parse_proto(pb_msg: 'DocumentProto') -> 'Document':
2727
elif f_name == 'location':
2828
fields[f_name] = list(value)
2929
elif f_name == 'scores' or f_name == 'evaluations':
30-
fields[f_name] = defaultdict()
30+
fields[f_name] = defaultdict(NamedScore)
3131
for k, v in value.items():
3232
fields[f_name][k] = NamedScore(
3333
{ff.name: vv for (ff, vv) in v.ListFields()}
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
# FastAPI/pydantic Support
2+
3+
Long story short, DocArray supports [pydantic data model](https://pydantic-docs.helpmanual.io/) via {class}`~docarray.document.pydantic_model.PydanticDocument` and {class}`~docarray.document.pydantic_model.PydanticDocumentArray`.
4+
5+
But this is probably too short to make any sense. So let's take a step back and see what does this mean.
6+
7+
When you want to send/receive Document or DocumentArray object via REST API, you can use `.from_json`/`.to_json` that convert the Document/DocumentArray object into JSON. This has been introduced in the {ref}`docarray-serialization` section.
8+
9+
This way, although quite intuitive to many data scientists, is *not* the modern way of building API services. Your engineer friends won't be happy if you give them a service like this. The main problem here is the **data validation**.
10+
11+
Of course, you can include data validation inside your service logic, but it is often brainfuck as you will need to check field by field and repeat things like `isinstance(field, int)`, not even to mention handling nested JSON.
12+
13+
Modern web frameworks validate the data _before_ it enters the core logic. For example, [FastAPI](https://fastapi.tiangolo.com/) leverages [pydantic](https://pydantic-docs.helpmanual.io/) to validate input & output data.
14+
15+
This chapter will introduce how to leverage DocArray's pydantic support in a FastAPI service to build a modern API service.
16+
17+
```{tip}
18+
Features introduced in this chapter require `fastapi` and `pydantic` as dependency, please do `pip install "docarray[full]"` to enable it.
19+
```
20+
21+
## JSON Schema
22+
23+
You can get the [JSON Schema](https://json-schema.org/) (OpenAPI itself is based on JSON Schema) of Document and DocumentArray by {meth}`~docarray.array.mixins.pydantic.PydanticMixin.get_json_schema`.
24+
25+
````{tab} Document
26+
```python
27+
from docarray import Document
28+
Document.get_json_schema()
29+
```
30+
31+
```json
32+
{
33+
"$ref": "#/definitions/PydanticDocument",
34+
"definitions": {
35+
"PydanticDocument": {
36+
"title": "PydanticDocument",
37+
"type": "object",
38+
"properties": {
39+
"id": {
40+
"title": "Id",
41+
"type": "string"
42+
},
43+
```
44+
````
45+
````{tab} DocumentArray
46+
```python
47+
from docarray import DocumentArray
48+
DocumentArray.get_json_schema()
49+
```
50+
51+
```json
52+
{
53+
"title": "DocumentArray Schema",
54+
"type": "array",
55+
"items": {
56+
"$ref": "#/definitions/PydanticDocument"
57+
},
58+
"definitions": {
59+
"PydanticDocument": {
60+
"title": "PydanticDocument",
61+
"type": "object",
62+
"properties": {
63+
"id": {
64+
"title": "Id",
65+
```
66+
````
67+
Hand them over to your engineer friends, they will be happy as now they can understand what data format you are working on. With these schemas, they can easily integrate DocArray into the system.
68+
69+
## FastAPI usage
70+
71+
The fundamentals of FastAPI can be learned from its docs. I won't repeat them here again.
72+
73+
### Validate incoming Document and DocumentArray
74+
75+
You can import {class}`~docarray.document.pydantic_model.PydanticDocument` and {class}`~docarray.document.pydantic_model.PydanticDocumentArray` pydantic data models, and use them to type hint your endpoint. This will enable the data validation.
76+
77+
```python
78+
from docarray.document.pydantic_model import PydanticDocument, PydanticDocumentArray
79+
from fastapi import FastAPI
80+
81+
app = FastAPI()
82+
83+
@app.post('/single')
84+
async def create_item(item: PydanticDocument):
85+
...
86+
87+
@app.post('/multi')
88+
async def create_array(items: PydanticDocumentArray):
89+
...
90+
```
91+
92+
Let's now send some JSON:
93+
94+
```python
95+
from starlette.testclient import TestClient
96+
client = TestClient(app)
97+
98+
response = client.post('/single', {'hello': 'world'})
99+
print(response, response.text)
100+
response = client.post('/single', {'id': [12, 23]})
101+
print(response, response.text)
102+
```
103+
104+
```text
105+
<Response [422]> {"detail":[{"loc":["body"],"msg":"value is not a valid dict","type":"type_error.dict"}]}
106+
<Response [422]> {"detail":[{"loc":["body"],"msg":"value is not a valid dict","type":"type_error.dict"}]}
107+
```
108+
109+
Both got rejected (422 error) as they are not valid.
110+
111+
## Convert between pydantic model and DocArray objects
112+
113+
{class}`~docarray.document.pydantic_model.PydanticDocument` and {class}`~docarray.document.pydantic_model.PydanticDocumentArray` are mainly for data validation. When you want to implement real logics, you need to convert it into Document or DocumentArray. This can be easily achieved via {meth}`~docarray.array.mixins.pydantic.PydanticMixin.from_pydantic_model`. When you are done with processing and want to send back, you can call {meth}`~docarray.array.mixins.pydantic.PydanticMixin.to_pydantic_model`.
114+
115+
In a nutshell, the whole procedure looks like the following:
116+
117+
```{figure} lifetime-pydantic.svg
118+
```
119+
120+
121+
Let's see an example,
122+
123+
```python
124+
from docarray import Document, DocumentArray
125+
126+
@app.post('/single')
127+
async def create_item(item: PydanticDocument):
128+
d = Document.from_pydantic_model(item)
129+
# now `d` is a Document object
130+
... # process `d` how ever you want
131+
return d.to_pydantic_model()
132+
133+
134+
@app.post('/multi')
135+
async def create_array(items: PydanticDocumentArray):
136+
da = DocumentArray.from_pydantic_model(items)
137+
# now `da` is a DocumentArray object
138+
... # process `da` how ever you want
139+
return da.to_pydantic_model()
140+
```
141+
142+
143+
144+
## Limit returned fields by response model
145+
146+
Supporting pydantic data model means much more beyond data validation. One useful pattern is to define a smaller data model and restrict the response to certain fields of the Document.
147+
148+
Imagine we have a DocumentArray with `.embeddings` on the server side. But we do not want to return them to the client for some reasons (1. meaningless to users; 2. too big to transfer). One can simply define the interested fields via
149+
`pydantic.BaseModel` and then use it in `response_model=`.
150+
151+
```python
152+
from pydantic import BaseModel
153+
from docarray import Document
154+
155+
class IdOnly(BaseModel):
156+
id: str
157+
158+
@app.get('/single', response_model=IdOnly)
159+
async def get_item_no_embedding():
160+
d = Document(embedding=[1, 2, 3])
161+
return d.to_pydantic_model()
162+
```
163+
164+
And you get:
165+
166+
```text
167+
<Response [200]> {'id': '065a5548756211ecaa8d1e008a366d49'}
168+
```
169+
170+
## Limit returned results recursively
171+
172+
The same idea applies to DocumentArray as well. Say after [`.match()`](../documentarray/matching.md), you are only interested in `.id` - the parent `.id` and all matches `id`. You can declare a `BaseModel` as follows:
173+
174+
```python
175+
from typing import List, Optional
176+
177+
class IdAndMatch(BaseModel):
178+
id: str
179+
matches: Optional[List['IdMatch']]
180+
```
181+
182+
Bind it to `response_model`:
183+
184+
```python
185+
@app.get('/get_match', response_model=List[IdAndMatch])
186+
async def get_match_id_only():
187+
da = DocumentArray.empty(10)
188+
da.embeddings = np.random.random([len(da), 3])
189+
da.match(da)
190+
return da.to_pydantic_model()
191+
```
192+
193+
Then you get a very nice result of `id`s of matches (potentially unlimited depth).
194+
195+
```text
196+
[{'id': 'ef82e4f4756411ecb2c01e008a366d49',
197+
'matches': [{'id': 'ef82e4f4756411ecb2c01e008a366d49', 'matches': None},
198+
{'id': 'ef82e6d4756411ecb2c01e008a366d49', 'matches': None},
199+
{'id': 'ef82e760756411ecb2c01e008a366d49', 'matches': None},
200+
{'id': 'ef82e7ec756411ecb2c01e008a366d49', 'matches': None},
201+
...
202+
```
203+
204+
If `'matches': None` is annoying to you (they are here because you didn't compute second-degree matches), you can further leverage FastAPI's feature and do:
205+
```python
206+
@app.get('/get_match', response_model=List[IdMatch], response_model_exclude_none=True)
207+
async def get_match_id_only():
208+
...
209+
```
210+
211+
Finally, you get a very clean results with ids and matches only:
212+
213+
```text
214+
[{'id': '3da6383e756511ecb7cb1e008a366d49',
215+
'matches': [{'id': '3da6383e756511ecb7cb1e008a366d49'},
216+
{'id': '3da63a14756511ecb7cb1e008a366d49'},
217+
{'id': '3da6392e756511ecb7cb1e008a366d49'},
218+
{'id': '3da63b72756511ecb7cb1e008a366d49'},
219+
{'id': '3da639ce756511ecb7cb1e008a366d49'},
220+
{'id': '3da63a5a756511ecb7cb1e008a366d49'},
221+
{'id': '3da63ae6756511ecb7cb1e008a366d49'},
222+
{'id': '3da63aa0756511ecb7cb1e008a366d49'},
223+
{'id': '3da63b2c756511ecb7cb1e008a366d49'},
224+
{'id': '3da63988756511ecb7cb1e008a366d49'}]},
225+
{'id': '3da6392e756511ecb7cb1e008a366d49',
226+
'matches': [{'id': '3da6392e756511ecb7cb1e008a366d49'},
227+
{'id': '3da639ce756511ecb7cb1e008a366d49'},
228+
...
229+
```
230+
231+
More tricks and usages of pydantic model can be found in its docs. Same for FastAPI. I strongly recommend interested readers to go through their documentations.

docs/fundamentals/fastapi-support/lifetime-pydantic.svg

Lines changed: 1 addition & 0 deletions
Loading

docs/index.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,9 @@ get-started/what-is
9292
9393
fundamentals/document/index
9494
fundamentals/documentarray/index
95-
fundamentals/notebook-support/index
9695
datatypes/index
96+
fundamentals/notebook-support/index
97+
fundamentals/fastapi-support/index
9798
```
9899

99100

tests/unit/document/test_protobuf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def test_from_to_namescore_default_dict(attr, meth):
4242
d = Document()
4343
getattr(d, attr)['relevance'].value = 3.0
4444
assert isinstance(d.scores, defaultdict)
45+
assert isinstance(d.scores['random_score1'], NamedScore)
4546

4647
r_d = getattr(Document, f'from_{meth}')(getattr(d, f'to_{meth}')())
4748
assert isinstance(r_d.scores, defaultdict)
49+
assert isinstance(r_d.scores['random_score2'], NamedScore)

0 commit comments

Comments
 (0)