Skip to content

Commit b2d3a52

Browse files
authored
fix: pass crawler configuration to storages (#375)
This makes sure that storages opened via `BasicCrawler.get_dataset` and the like will use the same configuration object as the crawler. While this is relevant to #351, it does not fully resolve it (see #152).
1 parent 9d28a3b commit b2d3a52

2 files changed

Lines changed: 22 additions & 5 deletions

File tree

src/crawlee/basic_crawler/basic_crawler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ async def get_request_provider(
288288
) -> RequestProvider:
289289
"""Return the configured request provider. If none is configured, open and return the default request queue."""
290290
if not self._request_provider:
291-
self._request_provider = await RequestQueue.open(id=id, name=name)
291+
self._request_provider = await RequestQueue.open(id=id, name=name, configuration=self._configuration)
292292

293293
return self._request_provider
294294

@@ -299,7 +299,7 @@ async def get_dataset(
299299
name: str | None = None,
300300
) -> Dataset:
301301
"""Return the dataset with the given ID or name. If none is provided, return the default dataset."""
302-
return await Dataset.open(id=id, name=name)
302+
return await Dataset.open(id=id, name=name, configuration=self._configuration)
303303

304304
async def get_key_value_store(
305305
self,
@@ -308,7 +308,7 @@ async def get_key_value_store(
308308
name: str | None = None,
309309
) -> KeyValueStore:
310310
"""Return the key-value store with the given ID or name. If none is provided, return the default KVS."""
311-
return await KeyValueStore.open(id=id, name=name)
311+
return await KeyValueStore.open(id=id, name=name, configuration=self._configuration)
312312

313313
def error_handler(
314314
self, handler: ErrorHandler[TCrawlingContext | BasicCrawlingContext]
@@ -468,7 +468,7 @@ async def export_data(
468468
dataset_id: The ID of the dataset.
469469
dataset_name: The name of the dataset.
470470
"""
471-
dataset = await Dataset.open(id=dataset_id, name=dataset_name)
471+
dataset = await self.get_dataset(id=dataset_id, name=dataset_name)
472472
path = path if isinstance(path, Path) else Path(path)
473473

474474
if content_type is None:
@@ -494,7 +494,7 @@ async def _push_data(
494494
dataset_name: The name of the dataset.
495495
kwargs: Keyword arguments to be passed to the dataset's `push_data` method.
496496
"""
497-
dataset = await Dataset.open(id=dataset_id, name=dataset_name)
497+
dataset = await self.get_dataset(id=dataset_id, name=dataset_name)
498498
await dataset.push_data(data, **kwargs)
499499

500500
def _should_retry_request(self, crawling_context: BasicCrawlingContext, error: Exception) -> bool:

tests/unit/basic_crawler/test_basic_crawler.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from crawlee.basic_crawler import BasicCrawler
1818
from crawlee.basic_crawler.errors import SessionError, UserDefinedErrorHandlerError
1919
from crawlee.basic_crawler.types import AddRequestsKwargs, BasicCrawlingContext
20+
from crawlee.configuration import Configuration
2021
from crawlee.enqueue_strategy import EnqueueStrategy
2122
from crawlee.models import BaseRequestData, Request
2223
from crawlee.storages import Dataset, KeyValueStore, RequestList, RequestQueue
@@ -586,3 +587,19 @@ def test_crawler_log() -> None:
586587
crawler = BasicCrawler()
587588
assert isinstance(crawler.log, logging.Logger)
588589
crawler.log.info('Test log message')
590+
591+
592+
async def test_passes_configuration_to_storages() -> None:
593+
configuration = Configuration(persist_storage=False, purge_on_start=True)
594+
595+
crawler = BasicCrawler(configuration=configuration)
596+
597+
dataset = await crawler.get_dataset()
598+
assert dataset._configuration is configuration
599+
600+
key_value_store = await crawler.get_key_value_store()
601+
assert key_value_store._configuration is configuration
602+
603+
request_provider = await crawler.get_request_provider()
604+
assert isinstance(request_provider, RequestQueue)
605+
assert request_provider._configuration is configuration

0 commit comments

Comments
 (0)