Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
test: test config subclassing
Signed-off-by: Johannes Messner <[email protected]>
  • Loading branch information
JohannesMessner committed Sep 5, 2023
commit d587a6a1d270cf913c35d14bdc88c6d7f548ab91
31 changes: 26 additions & 5 deletions tests/units/document/test_any_document.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import Dict, List

import numpy as np
import pytest
from typing import Dict, List
from orjson import orjson

from docarray import DocList
from docarray.base_doc import AnyDoc, BaseDoc
from docarray.base_doc.io.json import orjson_dumps_and_decode
from docarray.typing import NdArray
from docarray.typing.tensor.abstract_tensor import AbstractTensor


def test_any_doc():
Expand Down Expand Up @@ -36,7 +40,7 @@ class InnerDoc(BaseDoc):
class DocTest(BaseDoc):
text: str
tags: Dict[str, int]
l: List[int]
l_: List[int]
d: InnerDoc
ld: DocList[InnerDoc]

Expand All @@ -46,14 +50,14 @@ class DocTest(BaseDoc):
DocTest(
text='type1',
tags={'type': 1},
l=[1, 2],
l_=[1, 2],
d=inner_doc,
ld=DocList[InnerDoc]([inner_doc]),
),
DocTest(
text='type2',
tags={'type': 2},
l=[1, 2],
l_=[1, 2],
d=inner_doc,
ld=DocList[InnerDoc]([inner_doc]),
),
Expand All @@ -71,7 +75,7 @@ class DocTest(BaseDoc):
for i, d in enumerate(aux):
assert d.tags['type'] == i + 1
assert d.text == f'type{i + 1}'
assert d.l == [1, 2]
assert d.l_ == [1, 2]
if protocol == 'proto':
assert isinstance(d.d, AnyDoc)
assert d.d.text == 'I am inner' # inner Document is a Dict
Expand All @@ -89,3 +93,20 @@ class DocTest(BaseDoc):
assert isinstance(d.ld[0], dict)
assert d.ld[0]['text'] == 'I am inner'
assert d.ld[0]['t'] == {'a': 'b'}


def test_subclass_config():
class MyDoc(BaseDoc):
x: str

class Config(BaseDoc.Config):
arbitrary_types_allowed = True # just an example setting

assert MyDoc.Config.json_loads == orjson.loads
assert MyDoc.Config.json_dumps == orjson_dumps_and_decode
assert (
MyDoc.Config.json_encoders[AbstractTensor](3) == 3
) # dirty check that it is identity
assert MyDoc.Config.validate_assignment
assert not MyDoc.Config._load_extra_fields_from_protobuf
assert MyDoc.Config.arbitrary_types_allowed