Skip to content

Commit d345259

Browse files
committed
use_state through RequestHandlerRunResult
1 parent 2408d85 commit d345259

8 files changed

Lines changed: 45 additions & 96 deletions

File tree

docs/examples/code/beautifulsoup_crawler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ async def main() -> None:
2525
@crawler.router.default_handler
2626
async def request_handler(context: BeautifulSoupCrawlingContext) -> None:
2727
context.log.info(f'Processing {context.request.url} ...')
28-
await context.use_state({"asd":"sad"})
28+
await context.use_state({'asd':'sad'})
2929
# Extract data from the page.
3030
data = {
3131
'url': context.request.url,

src/crawlee/_types.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import Iterator, Mapping
4+
from copy import deepcopy
45
from dataclasses import dataclass
56
from enum import Enum
67
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Protocol, TypeVar, Union, cast, overload
@@ -402,12 +403,17 @@ def __call__(
402403

403404
class RequestHandlerRunResult:
404405
"""Record of calls to storage-related context helpers."""
406+
CRAWLEE_STATE_KEY = 'CRAWLEE_STATE'
405407

406408
def __init__(self, *, key_value_store_getter: GetKeyValueStoreFunction) -> None:
407409
self._key_value_store_getter = key_value_store_getter
408410
self.add_requests_calls = list[AddRequestsKwargs]()
409411
self.push_data_calls = list[PushDataFunctionCall]()
410412
self.key_value_store_changes = dict[tuple[Optional[str], Optional[str]], KeyValueStoreChangeRecords]()
413+
# This is handle to dict available to user. If it gets mutated, it needs to be reflected in changes.
414+
self._use_state_user: None | dict[str, JsonSerializable] = None
415+
# Last known use_state by RequestHandlerRunResult. Used for mutation detection by user.
416+
self._last_known_use_state: None | dict[str, JsonSerializable] = None
411417

412418
async def add_requests(
413419
self,
@@ -452,5 +458,31 @@ async def get_key_value_store(
452458
return self.key_value_store_changes[id, name]
453459

454460

455-
async def use_state(self):
456-
# TODO: Somehow make crawlers add to kvs through this. Currently it does it directly
461+
async def use_state(self, default_value: dict[str, JsonSerializable] | None = None) -> dict[str, JsonSerializable]:
462+
# Find if the value is already present i
463+
_default: dict[str, JsonSerializable] = default_value or {}
464+
default_kvs_changes = await self.get_key_value_store()
465+
466+
use_state: dict[str, JsonSerializable] = await default_kvs_changes.get_value(self.CRAWLEE_STATE_KEY, _default)
467+
468+
if use_state is _default:
469+
# Set default value if there is no value in change records or actual kvs.
470+
await default_kvs_changes.set_value(self.CRAWLEE_STATE_KEY, _default)
471+
472+
# This will be same dict that is available to the user and can be mutated at any point.
473+
self._use_state_user = use_state
474+
# This will not be available to user and should not be change.
475+
self._last_known_use_state = deepcopy(self._use_state_user)
476+
477+
return use_state
478+
479+
async def update_mutated_use_state(self) -> None:
480+
"""Update use_state if it was mutated by the user."""
481+
if self._use_state_user != self._last_known_use_state:
482+
default_kvs_changes = await self.get_key_value_store()
483+
await default_kvs_changes.set_value(self.CRAWLEE_STATE_KEY, self._use_state_user)
484+
485+
486+
487+
488+

src/crawlee/crawlers/_adaptive_playwright/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from crawlee.crawlers._adaptive_playwright._adaptive_playwright_crawler import AdaptivePlaywrightCrawler
2-
from crawlee.crawlers._adaptive_playwright._adaptive_playwright_crawling_context import \
3-
AdaptivePlaywrightCrawlingContext
2+
from crawlee.crawlers._adaptive_playwright._adaptive_playwright_crawling_context import (
3+
AdaptivePlaywrightCrawlingContext,
4+
)
45

56
__all__ = [
67
'AdaptivePlaywrightCrawler',

src/crawlee/crawlers/_adaptive_playwright/_adaptive_playwright_crawler.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
BeautifulSoupCrawler,
1717
BeautifulSoupCrawlingContext,
1818
BeautifulSoupParserType,
19-
ContextPipeline,
2019
PlaywrightCrawler,
2120
PlaywrightCrawlingContext,
2221
PlaywrightPreNavCrawlingContext,
@@ -115,9 +114,6 @@ def __init__(self,
115114
playwright_crawler_args: PlaywrightCrawler only kwargs that are passed to the sub crawler.
116115
kwargs: Additional keyword arguments to pass to the underlying `BasicCrawler`.
117116
"""
118-
119-
120-
121117
# Some sub crawler kwargs are internally modified. Prepare copies.
122118
bs_kwargs = deepcopy(kwargs)
123119
pw_kwargs = deepcopy(kwargs)
@@ -193,7 +189,6 @@ async def run(
193189
purge_request_queue: If this is `True` and the crawler is not being run for the first time, the default
194190
request queue will be purged.
195191
"""
196-
197192
# TODO: Create something more robust that does not leak implementation so much
198193
async with (self.beautifulsoup_crawler.statistics, self.playwright_crawler.statistics,
199194
self.playwright_crawler._additional_context_managers[0]):
@@ -249,6 +244,8 @@ async def _run_subcrawler(crawler: BeautifulSoupCrawler | PlaywrightCrawler,
249244

250245
context.log.debug(f'Running browser request handler for {context.request.url}')
251246

247+
248+
# This might not be needed if kvs access is properly routed through results and we commit PW result in the end of the function
252249
kvs = await context.get_key_value_store()
253250
default_value =dict[str, JsonSerializable]()
254251
old_state: dict[str, JsonSerializable] = await kvs.get_value(BasicCrawler.CRAWLEE_STATE_KEY, default_value)

src/crawlee/crawlers/_adaptive_playwright/_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ async def main() ->None:
1818

1919
crawler = AdaptivePlaywrightCrawler(max_requests_per_crawl=10,
2020
_logger=top_logger,
21-
playwright_crawler_args={"headless":False})
21+
playwright_crawler_args={'headless':False})
2222

2323
@crawler.router.default_handler
2424
async def request_handler(context: AdaptivePlaywrightCrawlingContext) -> None:
@@ -27,7 +27,7 @@ async def request_handler(context: AdaptivePlaywrightCrawlingContext) -> None:
2727
context.log.info(f'Processing with Top adaptive_crawler: {context.request.url} ...')
2828
await context.enqueue_links()
2929
await context.push_data({'Top crwaler Url': context.request.url})
30-
await context.use_state({"bla":i})
30+
await context.use_state({'bla':i})
3131

3232
@crawler.pre_navigation_hook_bs
3333
async def bs_hook(context: BasicCrawlingContext) -> None:

src/crawlee/crawlers/_basic/_basic_crawler.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -574,12 +574,6 @@ async def add_requests(
574574
wait_for_all_requests_to_be_added_timeout=wait_for_all_requests_to_be_added_timeout,
575575
)
576576

577-
async def _use_state(
578-
self, default_value: dict[str, JsonSerializable] | None = None
579-
) -> dict[str, JsonSerializable]:
580-
store = await self.get_key_value_store()
581-
return await store.get_auto_saved_value(BasicCrawler.CRAWLEE_STATE_KEY, default_value)
582-
583577
async def _save_crawler_state(self) -> None:
584578
store = await self.get_key_value_store()
585579
await store.persist_autosaved_values()
@@ -951,6 +945,7 @@ async def _commit_request_handler_result(
951945

952946

953947
async def _commit_key_value_store_changes(self, result: RequestHandlerRunResult) -> None:
948+
await result.update_mutated_use_state()
954949
for (id, name), changes in result.key_value_store_changes.items():
955950
store = await self.get_key_value_store(id=id, name=name)
956951
for key, value in changes.updates.items():
@@ -1011,7 +1006,7 @@ async def __run_task_function(self) -> None:
10111006
add_requests=result.add_requests,
10121007
push_data=result.push_data,
10131008
get_key_value_store=result.get_key_value_store,
1014-
use_state=self._use_state,
1009+
use_state=result.use_state,
10151010
log=self._logger,
10161011
)
10171012

src/crawlee/storages/_key_value_store.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -182,37 +182,6 @@ async def get_public_url(self, key: str) -> str:
182182
"""
183183
return await self._resource_client.get_public_url(key)
184184

185-
async def get_auto_saved_value(
186-
self,
187-
key: str,
188-
default_value: dict[str, JsonSerializable] | None = None,
189-
) -> dict[str, JsonSerializable]:
190-
"""Gets a value from KVS that will be automatically saved on changes.
191-
192-
Args:
193-
key: Key of the record, to store the value.
194-
default_value: Value to be used if the record does not exist yet. Should be a dictionary.
195-
196-
Returns:
197-
Returns the value of the key.
198-
"""
199-
default_value = {} if default_value is None else default_value
200-
201-
if key in self._cache:
202-
return self._cache[key]
203-
204-
value = await self.get_value(key, default_value)
205-
206-
if not isinstance(value, dict):
207-
raise TypeError(
208-
f'Expected dictionary for persist state value at key "{key}, but got {type(value).__name__}'
209-
)
210-
211-
self._cache[key] = value
212-
213-
self._ensure_persist_event()
214-
215-
return value
216185

217186
@property
218187
def _cache(self) -> dict[str, dict[str, JsonSerializable]]:

tests/unit/storages/test_key_value_store.py

Lines changed: 1 addition & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import asyncio
4-
from datetime import datetime, timedelta, timezone
4+
from datetime import timedelta
55
from typing import TYPE_CHECKING
66
from unittest.mock import patch
77
from urllib.parse import urlparse
@@ -14,7 +14,6 @@
1414
if TYPE_CHECKING:
1515
from collections.abc import AsyncGenerator
1616

17-
from crawlee._types import JsonSerializable
1817

1918

2019
@pytest.fixture
@@ -134,47 +133,3 @@ async def test_get_public_url(key_value_store: KeyValueStore) -> None:
134133
with open(path) as f: # noqa: ASYNC230
135134
content = await asyncio.to_thread(f.read)
136135
assert content == 'static'
137-
138-
139-
async def test_get_auto_saved_value_default_value(key_value_store: KeyValueStore) -> None:
140-
default_value: dict[str, JsonSerializable] = {'hello': 'world'}
141-
value = await key_value_store.get_auto_saved_value('state', default_value)
142-
assert value == default_value
143-
144-
145-
async def test_get_auto_saved_value_cache_value(key_value_store: KeyValueStore) -> None:
146-
default_value: dict[str, JsonSerializable] = {'hello': 'world'}
147-
key_name = 'state'
148-
149-
value = await key_value_store.get_auto_saved_value(key_name, default_value)
150-
value['hello'] = 'new_world'
151-
value_one = await key_value_store.get_auto_saved_value(key_name)
152-
assert value_one == {'hello': 'new_world'}
153-
154-
value_one['hello'] = ['new_world']
155-
value_two = await key_value_store.get_auto_saved_value(key_name)
156-
assert value_two == {'hello': ['new_world']}
157-
158-
159-
async def test_get_auto_saved_value_auto_save(key_value_store: KeyValueStore, mock_event_manager: EventManager) -> None: # noqa: ARG001
160-
# This is not a realtime system and timing constrains can be hard to enforce.
161-
# For the test to avoid flakiness it needs some time tolerance.
162-
autosave_deadline_time = 1
163-
autosave_check_period = 0.01
164-
165-
async def autosaved_within_deadline(key: str, expected_value: dict[str, str]) -> bool:
166-
"""Check if the `key_value_store` of `key` has expected value within `autosave_deadline_time` seconds."""
167-
deadline = datetime.now(tz=timezone.utc) + timedelta(seconds=autosave_deadline_time)
168-
while datetime.now(tz=timezone.utc) < deadline:
169-
await asyncio.sleep(autosave_check_period)
170-
if await key_value_store.get_value(key) == expected_value:
171-
return True
172-
return False
173-
174-
default_value: dict[str, JsonSerializable] = {'hello': 'world'}
175-
key_name = 'state'
176-
value = await key_value_store.get_auto_saved_value(key_name, default_value)
177-
assert await autosaved_within_deadline(key=key_name, expected_value={'hello': 'world'})
178-
179-
value['hello'] = 'new_world'
180-
assert await autosaved_within_deadline(key=key_name, expected_value={'hello': 'new_world'})

0 commit comments

Comments
 (0)