Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 48 additions & 37 deletions src/crawlee/autoscaling/autoscaled_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
if TYPE_CHECKING:
from crawlee.autoscaling import SystemStatus

__all__ = ['ConcurrencySettings', 'AutoscaledPool']

logger = getLogger(__name__)


Expand Down Expand Up @@ -63,6 +65,16 @@ def __init__(
self.max_tasks_per_minute = max_tasks_per_minute


class _AutoscaledPoolRun:
def __init__(self) -> None:
self.worker_tasks = list[asyncio.Task]()
"""A list of worker tasks currently in progress"""

self.worker_tasks_updated = asyncio.Event()
self.cleanup_done = asyncio.Event()
self.result: asyncio.Future = asyncio.Future()


class AutoscaledPool:
"""Manages a pool of asynchronous resource-intensive tasks that are executed in parallel.

Expand Down Expand Up @@ -131,13 +143,6 @@ def __init__(

self._autoscale_task = RecurringTask(self._autoscale, autoscale_interval)

self._worker_tasks = list[asyncio.Task]()
"""A list of worker tasks currently in progress"""

self._worker_tasks_updated = asyncio.Event()
self._cleanup_done = asyncio.Event()
self._run_result: asyncio.Future = asyncio.Future()

if desired_concurrency_ratio < 0 or desired_concurrency_ratio > 1:
raise ValueError('desired_concurrency_ratio must be between 0 and 1 (non-inclusive)')

Expand All @@ -154,32 +159,33 @@ def __init__(

self._max_tasks_per_minute = concurrency_settings.max_tasks_per_minute
self._is_paused = False
self._is_running = False
self._current_run: _AutoscaledPoolRun | None = None

async def run(self) -> None:
"""Start the autoscaled pool and return when all tasks are completed and `is_finished_function` returns True.

If there is an exception in one of the tasks, it will be re-raised.
"""
if self._is_running:
if self._current_run is not None:
raise RuntimeError('The pool is already running')

self._is_running = True
self._cleanup_done.clear()
run = _AutoscaledPoolRun()
self._current_run = run

logger.debug('Starting the pool')

self._autoscale_task.start()
self._log_system_status_task.start()

orchestrator = asyncio.create_task(
self._worker_task_orchestrator(), name='autoscaled pool worker task orchestrator'
self._worker_task_orchestrator(run), name='autoscaled pool worker task orchestrator'
)

try:
await self._run_result
await run.result
except AbortError:
orchestrator.cancel()
for task in self._worker_tasks:
for task in run.worker_tasks:
if not task.done():
task.cancel()
finally:
Expand All @@ -195,21 +201,23 @@ async def run(self) -> None:

logger.info('Waiting for remaining tasks to finish')

for task in self._worker_tasks:
for task in run.worker_tasks:
if not task.done():
with suppress(BaseException):
await task

self._run_result = asyncio.Future()
self._cleanup_done.set()
self._is_running = False
run.cleanup_done.set()
self._current_run = None

logger.debug('Pool cleanup finished')

async def abort(self) -> None:
"""Interrupt the autoscaled pool and all the tasks in progress."""
self._run_result.set_exception(AbortError())
await self._cleanup_done.wait()
if not self._current_run:
raise RuntimeError('The pool is not running')

self._current_run.result.set_exception(AbortError())
await self._current_run.cleanup_done.wait()

def pause(self) -> None:
"""Pause the autoscaled pool so that it does not start new tasks."""
Expand All @@ -227,7 +235,10 @@ def desired_concurrency(self) -> int:
@property
def current_concurrency(self) -> int:
"""The number of concurrent tasks in progress."""
return len(self._worker_tasks)
if self._current_run is None:
return 0

return len(self._current_run.worker_tasks)

def _autoscale(self) -> None:
"""Inspect system load status and adjust desired concurrency if necessary. Do not call directly."""
Expand Down Expand Up @@ -258,16 +269,16 @@ def _log_system_status(self) -> None:
f'{system_status!s}'
)

async def _worker_task_orchestrator(self) -> None:
async def _worker_task_orchestrator(self, run: _AutoscaledPoolRun) -> None:
"""Launches worker tasks whenever there is free capacity and a task is ready.

Exits when `is_finished_function` returns True.
"""
finished = False

try:
while not (finished := await self._is_finished_function()) and not self._run_result.done():
self._worker_tasks_updated.clear()
while not (finished := await self._is_finished_function()) and not run.result.done():
run.worker_tasks_updated.clear()

current_status = self._system_status.get_current_system_info()
if not current_status.is_system_idle:
Expand All @@ -281,44 +292,44 @@ async def _worker_task_orchestrator(self) -> None:
else:
logger.debug('Scheduling a new task')
worker_task = asyncio.create_task(self._worker_task(), name='autoscaled pool worker task')
worker_task.add_done_callback(lambda task: self._reap_worker_task(task))
self._worker_tasks.append(worker_task)
worker_task.add_done_callback(lambda task: self._reap_worker_task(task, run))
run.worker_tasks.append(worker_task)

if math.isfinite(self._max_tasks_per_minute):
await asyncio.sleep(60 / self._max_tasks_per_minute)

continue

with suppress(asyncio.TimeoutError):
await asyncio.wait_for(self._worker_tasks_updated.wait(), timeout=0.5)
await asyncio.wait_for(run.worker_tasks_updated.wait(), timeout=0.5)
finally:
if finished:
logger.debug('`is_finished_function` reports that we are finished')
elif self._run_result.done() and self._run_result.exception() is not None:
elif run.result.done() and run.result.exception() is not None:
logger.debug('Unhandled exception in `run_task_function`')

if self._worker_tasks:
if run.worker_tasks:
logger.debug('Terminating - waiting for tasks to complete')
await asyncio.wait(self._worker_tasks, return_when=asyncio.ALL_COMPLETED)
await asyncio.wait(run.worker_tasks, return_when=asyncio.ALL_COMPLETED)
logger.debug('Worker tasks finished')
else:
logger.debug('Terminating - no running tasks to wait for')

if not self._run_result.done():
self._run_result.set_result(object())
if not run.result.done():
run.result.set_result(object())

def _reap_worker_task(self, task: asyncio.Task) -> None:
def _reap_worker_task(self, task: asyncio.Task, run: _AutoscaledPoolRun) -> None:
"""A callback for finished worker tasks.

- It interrupts the run in case of an exception,
- keeps track of tasks in progress,
- notifies the orchestrator
"""
self._worker_tasks_updated.set()
self._worker_tasks.remove(task)
run.worker_tasks_updated.set()
run.worker_tasks.remove(task)

if not task.cancelled() and (exception := task.exception()) and not self._run_result.done():
self._run_result.set_exception(exception)
if not task.cancelled() and (exception := task.exception()) and not run.result.done():
run.result.set_exception(exception)

async def _worker_task(self) -> None:
try:
Expand Down