Skip to content

Commit a33bc53

Browse files
committed
feat: add storage helpers to crawler & context
closes: #100 closes: #172
1 parent f48c806 commit a33bc53

15 files changed

Lines changed: 487 additions & 213 deletions

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
- Browser rotation with a maximum number of pages opened per browser.
1212
- Add emit persist state event to event manager
1313
- Add batched request addition in `RequestQueue`
14+
- Add start requests option to `BasicCrawler`
15+
- Add storage-related helpers `get_data`, `push_data` and `export_to` to `BasicCrawler` and `BasicContext`
1416

1517
## [0.0.4](../../releases/tag/v0.0.4) - 2024-05-30
1618

src/crawlee/basic_crawler/basic_crawler.py

Lines changed: 110 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import httpx
1313
from tldextract import TLDExtract
14-
from typing_extensions import NotRequired, TypedDict, TypeVar, assert_never
14+
from typing_extensions import NotRequired, TypedDict, TypeVar, Unpack, assert_never
1515

1616
from crawlee import Glob
1717
from crawlee._utils.wait import wait_for
@@ -32,10 +32,10 @@
3232
from crawlee.enqueue_strategy import EnqueueStrategy
3333
from crawlee.events.local_event_manager import LocalEventManager
3434
from crawlee.http_clients.httpx_client import HttpxClient
35-
from crawlee.models import BaseRequestData, Request, RequestState
35+
from crawlee.models import BaseRequestData, DatasetItemsListPage, Request, RequestState
3636
from crawlee.sessions import SessionPool
3737
from crawlee.statistics.statistics import Statistics
38-
from crawlee.storages import RequestQueue
38+
from crawlee.storages import Dataset, KeyValueStore, RequestQueue
3939

4040
if TYPE_CHECKING:
4141
import re
@@ -44,6 +44,7 @@
4444
from crawlee.proxy_configuration import ProxyConfiguration, ProxyInfo
4545
from crawlee.sessions.session import Session
4646
from crawlee.statistics.models import FinalStatistics, StatisticsState
47+
from crawlee.storages.dataset import ExportToKwargs, GetDataKwargs, PushDataKwargs
4748
from crawlee.storages.request_provider import RequestProvider
4849

4950
TCrawlingContext = TypeVar('TCrawlingContext', bound=BasicCrawlingContext, default=BasicCrawlingContext)
@@ -86,6 +87,7 @@ class BasicCrawler(Generic[TCrawlingContext]):
8687

8788
def __init__(
8889
self,
90+
start_requests: Sequence[str | BaseRequestData | Request] | None = None,
8991
*,
9092
request_provider: RequestProvider | None = None,
9193
request_handler: Callable[[TCrawlingContext], Awaitable[None]] | None = None,
@@ -106,6 +108,7 @@ def __init__(
106108
"""Initialize the BasicCrawler.
107109
108110
Args:
111+
start_requests: A list of URLs to start crawling from
109112
request_provider: Provides requests to be processed
110113
request_handler: A callable to which request handling is delegated
111114
http_client: HTTP client to be used for `BasicCrawlingContext.send_request` and HTTP-only crawling.
@@ -126,6 +129,7 @@ def __init__(
126129
This parameter is meant to be used by child classes, not when BasicCrawler is instantiated directly.
127130
_additional_context_managers: Additional context managers to be used in the crawler lifecycle.
128131
"""
132+
self._start_requests = start_requests or []
129133
self._router: Router[TCrawlingContext] | None = None
130134

131135
if isinstance(cast(Router, request_handler), Router):
@@ -227,13 +231,39 @@ async def _get_proxy_info(self, request: Request, session: Session | None) -> Pr
227231
proxy_tier=None,
228232
)
229233

230-
async def get_request_provider(self) -> RequestProvider:
234+
async def get_request_provider(
235+
self,
236+
*,
237+
id: str | None = None,
238+
name: str | None = None,
239+
configuration: Configuration | None = None,
240+
) -> RequestProvider:
231241
"""Return the configured request provider. If none is configured, open and return the default request queue."""
232242
if not self._request_provider:
233-
self._request_provider = await RequestQueue.open()
243+
self._request_provider = await RequestQueue.open(id=id, name=name, configuration=configuration)
234244

235245
return self._request_provider
236246

247+
async def get_dataset(
248+
self,
249+
*,
250+
id: str | None = None,
251+
name: str | None = None,
252+
configuration: Configuration | None = None,
253+
) -> Dataset:
254+
"""Return the dataset with the given ID or name. If none is provided, return the default dataset."""
255+
return await Dataset.open(id=id, name=name, configuration=configuration)
256+
257+
async def get_key_value_store(
258+
self,
259+
*,
260+
id: str | None = None,
261+
name: str | None = None,
262+
configuration: Configuration | None = None,
263+
) -> KeyValueStore:
264+
"""Return the key-value store with the given ID or name. If none is provided, return the default KVS."""
265+
return await KeyValueStore.open(id=id, name=name, configuration=configuration)
266+
237267
def error_handler(self, handler: ErrorHandler[TCrawlingContext]) -> ErrorHandler[TCrawlingContext]:
238268
"""Decorator for configuring an error handler (called after a request handler error and before retrying)."""
239269
self._error_handler = handler
@@ -246,7 +276,7 @@ def failed_request_handler(
246276
self._failed_request_handler = handler
247277
return handler
248278

249-
async def run(self, requests: list[str | BaseRequestData] | None = None) -> FinalStatistics:
279+
async def run(self, requests: Sequence[str | BaseRequestData | Request] | None = None) -> FinalStatistics:
250280
"""Run the crawler until all requests are processed."""
251281
if self._running:
252282
raise RuntimeError(
@@ -261,6 +291,8 @@ async def run(self, requests: list[str | BaseRequestData] | None = None) -> Fina
261291
if self._use_session_pool:
262292
await self._session_pool.reset_store()
263293

294+
await self.add_requests(self._start_requests)
295+
264296
if requests is not None:
265297
await self.add_requests(requests)
266298

@@ -286,12 +318,13 @@ async def run(self, requests: list[str | BaseRequestData] | None = None) -> Fina
286318

287319
self._running = False
288320
self._has_finished_before = True
321+
self._start_requests = [] # Clear the start requests to prevent them from being added again
289322

290323
return self._statistics.calculate()
291324

292325
async def add_requests(
293326
self,
294-
requests: Sequence[BaseRequestData | Request | str],
327+
requests: Sequence[str | BaseRequestData | Request],
295328
*,
296329
batch_size: int = 1000,
297330
wait_time_between_batches: timedelta = timedelta(0),
@@ -317,6 +350,73 @@ async def add_requests(
317350
wait_for_all_requests_to_be_added_timeout=wait_for_all_requests_to_be_added_timeout,
318351
)
319352

353+
async def push_data(
354+
self,
355+
dataset_id: str | None = None,
356+
dataset_name: str | None = None,
357+
configuration: Configuration | None = None,
358+
**kwargs: Unpack[PushDataKwargs],
359+
) -> None:
360+
"""Push data to a dataset.
361+
362+
This helper method simplifies the process of pushing data to a dataset. It opens the specified
363+
dataset and then pushes the provided data to it.
364+
365+
Args:
366+
data: The data to push to the dataset.
367+
dataset_id: The ID of the dataset.
368+
dataset_name: The name of the dataset.
369+
configuration: The configuration settings for accessing the dataset.
370+
kwargs: Keyword arguments to be passed to the dataset's `push_data` method.
371+
"""
372+
dataset = await Dataset.open(id=dataset_id, name=dataset_name, configuration=configuration)
373+
await dataset.push_data(**kwargs)
374+
375+
async def get_data(
376+
self,
377+
dataset_id: str | None = None,
378+
dataset_name: str | None = None,
379+
configuration: Configuration | None = None,
380+
**kwargs: Unpack[GetDataKwargs],
381+
) -> DatasetItemsListPage:
382+
"""Retrieve data from a dataset.
383+
384+
This helper method simplifies the process of retrieving data from a dataset. It opens the specified
385+
dataset and then retrieves the data based on the provided parameters.
386+
387+
Args:
388+
dataset_id: The ID of the dataset.
389+
dataset_name: The name of the dataset.
390+
configuration: The configuration settings for accessing the dataset.
391+
kwargs: Keyword arguments to be passed to the dataset's `get_data` method.
392+
393+
Returns:
394+
The retrieved data.
395+
"""
396+
dataset = await Dataset.open(id=dataset_id, name=dataset_name, configuration=configuration)
397+
return await dataset.get_data(**kwargs)
398+
399+
async def export_to(
400+
self,
401+
dataset_id: str | None = None,
402+
dataset_name: str | None = None,
403+
configuration: Configuration | None = None,
404+
**kwargs: Unpack[ExportToKwargs],
405+
) -> None:
406+
"""Export data from a dataset.
407+
408+
This helper method simplifies the process of exporting data from a dataset. It opens the specified
409+
dataset and then exports the data based on the provided parameters.
410+
411+
Args:
412+
dataset_id: The ID of the dataset.
413+
dataset_name: The name of the dataset.
414+
configuration: The configuration settings for accessing the dataset.
415+
kwargs: Keyword arguments to be passed to the dataset's `export_to` method.
416+
"""
417+
dataset = await Dataset.open(id=dataset_id, name=dataset_name, configuration=configuration)
418+
return await dataset.export_to(**kwargs)
419+
320420
def _should_retry_request(self, crawling_context: BasicCrawlingContext, error: Exception) -> bool:
321421
if crawling_context.request.no_retry:
322422
return False
@@ -517,6 +617,9 @@ async def __run_task_function(self) -> None:
517617
proxy_info=proxy_info,
518618
send_request=self._prepare_send_request_function(session, proxy_info),
519619
add_requests=result.add_requests,
620+
get_data=self.get_data,
621+
push_data=self.push_data,
622+
export_to=self.export_to,
520623
)
521624

522625
statistics_id = request.id or request.unique_key

src/crawlee/basic_crawler/types.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010

1111
if TYPE_CHECKING:
1212
from crawlee import Glob
13+
from crawlee.configuration import Configuration
1314
from crawlee.enqueue_strategy import EnqueueStrategy
1415
from crawlee.http_clients.base_http_client import HttpResponse
15-
from crawlee.models import BaseRequestData, Request
16+
from crawlee.models import BaseRequestData, DatasetItemsListPage, Request
1617
from crawlee.proxy_configuration import ProxyInfo
1718
from crawlee.sessions.session import Session
19+
from crawlee.storages.dataset import ExportToKwargs, GetDataKwargs, PushDataKwargs
1820

1921

2022
class AddRequestsFunctionKwargs(TypedDict):
@@ -28,10 +30,64 @@ class AddRequestsFunctionKwargs(TypedDict):
2830

2931

3032
class AddRequestsFunction(Protocol):
31-
"""Type of a function for adding URLs to the request queue with optional filtering."""
33+
"""Type of a function for adding URLs to the request queue with optional filtering.
34+
35+
This helper method simplifies the process of adding requests to the request provider. It opens the specified
36+
request provider and adds the requests to it.
37+
"""
3238

3339
def __call__( # noqa: D102
34-
self, requests: Sequence[str | BaseRequestData], **kwargs: Unpack[AddRequestsFunctionKwargs]
40+
self,
41+
requests: Sequence[str | BaseRequestData | Request],
42+
**kwargs: Unpack[AddRequestsFunctionKwargs],
43+
) -> Coroutine[None, None, None]: ...
44+
45+
46+
class GetDataFunction(Protocol):
47+
"""Type of a function for getting data from the dataset.
48+
49+
This helper method simplifies the process of retrieving data from a dataset. It opens the specified
50+
dataset and then retrieves the data based on the provided parameters.
51+
"""
52+
53+
def __call__( # noqa: D102
54+
self,
55+
dataset_id: str | None = None,
56+
dataset_name: str | None = None,
57+
configuration: Configuration | None = None,
58+
**kwargs: Unpack[GetDataKwargs],
59+
) -> Coroutine[None, None, DatasetItemsListPage]: ...
60+
61+
62+
class PushDataFunction(Protocol):
63+
"""Type of a function for pushing data to the dataset.
64+
65+
This helper method simplifies the process of pushing data to a dataset. It opens the specified
66+
dataset and then pushes the provided data to it.
67+
"""
68+
69+
def __call__( # noqa: D102
70+
self,
71+
dataset_id: str | None = None,
72+
dataset_name: str | None = None,
73+
configuration: Configuration | None = None,
74+
**kwargs: Unpack[PushDataKwargs],
75+
) -> Coroutine[None, None, None]: ...
76+
77+
78+
class ExportToFunction(Protocol):
79+
"""Type of a function for exporting data from a dataset.
80+
81+
This helper method simplifies the process of exporting data from a dataset. It opens the specified
82+
dataset and then exports its content to the key-value store.
83+
"""
84+
85+
def __call__( # noqa: D102
86+
self,
87+
dataset_id: str | None = None,
88+
dataset_name: str | None = None,
89+
configuration: Configuration | None = None,
90+
**kwargs: Unpack[ExportToKwargs],
3591
) -> Coroutine[None, None, None]: ...
3692

3793

@@ -69,12 +125,15 @@ class BasicCrawlingContext:
69125
proxy_info: ProxyInfo | None
70126
send_request: SendRequestFunction
71127
add_requests: AddRequestsFunction
128+
get_data: GetDataFunction
129+
push_data: PushDataFunction
130+
export_to: ExportToFunction
72131

73132

74133
class AddRequestsFunctionCall(AddRequestsFunctionKwargs):
75134
"""Record of a call to `add_requests`."""
76135

77-
requests: Sequence[str | BaseRequestData]
136+
requests: Sequence[str | BaseRequestData | Request]
78137

79138

80139
@dataclass()
@@ -84,7 +143,9 @@ class RequestHandlerRunResult:
84143
add_requests_calls: list[AddRequestsFunctionCall] = field(default_factory=list)
85144

86145
async def add_requests(
87-
self, requests: Sequence[str | BaseRequestData], **kwargs: Unpack[AddRequestsFunctionKwargs]
146+
self,
147+
requests: Sequence[str | BaseRequestData],
148+
**kwargs: Unpack[AddRequestsFunctionKwargs],
88149
) -> None:
89150
"""Track a call to the `add_requests` context helper."""
90151
self.add_requests_calls.append(AddRequestsFunctionCall(requests=requests, **kwargs))

src/crawlee/beautifulsoup_crawler/beautifulsoup_crawler.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,11 @@ async def _make_http_request(self, context: BasicCrawlingContext) -> AsyncGenera
7474
request=context.request,
7575
session=context.session,
7676
proxy_info=context.proxy_info,
77-
send_request=context.send_request,
7877
add_requests=context.add_requests,
78+
send_request=context.send_request,
79+
get_data=context.get_data,
80+
push_data=context.push_data,
81+
export_to=context.export_to,
7982
http_response=result.http_response,
8083
)
8184

@@ -134,9 +137,12 @@ async def enqueue_links(
134137
request=context.request,
135138
session=context.session,
136139
proxy_info=context.proxy_info,
137-
send_request=context.send_request,
138-
add_requests=context.add_requests,
139140
enqueue_links=enqueue_links,
141+
add_requests=context.add_requests,
142+
send_request=context.send_request,
143+
get_data=context.get_data,
144+
push_data=context.push_data,
145+
export_to=context.export_to,
140146
http_response=context.http_response,
141147
soup=soup,
142148
)

0 commit comments

Comments
 (0)