Skip to content

Commit 96cd383

Browse files
author
Charlotte Gerhaher
authored
feat(v2): display mesh and pointcloud (#1113)
* feat: display mesh and pointcloud Signed-off-by: anna-charlotte <[email protected]> * chore: update poetry Signed-off-by: anna-charlotte <[email protected]> * fix: mypy Signed-off-by: anna-charlotte <[email protected]> * fix: add display from param to mesh and pc display Signed-off-by: anna-charlotte <[email protected]> * fix: clean up Signed-off-by: anna-charlotte <[email protected]> * fix: mypy Signed-off-by: anna-charlotte <[email protected]> * fix: move display from url to mesh and pc url classes Signed-off-by: anna-charlotte <[email protected]> * chore: remove pyglet dependency Signed-off-by: anna-charlotte <[email protected]> * chore: update pyproject toml Signed-off-by: anna-charlotte <[email protected]> * refactor: copy is notebook function from hubble sdk Signed-off-by: anna-charlotte <[email protected]> * fix: introduce vertices and faces doc Signed-off-by: anna-charlotte <[email protected]> * fix: introduce points and colors class for point cloud Signed-off-by: anna-charlotte <[email protected]> * fix: mypy and tests Signed-off-by: anna-charlotte <[email protected]> * docs: add display example to docs Signed-off-by: anna-charlotte <[email protected]> * fix: apply johannes suggestion from review Signed-off-by: anna-charlotte <[email protected]> * fix: apply samis suggestion Signed-off-by: anna-charlotte <[email protected]> * docs: update docstring Signed-off-by: anna-charlotte <[email protected]> * fix: only display in notebook Signed-off-by: anna-charlotte <[email protected]> * docs: update docstring Signed-off-by: anna-charlotte <[email protected]> * chore: get poetry lock file from feat rewrite v2 Signed-off-by: anna-charlotte <[email protected]> * docs: update docstrings Signed-off-by: anna-charlotte <[email protected]> --------- Signed-off-by: anna-charlotte <[email protected]>
1 parent b085794 commit 96cd383

File tree

14 files changed

+280
-78
lines changed

14 files changed

+280
-78
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from docarray.documents.mesh.mesh_3d import Mesh3D
2+
3+
__all__ = ['Mesh3D']
Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from typing import Any, Optional, Type, TypeVar, Union
22

33
from docarray.base_document import BaseDocument
4-
from docarray.typing import AnyEmbedding, AnyTensor, Mesh3DUrl
4+
from docarray.documents.mesh.vertices_and_faces import VerticesAndFaces
5+
from docarray.typing.tensor.embedding import AnyEmbedding
6+
from docarray.typing.url.url_3d.mesh_url import Mesh3DUrl
57

68
T = TypeVar('T', bound='Mesh3D')
79

@@ -17,9 +19,10 @@ class Mesh3D(BaseDocument):
1719
tensor of shape (n_faces, 3). Each number in that tensor refers to an index of a
1820
vertex in the tensor of vertices.
1921
20-
The Mesh3D Document can contain an Mesh3DUrl (`Mesh3D.url`), an AnyTensor of
21-
vertices (`Mesh3D.vertices`), an AnyTensor of faces (`Mesh3D.faces`) and an
22-
AnyEmbedding (`Mesh3D.embedding`).
22+
The Mesh3D Document can contain an Mesh3DUrl (`Mesh3D.url`), a VerticesAndFaces
23+
object containing an AnyTensor of vertices (`Mesh3D.tensors.vertices) and an
24+
AnyTensor of faces (`Mesh3D.tensors.faces), and an AnyEmbedding
25+
(`Mesh3D.embedding`).
2326
2427
EXAMPLE USAGE:
2528
@@ -31,9 +34,9 @@ class Mesh3D(BaseDocument):
3134
3235
# use it directly
3336
mesh = Mesh3D(url='https://people.sc.fsu.edu/~jburkardt/data/obj/al.obj')
34-
mesh.vertices, mesh.faces = mesh.url.load()
37+
mesh.tensors = mesh.url.load()
3538
model = MyEmbeddingModel()
36-
mesh.embedding = model(mesh.vertices)
39+
mesh.embedding = model(mesh.tensors.vertices)
3740
3841
You can extend this Document:
3942
@@ -43,13 +46,14 @@ class Mesh3D(BaseDocument):
4346
from docarray.typing import AnyEmbedding
4447
from typing import Optional
4548
49+
4650
# extend it
4751
class MyMesh3D(Mesh3D):
4852
name: Optional[Text]
4953
5054
5155
mesh = MyMesh3D(url='https://people.sc.fsu.edu/~jburkardt/data/obj/al.obj')
52-
mesh.vertices, mesh.faces = mesh.url.load()
56+
mesh.tensors = mesh.url.load()
5357
model = MyEmbeddingModel()
5458
mesh.embedding = model(mesh.vertices)
5559
mesh.name = 'my first mesh'
@@ -62,6 +66,7 @@ class MyMesh3D(Mesh3D):
6266
from docarray import BaseDocument
6367
from docarray.documents import Mesh3D, Text
6468
69+
6570
# compose it
6671
class MultiModalDoc(BaseDocument):
6772
mesh: Mesh3D
@@ -72,16 +77,32 @@ class MultiModalDoc(BaseDocument):
7277
mesh=Mesh3D(url='https://people.sc.fsu.edu/~jburkardt/data/obj/al.obj'),
7378
text=Text(text='hello world, how are you doing?'),
7479
)
75-
mmdoc.mesh.vertices, mmdoc.mesh.faces = mmdoc.mesh.url.load()
80+
mmdoc.mesh.tensors = mmdoc.mesh.url.load()
7681
7782
# or
7883
mmdoc.mesh.bytes = mmdoc.mesh.url.load_bytes()
7984
85+
86+
You can display your 3D mesh in a notebook from either its url, or its tensors:
87+
88+
.. code-block:: python
89+
90+
from docarray.documents import Mesh3D
91+
92+
# display from url
93+
mesh = Mesh3D(url='https://people.sc.fsu.edu/~jburkardt/data/obj/al.obj')
94+
mesh.url.display()
95+
96+
# display from tensors
97+
mesh.tensors = mesh.url.load()
98+
model = MyEmbeddingModel()
99+
mesh.embedding = model(mesh.tensors.vertices)
100+
101+
80102
"""
81103

82104
url: Optional[Mesh3DUrl]
83-
vertices: Optional[AnyTensor]
84-
faces: Optional[AnyTensor]
105+
tensors: Optional[VerticesAndFaces]
85106
embedding: Optional[AnyEmbedding]
86107
bytes: Optional[bytes]
87108

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing import Any, Type, TypeVar, Union
2+
3+
from docarray.base_document import BaseDocument
4+
from docarray.typing.tensor.tensor import AnyTensor
5+
6+
T = TypeVar('T', bound='VerticesAndFaces')
7+
8+
9+
class VerticesAndFaces(BaseDocument):
10+
"""
11+
Document for handling 3D mesh tensor data.
12+
13+
A VerticesAndFaces Document can contain an AnyTensor containing the vertices
14+
information (`VerticesAndFaces.vertices`), and an AnyTensor containing the faces
15+
information (`VerticesAndFaces.faces`).
16+
"""
17+
18+
vertices: AnyTensor
19+
faces: AnyTensor
20+
21+
@classmethod
22+
def validate(
23+
cls: Type[T],
24+
value: Union[str, Any],
25+
) -> T:
26+
return super().validate(value)
27+
28+
def display(self) -> None:
29+
"""
30+
Plot mesh consisting of vertices and faces.
31+
To use this you need to install trimesh[easy]: `pip install 'trimesh[easy]'`.
32+
"""
33+
import trimesh
34+
from IPython.display import display
35+
36+
if self.vertices is None or self.faces is None:
37+
raise ValueError(
38+
'Can\'t display mesh from tensors when the vertices and/or faces '
39+
'are None.'
40+
)
41+
42+
mesh = trimesh.Trimesh(vertices=self.vertices, faces=self.faces)
43+
display(mesh.show())
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from docarray.documents.point_cloud.point_cloud_3d import PointCloud3D
2+
3+
__all__ = ['PointCloud3D']

docarray/documents/point_cloud.py renamed to docarray/documents/point_cloud/point_cloud_3d.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import numpy as np
44

55
from docarray.base_document import BaseDocument
6-
from docarray.typing import AnyEmbedding, AnyTensor, PointCloud3DUrl
6+
from docarray.documents.point_cloud.points_and_colors import PointsAndColors
7+
from docarray.typing import AnyEmbedding, PointCloud3DUrl
78
from docarray.typing.tensor.abstract_tensor import AbstractTensor
89
from docarray.utils.misc import is_tf_available, is_torch_available
910

@@ -27,8 +28,9 @@ class PointCloud3D(BaseDocument):
2728
representation, the point cloud is a fixed size ndarray (shape=(n_samples, 3)) and
2829
hence easier for deep learning algorithms to handle.
2930
30-
A PointCloud3D Document can contain an PointCloud3DUrl (`PointCloud3D.url`), an
31-
AnyTensor (`PointCloud3D.tensor`), and an AnyEmbedding (`PointCloud3D.embedding`).
31+
A PointCloud3D Document can contain an PointCloud3DUrl (`PointCloud3D.url`),
32+
a PointsAndColors object (`PointCloud3D.tensors`), and an AnyEmbedding
33+
(`PointCloud3D.embedding`).
3234
3335
EXAMPLE USAGE:
3436
@@ -40,9 +42,9 @@ class PointCloud3D(BaseDocument):
4042
4143
# use it directly
4244
pc = PointCloud3D(url='https://people.sc.fsu.edu/~jburkardt/data/obj/al.obj')
43-
pc.tensor = pc.url.load(samples=100)
45+
pc.tensors = pc.url.load(samples=100)
4446
model = MyEmbeddingModel()
45-
pc.embedding = model(pc.tensor)
47+
pc.embedding = model(pc.tensors.points)
4648
4749
You can extend this Document:
4850
@@ -58,10 +60,10 @@ class MyPointCloud3D(PointCloud3D):
5860
5961
6062
pc = MyPointCloud3D(url='https://people.sc.fsu.edu/~jburkardt/data/obj/al.obj')
61-
pc.tensor = pc.url.load(samples=100)
63+
pc.tensors = pc.url.load(samples=100)
6264
model = MyEmbeddingModel()
63-
pc.embedding = model(pc.tensor)
64-
pc.second_embedding = model(pc.tensor)
65+
pc.embedding = model(pc.tensors.points)
66+
pc.second_embedding = model(pc.tensors.colors)
6567
6668
6769
You can use this Document for composition:
@@ -83,16 +85,32 @@ class MultiModalDoc(BaseDocument):
8385
),
8486
text=Text(text='hello world, how are you doing?'),
8587
)
86-
mmdoc.point_cloud.tensor = mmdoc.point_cloud.url.load(samples=100)
88+
mmdoc.point_cloud.tensors = mmdoc.point_cloud.url.load(samples=100)
8789
8890
# or
8991
9092
mmdoc.point_cloud.bytes = mmdoc.point_cloud.url.load_bytes()
9193
94+
95+
You can display your point cloud from either its url, or its tensors:
96+
97+
.. code-block:: python
98+
99+
from docarray.documents import PointCloud3D
100+
101+
# display from url
102+
pc = PointCloud3D(url='https://people.sc.fsu.edu/~jburkardt/data/obj/al.obj')
103+
pc.url.display()
104+
105+
# display from tensors
106+
pc.tensors = pc.url.load(samples=10000)
107+
model = MyEmbeddingModel()
108+
pc.embedding = model(pc.tensors.points)
109+
92110
"""
93111

94112
url: Optional[PointCloud3DUrl]
95-
tensor: Optional[AnyTensor]
113+
tensors: Optional[PointsAndColors]
96114
embedding: Optional[AnyEmbedding]
97115
bytes: Optional[bytes]
98116

@@ -108,6 +126,6 @@ def validate(
108126
and isinstance(value, torch.Tensor)
109127
or (tf_available and isinstance(value, tf.Tensor))
110128
):
111-
value = cls(tensor=value)
129+
value = cls(tensors=PointsAndColors(points=value))
112130

113131
return super().validate(value)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from typing import Any, Optional, Type, TypeVar, Union
2+
3+
import numpy as np
4+
5+
from docarray.base_document import BaseDocument
6+
from docarray.typing import AnyTensor
7+
from docarray.typing.tensor.abstract_tensor import AbstractTensor
8+
from docarray.utils.misc import is_tf_available, is_torch_available
9+
10+
torch_available = is_torch_available()
11+
if torch_available:
12+
import torch
13+
14+
tf_available = is_tf_available()
15+
if tf_available:
16+
import tensorflow as tf # type: ignore
17+
18+
T = TypeVar('T', bound='PointsAndColors')
19+
20+
21+
class PointsAndColors(BaseDocument):
22+
"""
23+
Document for handling point clouds tensor data.
24+
25+
A PointsAndColors Document can contain an AnyTensor containing the points in
26+
3D space information (`PointsAndColors.points`), and an AnyTensor containing
27+
the points' color information (`PointsAndColors.colors`).
28+
"""
29+
30+
points: AnyTensor
31+
colors: Optional[AnyTensor]
32+
33+
@classmethod
34+
def validate(
35+
cls: Type[T],
36+
value: Union[str, AbstractTensor, Any],
37+
) -> T:
38+
if isinstance(value, (AbstractTensor, np.ndarray)) or (
39+
torch_available
40+
and isinstance(value, torch.Tensor)
41+
or (tf_available and isinstance(value, tf.Tensor))
42+
):
43+
value = cls(points=value)
44+
45+
return super().validate(value)
46+
47+
def display(self) -> None:
48+
"""
49+
Plot point cloud consisting of points in 3D space and optionally colors.
50+
To use this you need to install trimesh[easy]: `pip install 'trimesh[easy]'`.
51+
"""
52+
import trimesh
53+
from IPython.display import display
54+
55+
colors = (
56+
self.colors
57+
if self.colors is not None
58+
else np.tile(
59+
np.array([0, 0, 0]),
60+
(self.points.get_comp_backend().shape(self.points)[0], 1),
61+
)
62+
)
63+
pc = trimesh.points.PointCloud(vertices=self.points, colors=colors)
64+
65+
s = trimesh.Scene(geometry=pc)
66+
display(s.show())

docarray/typing/tensor/tensor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,10 @@
1313
from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor # noqa: F401
1414

1515

16+
AnyTensor = Union[NdArray]
1617
if torch_available and tf_available:
17-
AnyTensor = Union[NdArray, TorchTensor, TensorFlowTensor]
18+
AnyTensor = Union[NdArray, TorchTensor, TensorFlowTensor] # type: ignore
1819
elif torch_available:
1920
AnyTensor = Union[NdArray, TorchTensor] # type: ignore
2021
elif tf_available:
2122
AnyTensor = Union[NdArray, TensorFlowTensor] # type: ignore
22-
else:
23-
AnyTensor = Union[NdArray] # type: ignore
Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import NamedTuple, TypeVar
1+
from typing import TYPE_CHECKING, TypeVar
22

33
import numpy as np
44
from pydantic import parse_obj_as
@@ -7,12 +7,10 @@
77
from docarray.typing.tensor.ndarray import NdArray
88
from docarray.typing.url.url_3d.url_3d import Url3D
99

10-
T = TypeVar('T', bound='Mesh3DUrl')
11-
10+
if TYPE_CHECKING:
11+
from docarray.documents.mesh.vertices_and_faces import VerticesAndFaces
1212

13-
class Mesh3DLoadResult(NamedTuple):
14-
vertices: NdArray
15-
faces: NdArray
13+
T = TypeVar('T', bound='Mesh3DUrl')
1614

1715

1816
@_register_proto(proto_type_name='mesh_url')
@@ -22,9 +20,9 @@ class Mesh3DUrl(Url3D):
2220
Can be remote (web) URL, or a local file path.
2321
"""
2422

25-
def load(self: T) -> Mesh3DLoadResult:
23+
def load(self: T) -> 'VerticesAndFaces':
2624
"""
27-
Load the data from the url into a named tuple of two NdArrays containing
25+
Load the data from the url into a VerticesAndFaces object containing
2826
vertices and faces information.
2927
3028
EXAMPLE USAGE
@@ -34,7 +32,7 @@ def load(self: T) -> Mesh3DLoadResult:
3432
from docarray import BaseDocument
3533
import numpy as np
3634
37-
from docarray.typing import Mesh3DUrl
35+
from docarray.typing import Mesh3DUrl, NdArray
3836
3937
4038
class MyDoc(BaseDocument):
@@ -43,16 +41,29 @@ class MyDoc(BaseDocument):
4341
4442
doc = MyDoc(mesh_url="toydata/tetrahedron.obj")
4543
46-
vertices, faces = doc.mesh_url.load()
47-
assert isinstance(vertices, np.ndarray)
48-
assert isinstance(faces, np.ndarray)
44+
tensors = doc.mesh_url.load()
45+
assert isinstance(tensors.vertices, NdArray)
46+
assert isinstance(tensors.faces, NdArray)
47+
4948
50-
:return: named tuple of two NdArrays representing the mesh's vertices and faces
49+
:return: VerticesAndFaces object containing vertices and faces information.
5150
"""
51+
from docarray.documents.mesh.vertices_and_faces import VerticesAndFaces
5252

5353
mesh = self._load_trimesh_instance(force='mesh')
5454

5555
vertices = parse_obj_as(NdArray, mesh.vertices.view(np.ndarray))
5656
faces = parse_obj_as(NdArray, mesh.faces.view(np.ndarray))
5757

58-
return Mesh3DLoadResult(vertices=vertices, faces=faces)
58+
return VerticesAndFaces(vertices=vertices, faces=faces)
59+
60+
def display(self) -> None:
61+
"""
62+
Plot mesh from url.
63+
This loads the Trimesh instance of the 3D mesh, and then displays it.
64+
To use this you need to install trimesh[easy]: `pip install 'trimesh[easy]'`.
65+
"""
66+
from IPython.display import display
67+
68+
mesh = self._load_trimesh_instance()
69+
display(mesh.show())

0 commit comments

Comments
 (0)