Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: add display from param to mesh and pc display
Signed-off-by: anna-charlotte <[email protected]>
  • Loading branch information
anna-charlotte committed Feb 15, 2023
commit e979c32fbe06090ad5a1205b40749bb3fcb895c0
15 changes: 13 additions & 2 deletions docarray/documents/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,21 @@ def validate(
value = cls(url=value)
return super().validate(value)

def display(self):
"""Plot mesh consisting of vertices and faces."""
def display(self, display_from: str = 'url'):
"""
Plot mesh consisting of vertices and faces.
"""
from IPython.display import display

if display_from not in ['tensor', 'url']:
raise ValueError(f'Expected one of ["tensor", "url"], got "{display_from}"')

if not getattr(self, display_from):
raise ValueError(
f'Can not to display point cloud from {display_from} when the '
f'{display_from} is None.'
)

if self.url:
# mesh from uri
mesh = self.url._load_trimesh_instance()
Expand Down
42 changes: 29 additions & 13 deletions docarray/documents/point_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,25 +113,41 @@ def validate(

return super().validate(value)

def display(self) -> None:
"""Plot interactive point cloud from :attr:`.tensor`"""
def display(self, display_from: str = 'url', samples: int = 10000) -> None:
"""
Plot interactive point cloud from :attr:`.tensor`
"""
import trimesh
from hubble.utils.notebook import is_notebook
from IPython.display import display

colors = (
self.color_tensor
if self.color_tensor
else np.tile(
np.array([0, 0, 0]),
(self.tensor.get_comp_backend().shape(self.tensor)[0], 1),
if display_from not in ['tensor', 'url']:
raise ValueError(f'Expected one of ["tensor", "url"], got "{display_from}"')

if not getattr(self, display_from):
raise ValueError(
f'Can not to display point cloud from {display_from} when the '
f'{display_from} is None.'
)
)

pc = trimesh.points.PointCloud(
vertices=self.tensor,
colors=colors,
)
if display_from == 'url':
tensor = self.url.load(samples=samples)
colors = np.tile(
np.array([0, 0, 0]), (tensor.get_comp_backend().shape(tensor)[0], 1)
)
else:
tensor = self.tensor
comp_be = self.tensor.get_comp_backend()
colors = (
self.color_tensor
if self.color_tensor
else np.tile(
np.array([0, 0, 0]),
(comp_be.shape(tensor)[0], 1),
)
)

pc = trimesh.points.PointCloud(vertices=tensor, colors=colors)

if is_notebook():
s = trimesh.Scene(geometry=pc)
Expand Down
10 changes: 10 additions & 0 deletions tests/integrations/predefined_document/test_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,13 @@ class MyDoc(BaseDocument):

assert doc.mesh1.url == 'http://hello.ply'
assert doc.mesh2.url == 'http://hello.ply'


def test_display_illegal_param():
mesh = Mesh3D(url='http://myurl.ply')
with pytest.raises(ValueError):
mesh.display(display_from='tensor')

mesh = Mesh3D(vertices=np.zeros((10, 3)), faces=np.ones(10, 3))
with pytest.raises(ValueError):
mesh.display(display_from='url')
34 changes: 22 additions & 12 deletions tests/integrations/predefined_document/test_point_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,19 @@ def test_point_cloud(file_url):


def test_point_cloud_np():
pc = parse_obj_as(PointCloud3D, np.zeros((10, 10, 3)))
assert (pc.tensor == np.zeros((10, 10, 3))).all()
pc = parse_obj_as(PointCloud3D, np.zeros((10, 3)))
assert (pc.tensor == np.zeros((10, 3))).all()


def test_point_cloud_torch():
pc = parse_obj_as(PointCloud3D, torch.zeros(10, 10, 3))
assert (pc.tensor == torch.zeros(10, 10, 3)).all()
pc = parse_obj_as(PointCloud3D, torch.zeros(10, 3))
assert (pc.tensor == torch.zeros(10, 3)).all()


@pytest.mark.tensorflow
def test_point_cloud_tensorflow():
pc = parse_obj_as(PointCloud3D, tf.zeros((10, 10, 3)))
assert tnp.allclose(pc.tensor.tensor, tf.zeros((10, 10, 3)))
pc = parse_obj_as(PointCloud3D, tf.zeros((10, 3)))
assert tnp.allclose(pc.tensor.tensor, tf.zeros((10, 3)))


def test_point_cloud_shortcut_doc():
Expand All @@ -53,12 +53,12 @@ class MyDoc(BaseDocument):

doc = MyDoc(
pc='http://myurl.ply',
pc2=np.zeros((10, 10, 3)),
pc3=torch.zeros(10, 10, 3),
pc2=np.zeros((10, 3)),
pc3=torch.zeros(10, 3),
)
assert doc.pc.url == 'http://myurl.ply'
assert (doc.pc2.tensor == np.zeros((10, 10, 3))).all()
assert (doc.pc3.tensor == torch.zeros(10, 10, 3)).all()
assert (doc.pc2.tensor == np.zeros((10, 3))).all()
assert (doc.pc3.tensor == torch.zeros(10, 3)).all()


@pytest.mark.tensorflow
Expand All @@ -69,7 +69,17 @@ class MyDoc(BaseDocument):

doc = MyDoc(
pc='http://myurl.ply',
pc2=tf.zeros((10, 10, 3)),
pc2=tf.zeros((10, 3)),
)
assert doc.pc.url == 'http://myurl.ply'
assert tnp.allclose(doc.pc2.tensor.tensor, tf.zeros((10, 10, 3)))
assert tnp.allclose(doc.pc2.tensor.tensor, tf.zeros((10, 3)))


def test_display_illegal_param():
pc = PointCloud3D(url='http://myurl.ply')
with pytest.raises(ValueError):
pc.display(display_from='tensor')

pc = PointCloud3D(tensor=np.zeros((10, 3)))
with pytest.raises(ValueError):
pc.display(display_from='url')