Skip to content

Commit

Permalink
use_state through RequestHandlerRunResult
Browse files Browse the repository at this point in the history
  • Loading branch information
Pijukatel committed Jan 3, 2025
1 parent 2408d85 commit d345259
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 96 deletions.
2 changes: 1 addition & 1 deletion docs/examples/code/beautifulsoup_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async def main() -> None:
@crawler.router.default_handler
async def request_handler(context: BeautifulSoupCrawlingContext) -> None:
context.log.info(f'Processing {context.request.url} ...')
await context.use_state({"asd":"sad"})
await context.use_state({'asd':'sad'})
# Extract data from the page.
data = {
'url': context.request.url,
Expand Down
36 changes: 34 additions & 2 deletions src/crawlee/_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections.abc import Iterator, Mapping
from copy import deepcopy
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Protocol, TypeVar, Union, cast, overload
Expand Down Expand Up @@ -402,12 +403,17 @@ def __call__(

class RequestHandlerRunResult:
"""Record of calls to storage-related context helpers."""
CRAWLEE_STATE_KEY = 'CRAWLEE_STATE'

def __init__(self, *, key_value_store_getter: GetKeyValueStoreFunction) -> None:
self._key_value_store_getter = key_value_store_getter
self.add_requests_calls = list[AddRequestsKwargs]()
self.push_data_calls = list[PushDataFunctionCall]()
self.key_value_store_changes = dict[tuple[Optional[str], Optional[str]], KeyValueStoreChangeRecords]()
# This is handle to dict available to user. If it gets mutated, it needs to be reflected in changes.
self._use_state_user: None | dict[str, JsonSerializable] = None
# Last known use_state by RequestHandlerRunResult. Used for mutation detection by user.
self._last_known_use_state: None | dict[str, JsonSerializable] = None

async def add_requests(
self,
Expand Down Expand Up @@ -452,5 +458,31 @@ async def get_key_value_store(
return self.key_value_store_changes[id, name]


async def use_state(self):
# TODO: Somehow make crawlers add to kvs through this. Currently it does it directly
async def use_state(self, default_value: dict[str, JsonSerializable] | None = None) -> dict[str, JsonSerializable]:
# Find if the value is already present i
_default: dict[str, JsonSerializable] = default_value or {}
default_kvs_changes = await self.get_key_value_store()

use_state: dict[str, JsonSerializable] = await default_kvs_changes.get_value(self.CRAWLEE_STATE_KEY, _default)

if use_state is _default:
# Set default value if there is no value in change records or actual kvs.
await default_kvs_changes.set_value(self.CRAWLEE_STATE_KEY, _default)

# This will be same dict that is available to the user and can be mutated at any point.
self._use_state_user = use_state
# This will not be available to user and should not be change.
self._last_known_use_state = deepcopy(self._use_state_user)

return use_state

async def update_mutated_use_state(self) -> None:
"""Update use_state if it was mutated by the user."""
if self._use_state_user != self._last_known_use_state:
default_kvs_changes = await self.get_key_value_store()
await default_kvs_changes.set_value(self.CRAWLEE_STATE_KEY, self._use_state_user)





5 changes: 3 additions & 2 deletions src/crawlee/crawlers/_adaptive_playwright/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from crawlee.crawlers._adaptive_playwright._adaptive_playwright_crawler import AdaptivePlaywrightCrawler
from crawlee.crawlers._adaptive_playwright._adaptive_playwright_crawling_context import \
AdaptivePlaywrightCrawlingContext
from crawlee.crawlers._adaptive_playwright._adaptive_playwright_crawling_context import (
AdaptivePlaywrightCrawlingContext,
)

__all__ = [
'AdaptivePlaywrightCrawler',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
BeautifulSoupCrawler,
BeautifulSoupCrawlingContext,
BeautifulSoupParserType,
ContextPipeline,
PlaywrightCrawler,
PlaywrightCrawlingContext,
PlaywrightPreNavCrawlingContext,
Expand Down Expand Up @@ -115,9 +114,6 @@ def __init__(self,
playwright_crawler_args: PlaywrightCrawler only kwargs that are passed to the sub crawler.
kwargs: Additional keyword arguments to pass to the underlying `BasicCrawler`.
"""



# Some sub crawler kwargs are internally modified. Prepare copies.
bs_kwargs = deepcopy(kwargs)
pw_kwargs = deepcopy(kwargs)
Expand Down Expand Up @@ -193,7 +189,6 @@ async def run(
purge_request_queue: If this is `True` and the crawler is not being run for the first time, the default
request queue will be purged.
"""

# TODO: Create something more robust that does not leak implementation so much
async with (self.beautifulsoup_crawler.statistics, self.playwright_crawler.statistics,
self.playwright_crawler._additional_context_managers[0]):
Expand Down Expand Up @@ -249,6 +244,8 @@ async def _run_subcrawler(crawler: BeautifulSoupCrawler | PlaywrightCrawler,

context.log.debug(f'Running browser request handler for {context.request.url}')


# This might not be needed if kvs access is properly routed through results and we commit PW result in the end of the function
kvs = await context.get_key_value_store()
default_value =dict[str, JsonSerializable]()
old_state: dict[str, JsonSerializable] = await kvs.get_value(BasicCrawler.CRAWLEE_STATE_KEY, default_value)
Expand Down
4 changes: 2 additions & 2 deletions src/crawlee/crawlers/_adaptive_playwright/_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ async def main() ->None:

crawler = AdaptivePlaywrightCrawler(max_requests_per_crawl=10,
_logger=top_logger,
playwright_crawler_args={"headless":False})
playwright_crawler_args={'headless':False})

@crawler.router.default_handler
async def request_handler(context: AdaptivePlaywrightCrawlingContext) -> None:
Expand All @@ -27,7 +27,7 @@ async def request_handler(context: AdaptivePlaywrightCrawlingContext) -> None:
context.log.info(f'Processing with Top adaptive_crawler: {context.request.url} ...')
await context.enqueue_links()
await context.push_data({'Top crwaler Url': context.request.url})
await context.use_state({"bla":i})
await context.use_state({'bla':i})

@crawler.pre_navigation_hook_bs
async def bs_hook(context: BasicCrawlingContext) -> None:
Expand Down
9 changes: 2 additions & 7 deletions src/crawlee/crawlers/_basic/_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,12 +574,6 @@ async def add_requests(
wait_for_all_requests_to_be_added_timeout=wait_for_all_requests_to_be_added_timeout,
)

async def _use_state(
self, default_value: dict[str, JsonSerializable] | None = None
) -> dict[str, JsonSerializable]:
store = await self.get_key_value_store()
return await store.get_auto_saved_value(BasicCrawler.CRAWLEE_STATE_KEY, default_value)

async def _save_crawler_state(self) -> None:
store = await self.get_key_value_store()
await store.persist_autosaved_values()
Expand Down Expand Up @@ -951,6 +945,7 @@ async def _commit_request_handler_result(


async def _commit_key_value_store_changes(self, result: RequestHandlerRunResult) -> None:
await result.update_mutated_use_state()
for (id, name), changes in result.key_value_store_changes.items():
store = await self.get_key_value_store(id=id, name=name)
for key, value in changes.updates.items():
Expand Down Expand Up @@ -1011,7 +1006,7 @@ async def __run_task_function(self) -> None:
add_requests=result.add_requests,
push_data=result.push_data,
get_key_value_store=result.get_key_value_store,
use_state=self._use_state,
use_state=result.use_state,
log=self._logger,
)

Expand Down
31 changes: 0 additions & 31 deletions src/crawlee/storages/_key_value_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,37 +182,6 @@ async def get_public_url(self, key: str) -> str:
"""
return await self._resource_client.get_public_url(key)

async def get_auto_saved_value(
self,
key: str,
default_value: dict[str, JsonSerializable] | None = None,
) -> dict[str, JsonSerializable]:
"""Gets a value from KVS that will be automatically saved on changes.
Args:
key: Key of the record, to store the value.
default_value: Value to be used if the record does not exist yet. Should be a dictionary.
Returns:
Returns the value of the key.
"""
default_value = {} if default_value is None else default_value

if key in self._cache:
return self._cache[key]

value = await self.get_value(key, default_value)

if not isinstance(value, dict):
raise TypeError(
f'Expected dictionary for persist state value at key "{key}, but got {type(value).__name__}'
)

self._cache[key] = value

self._ensure_persist_event()

return value

@property
def _cache(self) -> dict[str, dict[str, JsonSerializable]]:
Expand Down
47 changes: 1 addition & 46 deletions tests/unit/storages/test_key_value_store.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import asyncio
from datetime import datetime, timedelta, timezone
from datetime import timedelta
from typing import TYPE_CHECKING
from unittest.mock import patch
from urllib.parse import urlparse
Expand All @@ -14,7 +14,6 @@
if TYPE_CHECKING:
from collections.abc import AsyncGenerator

from crawlee._types import JsonSerializable


@pytest.fixture
Expand Down Expand Up @@ -134,47 +133,3 @@ async def test_get_public_url(key_value_store: KeyValueStore) -> None:
with open(path) as f: # noqa: ASYNC230
content = await asyncio.to_thread(f.read)
assert content == 'static'


async def test_get_auto_saved_value_default_value(key_value_store: KeyValueStore) -> None:
default_value: dict[str, JsonSerializable] = {'hello': 'world'}
value = await key_value_store.get_auto_saved_value('state', default_value)
assert value == default_value


async def test_get_auto_saved_value_cache_value(key_value_store: KeyValueStore) -> None:
default_value: dict[str, JsonSerializable] = {'hello': 'world'}
key_name = 'state'

value = await key_value_store.get_auto_saved_value(key_name, default_value)
value['hello'] = 'new_world'
value_one = await key_value_store.get_auto_saved_value(key_name)
assert value_one == {'hello': 'new_world'}

value_one['hello'] = ['new_world']
value_two = await key_value_store.get_auto_saved_value(key_name)
assert value_two == {'hello': ['new_world']}


async def test_get_auto_saved_value_auto_save(key_value_store: KeyValueStore, mock_event_manager: EventManager) -> None: # noqa: ARG001
# This is not a realtime system and timing constrains can be hard to enforce.
# For the test to avoid flakiness it needs some time tolerance.
autosave_deadline_time = 1
autosave_check_period = 0.01

async def autosaved_within_deadline(key: str, expected_value: dict[str, str]) -> bool:
"""Check if the `key_value_store` of `key` has expected value within `autosave_deadline_time` seconds."""
deadline = datetime.now(tz=timezone.utc) + timedelta(seconds=autosave_deadline_time)
while datetime.now(tz=timezone.utc) < deadline:
await asyncio.sleep(autosave_check_period)
if await key_value_store.get_value(key) == expected_value:
return True
return False

default_value: dict[str, JsonSerializable] = {'hello': 'world'}
key_name = 'state'
value = await key_value_store.get_auto_saved_value(key_name, default_value)
assert await autosaved_within_deadline(key=key_name, expected_value={'hello': 'world'})

value['hello'] = 'new_world'
assert await autosaved_within_deadline(key=key_name, expected_value={'hello': 'new_world'})

0 comments on commit d345259

Please sign in to comment.