Skip to content

Commit 41869a3

Browse files
authored
Arrow chunk_size as keyword argument (#3084)
* Fix #2998: Arrow chunk_size as keyword argument * Adaptive chunk size logic for get_as_arrow * Run formatter * Fix missing kwarg * Fix chunk sizes for arrow tests * Provide polars users the means to customize chunk_size * test small chunk_size for polars and arrow * Revert to small test chunk sizes * Rework chunk_size defaults and conditions * Fix conditional logic * Cover 0, -1, None and fixed int params for chunk_size in tests
1 parent ae42b3b commit 41869a3

2 files changed

Lines changed: 38 additions & 15 deletions

File tree

src_py/query_result.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,20 @@ def get_as_df(self) -> pd.DataFrame:
123123

124124
return self._query_result.getAsDF()
125125

126-
def get_as_pl(self) -> pl.DataFrame:
126+
def get_as_pl(self, chunk_size: int | None = None) -> pl.DataFrame:
127127
"""
128128
Get the query result as a Polars DataFrame.
129129
130+
Parameters
131+
----------
132+
chunk_size : Number of rows to include in each chunk.
133+
None
134+
The chunk size is adaptive and depends on the number of columns in the query result.
135+
-1 or 0
136+
The entire result is returned as a single chunk.
137+
> 0
138+
The chunk size is the number of elements specified.
139+
130140
See Also
131141
--------
132142
get_as_df : Get the query result as a Pandas DataFrame.
@@ -139,20 +149,23 @@ def get_as_pl(self) -> pl.DataFrame:
139149
"""
140150
import polars as pl
141151

142-
target_n_elems = 10_000_000 # adaptive chunk_size; target 10m elements per chunk
143-
target_chunk_size = max(target_n_elems // len(self.get_column_names()), 10)
144-
return pl.from_arrow( # type: ignore[return-value]
145-
data=self.get_as_arrow(chunk_size=target_chunk_size),
146-
)
152+
self.check_for_query_result_close()
153+
154+
return pl.from_arrow(data=self.get_as_arrow(chunk_size=chunk_size))
147155

148-
def get_as_arrow(self, chunk_size: int) -> pa.Table:
156+
def get_as_arrow(self, chunk_size: int | None = None) -> pa.Table:
149157
"""
150158
Get the query result as a PyArrow Table.
151159
152160
Parameters
153161
----------
154-
chunk_size : int
155-
Number of rows to include in each chunk.
162+
chunk_size : Number of rows to include in each chunk.
163+
None
164+
The chunk size is adaptive and depends on the number of columns in the query result.
165+
-1 or 0
166+
The entire result is returned as a single chunk.
167+
> 0
168+
The chunk size is the number of elements specified.
156169
157170
See Also
158171
--------
@@ -166,7 +179,17 @@ def get_as_arrow(self, chunk_size: int) -> pa.Table:
166179
"""
167180
self.check_for_query_result_close()
168181

169-
return self._query_result.getAsArrow(chunk_size)
182+
if chunk_size is None:
183+
# Adaptive chunk_size; target number of elements per chunk_size
184+
target_chunk_size = max(1_000_000 // len(self.get_column_names()), 10)
185+
elif chunk_size <= 0:
186+
# No chunking: return the entire result as a single chunk
187+
target_chunk_size = self.get_num_tuples()
188+
else:
189+
# Chunk size is the number of elements specified
190+
target_chunk_size = chunk_size
191+
192+
return self._query_result.getAsArrow(target_chunk_size)
170193

171194
def get_column_data_types(self) -> list[str]:
172195
"""

test/test_arrow.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def _test_with_nulls(_conn: kuzu.Connection, return_type: str, chunk_size: int |
459459
_test_utf8_string(conn, "arrow", 3)
460460
_test_utf8_string(conn, "pl")
461461
_test_in_small_chunk_size(conn, "arrow", 4)
462-
_test_in_small_chunk_size(conn, "pl")
462+
_test_in_small_chunk_size(conn, "pl", 4)
463463
_test_with_nulls(conn, "arrow", 12)
464464
_test_with_nulls(conn, "pl")
465465

@@ -470,7 +470,7 @@ def test_to_arrow_complex(conn_db_readonly: ConnDB) -> None:
470470
def _test_node(_conn: kuzu.Connection) -> None:
471471
query = "MATCH (p:person) RETURN p ORDER BY p.ID"
472472
query_result = _conn.execute(query)
473-
arrow_tbl = query_result.get_as_arrow(12)
473+
arrow_tbl = query_result.get_as_arrow()
474474
p_col = arrow_tbl.column(0)
475475

476476
assert p_col.to_pylist() == [
@@ -487,7 +487,7 @@ def _test_node(_conn: kuzu.Connection) -> None:
487487
def _test_node_rel(_conn: kuzu.Connection) -> None:
488488
query = "MATCH (a:person)-[e:workAt]->(b:organisation) RETURN a, e, b;"
489489
query_result = _conn.execute(query)
490-
arrow_tbl = query_result.get_as_arrow(12)
490+
arrow_tbl = query_result.get_as_arrow(0)
491491
assert arrow_tbl.num_columns == 3
492492
a_col = arrow_tbl.column(0)
493493
assert len(a_col) == 3
@@ -528,7 +528,7 @@ def _test_node_rel(_conn: kuzu.Connection) -> None:
528528

529529
def _test_marries_table(_conn: kuzu.Connection) -> None:
530530
query = "MATCH (:person)-[e:marries]->(:person) RETURN e.*"
531-
arrow_tbl = _conn.execute(query).get_as_arrow(8)
531+
arrow_tbl = _conn.execute(query).get_as_arrow(0)
532532
assert arrow_tbl.num_columns == 3
533533

534534
used_addr_col = arrow_tbl.column(0)
@@ -553,5 +553,5 @@ def _test_marries_table(_conn: kuzu.Connection) -> None:
553553
def test_to_arrow1(conn_db_readonly: ConnDB) -> None:
554554
conn, db = conn_db_readonly
555555
query = "MATCH (a:person)-[e:knows]->(:person) RETURN e.summary"
556-
arrow_tbl = conn.execute(query).get_as_arrow(8)
556+
arrow_tbl = conn.execute(query).get_as_arrow(-1)
557557
assert arrow_tbl == []

0 commit comments

Comments
 (0)