Skip to content

Commit c9a9b35

Browse files
authored
feat(hubble): add public parameter to da.push (docarray#318)
1 parent e9bc78c commit c9a9b35

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

docarray/array/mixins/io/pushpull.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class PushPullMixin:
5757

5858
_max_bytes = 4 * 1024 * 1024 * 1024
5959

60-
def push(self, name: str, show_progress: bool = False) -> Dict:
60+
def push(self, name: str, show_progress: bool = False, public: bool = True) -> Dict:
6161
"""Push this DocumentArray object to Jina Cloud which can be later retrieved via :meth:`.push`
6262
6363
.. note::
@@ -69,6 +69,7 @@ def push(self, name: str, show_progress: bool = False) -> Dict:
6969
7070
:param name: a name that later can be used for retrieve this :class:`DocumentArray`.
7171
:param show_progress: if to show a progress bar on pulling
72+
:param public: If True, the DocumentArray will be shared publicly. Otherwise, it will be private.
7273
"""
7374
import requests
7475

@@ -82,6 +83,7 @@ def push(self, name: str, show_progress: bool = False) -> Dict:
8283
),
8384
'name': name,
8485
'type': 'documentArray',
86+
'public': public,
8587
}
8688
)
8789

tests/unit/array/mixins/test_pushpull.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import cgi
12
import json
23
import os
34
import pytest
45
import requests
6+
from io import BytesIO
57

68
from docarray import DocumentArray
79
from docarray.array.mixins.io.pushpull import JINA_CLOUD_CONFIG
@@ -82,6 +84,27 @@ def test_push(mocker, monkeypatch):
8284
assert mock.call_count == 1
8385

8486

87+
@pytest.mark.parametrize('public', [True, False])
88+
def test_push_with_public(mocker, monkeypatch, public):
89+
mock = mocker.Mock()
90+
_mock_post(mock, monkeypatch)
91+
92+
docs = random_docs(2)
93+
docs.push(name='test_name', public=public)
94+
95+
_, mock_kwargs = mock.call_args_list[0]
96+
97+
c_type, c_data = cgi.parse_header(mock_kwargs['headers']['Content-Type'])
98+
assert c_type == 'multipart/form-data'
99+
100+
form_data = cgi.parse_multipart(
101+
BytesIO(b''.join(mock_kwargs['data'])),
102+
{'boundary': c_data['boundary'].encode()},
103+
)
104+
105+
assert form_data['public'] == [str(public)]
106+
107+
85108
def test_pull(mocker, monkeypatch):
86109
mock = mocker.Mock()
87110
_mock_get(mock, monkeypatch)
@@ -104,7 +127,7 @@ def test_push_fail(mocker, monkeypatch):
104127
_mock_post(mock, monkeypatch, status_code=requests.codes.forbidden)
105128

106129
docs = random_docs(2)
107-
with pytest.raises(Exception) as exc_info:
130+
with pytest.raises(Exception):
108131
docs.push('test_name')
109132

110133
assert mock.call_count == 1

0 commit comments

Comments
 (0)