Skip to content

Commit b418858

Browse files
authored
Feat add show label (docarray#316)
1 parent 1e83c14 commit b418858

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

docarray/array/mixins/plot.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import warnings
88
from collections import Counter
99
from math import sqrt, ceil, floor
10-
from typing import Optional
10+
from typing import Optional, Tuple
1111

1212
import numpy as np
1313

@@ -317,6 +317,8 @@ def plot_image_sprites(
317317
image_source: str = 'tensor',
318318
skip_empty: bool = False,
319319
show_progress: bool = False,
320+
show_index: bool = False,
321+
fig_size: Optional[Tuple[int, int]] = None,
320322
) -> None:
321323
"""Generate a sprite image for all image tensors in this DocumentArray-like object.
322324
@@ -351,9 +353,12 @@ def plot_image_sprites(
351353
img_id = 0
352354

353355
from rich.progress import track
356+
from PIL import Image, ImageDraw
354357

355358
try:
356-
for d in track(self, description='Plotting', disable=not show_progress):
359+
for _idx, d in enumerate(
360+
track(self, description='Plotting', disable=not show_progress)
361+
):
357362

358363
if not d.uri and d.tensor is None:
359364
if skip_empty:
@@ -379,6 +384,13 @@ def plot_image_sprites(
379384

380385
row_id = floor(img_id / img_per_row)
381386
col_id = img_id % img_per_row
387+
388+
if show_index:
389+
_img = Image.fromarray(_d.tensor)
390+
draw = ImageDraw.Draw(_img)
391+
draw.text((0, 0), str(_idx), (255, 255, 255))
392+
_d.tensor = np.asarray(_img)
393+
382394
sprite_img[
383395
(row_id * img_size) : ((row_id + 1) * img_size),
384396
(col_id * img_size) : ((col_id + 1) * img_size),
@@ -392,14 +404,14 @@ def plot_image_sprites(
392404
'Bad image tensor. Try different `image_source` or `channel_axis`'
393405
) from ex
394406

395-
from PIL import Image
396-
397407
im = Image.fromarray(sprite_img)
398408

399409
if output:
400410
with open(output, 'wb') as fp:
401411
im.save(fp)
402412
else:
413+
if fig_size:
414+
plt.figure(figsize=fig_size, frameon=False)
403415
plt.gca().set_axis_off()
404416
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
405417
plt.margins(0, 0)

docarray/document/mixins/plot.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,6 @@ def display(self):
100100
else:
101101
self.summary()
102102

103-
plot = deprecate_by(display, removed_at='0.5')
104-
105103
def plot_matches_sprites(
106104
self,
107105
top_k: int = 10,

0 commit comments

Comments
 (0)