Skip to content

Commit d50249e

Browse files
authored
fix: Make result of RequestList.is_empty independent of fetch_next_request calls (#876)
The old version of `RequestList` only updated the `_is_empty` flag on `fetch_next_request` calls, which was not enough. This PR updates it so that the `is_empty` method also tries to dequeue a request from the iterator before returning a result.
1 parent beac9fa commit d50249e

File tree

2 files changed

+83
-18
lines changed

2 files changed

+83
-18
lines changed

src/crawlee/request_loaders/_request_list.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,12 @@ def __init__(
3939
self._assumed_total_count = 0
4040

4141
self._in_progress = set[str]()
42-
self._is_empty = False
42+
self._next: Request | None = None
4343

4444
if isinstance(requests, AsyncIterable):
4545
self._requests = requests.__aiter__()
4646
elif requests is None:
4747
self._requests = self._iterate_in_threadpool([])
48-
self._is_empty = True
4948
else:
5049
self._requests = self._iterate_in_threadpool(requests)
5150

@@ -61,30 +60,27 @@ async def get_total_count(self) -> int:
6160

6261
@override
6362
async def is_empty(self) -> bool:
64-
return self._is_empty
63+
await self._ensure_next_request()
64+
return self._next is None
6565

6666
@override
6767
async def is_finished(self) -> bool:
68-
return self._is_empty and len(self._in_progress) == 0
68+
return len(self._in_progress) == 0 and await self.is_empty()
6969

7070
@override
7171
async def fetch_next_request(self) -> Request | None:
72-
if self._is_empty:
72+
await self._ensure_next_request()
73+
74+
if self._next is None:
7375
return None
7476

75-
if self._requests_lock is None:
76-
self._requests_lock = asyncio.Lock()
77+
self._in_progress.add(self._next.id)
78+
self._assumed_total_count += 1
7779

78-
try:
79-
async with self._requests_lock:
80-
request = self._transform_request(await self._requests.__anext__())
81-
except StopAsyncIteration:
82-
self._is_empty = True
83-
return None
84-
else:
85-
self._in_progress.add(request.id)
86-
self._assumed_total_count += 1
87-
return request
80+
next_request = self._next
81+
self._next = None
82+
83+
return next_request
8884

8985
@override
9086
async def mark_request_as_handled(self, request: Request) -> None:
@@ -95,6 +91,17 @@ async def mark_request_as_handled(self, request: Request) -> None:
9591
async def get_handled_count(self) -> int:
9692
return self._handled_count
9793

94+
async def _ensure_next_request(self) -> None:
95+
if self._requests_lock is None:
96+
self._requests_lock = asyncio.Lock()
97+
98+
try:
99+
async with self._requests_lock:
100+
if self._next is None:
101+
self._next = self._transform_request(await self._requests.__anext__())
102+
except StopAsyncIteration:
103+
self._next = None
104+
98105
async def _iterate_in_threadpool(self, iterable: Iterable[str | Request]) -> AsyncIterator[str | Request]:
99106
"""Inspired by a function of the same name from encode/starlette."""
100107
iterator = iter(iterable)
@@ -115,4 +122,4 @@ def _next() -> str | Request:
115122
while True:
116123
yield await asyncio.to_thread(_next)
117124
except _StopIteration:
118-
raise StopAsyncIteration # noqa: B904
125+
return
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from collections.abc import AsyncGenerator
2+
3+
from crawlee.request_loaders._request_list import RequestList
4+
5+
6+
async def test_sync_traversal() -> None:
7+
request_list = RequestList(['https://a.com', 'https://b.com', 'https://c.com'])
8+
9+
while not await request_list.is_finished():
10+
item = await request_list.fetch_next_request()
11+
assert item is not None
12+
13+
await request_list.mark_request_as_handled(item)
14+
15+
assert await request_list.is_empty()
16+
17+
18+
async def test_async_traversal() -> None:
19+
async def generator() -> AsyncGenerator[str]:
20+
yield 'https://a.com'
21+
yield 'https://b.com'
22+
yield 'https://c.com'
23+
24+
request_list = RequestList(generator())
25+
26+
while not await request_list.is_finished():
27+
item = await request_list.fetch_next_request()
28+
assert item is not None
29+
30+
await request_list.mark_request_as_handled(item)
31+
32+
assert await request_list.is_empty()
33+
34+
35+
async def test_is_empty_does_not_depend_on_fetch_next_request() -> None:
36+
request_list = RequestList(['https://a.com', 'https://b.com', 'https://c.com'])
37+
38+
item_1 = await request_list.fetch_next_request()
39+
assert item_1 is not None
40+
assert not await request_list.is_finished()
41+
42+
item_2 = await request_list.fetch_next_request()
43+
assert item_2 is not None
44+
assert not await request_list.is_finished()
45+
46+
item_3 = await request_list.fetch_next_request()
47+
assert item_3 is not None
48+
assert not await request_list.is_finished()
49+
50+
assert await request_list.is_empty()
51+
assert not await request_list.is_finished()
52+
53+
await request_list.mark_request_as_handled(item_1)
54+
await request_list.mark_request_as_handled(item_2)
55+
await request_list.mark_request_as_handled(item_3)
56+
57+
assert await request_list.is_empty()
58+
assert await request_list.is_finished()

0 commit comments

Comments
 (0)