|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | from collections.abc import Iterator, Mapping |
| 4 | +from copy import deepcopy |
4 | 5 | from dataclasses import dataclass |
5 | 6 | from enum import Enum |
6 | 7 | from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Protocol, TypeVar, Union, cast, overload |
@@ -402,12 +403,17 @@ def __call__( |
402 | 403 |
|
403 | 404 | class RequestHandlerRunResult: |
404 | 405 | """Record of calls to storage-related context helpers.""" |
| 406 | + CRAWLEE_STATE_KEY = 'CRAWLEE_STATE' |
405 | 407 |
|
406 | 408 | def __init__(self, *, key_value_store_getter: GetKeyValueStoreFunction) -> None: |
407 | 409 | self._key_value_store_getter = key_value_store_getter |
408 | 410 | self.add_requests_calls = list[AddRequestsKwargs]() |
409 | 411 | self.push_data_calls = list[PushDataFunctionCall]() |
410 | 412 | self.key_value_store_changes = dict[tuple[Optional[str], Optional[str]], KeyValueStoreChangeRecords]() |
| 413 | + # This is handle to dict available to user. If it gets mutated, it needs to be reflected in changes. |
| 414 | + self._use_state_user: None | dict[str, JsonSerializable] = None |
| 415 | + # Last known use_state by RequestHandlerRunResult. Used for mutation detection by user. |
| 416 | + self._last_known_use_state: None | dict[str, JsonSerializable] = None |
411 | 417 |
|
412 | 418 | async def add_requests( |
413 | 419 | self, |
@@ -452,5 +458,31 @@ async def get_key_value_store( |
452 | 458 | return self.key_value_store_changes[id, name] |
453 | 459 |
|
454 | 460 |
|
455 | | - async def use_state(self): |
456 | | - # TODO: Somehow make crawlers add to kvs through this. Currently it does it directly |
| 461 | + async def use_state(self, default_value: dict[str, JsonSerializable] | None = None) -> dict[str, JsonSerializable]: |
| 462 | + # Find if the value is already present i |
| 463 | + _default: dict[str, JsonSerializable] = default_value or {} |
| 464 | + default_kvs_changes = await self.get_key_value_store() |
| 465 | + |
| 466 | + use_state: dict[str, JsonSerializable] = await default_kvs_changes.get_value(self.CRAWLEE_STATE_KEY, _default) |
| 467 | + |
| 468 | + if use_state is _default: |
| 469 | + # Set default value if there is no value in change records or actual kvs. |
| 470 | + await default_kvs_changes.set_value(self.CRAWLEE_STATE_KEY, _default) |
| 471 | + |
| 472 | + # This will be same dict that is available to the user and can be mutated at any point. |
| 473 | + self._use_state_user = use_state |
| 474 | + # This will not be available to user and should not be change. |
| 475 | + self._last_known_use_state = deepcopy(self._use_state_user) |
| 476 | + |
| 477 | + return use_state |
| 478 | + |
| 479 | + async def update_mutated_use_state(self) -> None: |
| 480 | + """Update use_state if it was mutated by the user.""" |
| 481 | + if self._use_state_user != self._last_known_use_state: |
| 482 | + default_kvs_changes = await self.get_key_value_store() |
| 483 | + await default_kvs_changes.set_value(self.CRAWLEE_STATE_KEY, self._use_state_user) |
| 484 | + |
| 485 | + |
| 486 | + |
| 487 | + |
| 488 | + |
0 commit comments