from __future__ import annotations

import contextlib
import copy
import datetime
import json
import logging
import uuid
from dataclasses import dataclass
from typing import Any
from typing import AsyncIterator
from typing import cast

import aioboto3
from boto3.dynamodb.conditions import Attr
from boto3.dynamodb.conditions import Key
from botocore.exceptions import ClientError

from pyncette.errors import PyncetteException
from pyncette.model import ContinuationToken
from pyncette.model import ExecutionMode
from pyncette.model import Lease
from pyncette.model import PollResponse
from pyncette.model import QueryResponse
from pyncette.model import ResultType
from pyncette.repository import Repository
from pyncette.task import Task

logger = logging.getLogger(__name__)


class _TaskRecord:
    execute_after: datetime.datetime | None
    locked_until: datetime.datetime | None
    locked_by: str | None
    version: int

    def _parse_date(s: str | None) -> datetime.datetime | None:
        return datetime.datetime.fromisoformat(s) if s else None

    def from_dynamo_response(cls, response: Any) -> _TaskRecord | None:
        if "Item" not in response:
            return None

        return cls.from_dynamo_item(response["Item"])

    def from_dynamo_item(cls, record: dict[str, Any]) -> _TaskRecord:
        return cls(

[docs]class DynamoDBRepository(Repository): """Redis-backed store for Pyncete task execution data""" _dynamo_resource: Any _table_name: str _batch_size: int _skip_table_create: bool _partition_prefix: str def __init__( self, dynamo_resource: Any, skip_table_create: bool, partition_prefix: str, **kwargs: Any, ): self._dynamo_resource = dynamo_resource self._table_name = kwargs.get("dynamodb_table_name", "") self._batch_size = kwargs.get("batch_size", 100) self._skip_table_create = skip_table_create self._partition_prefix = partition_prefix if self._batch_size < 1: raise ValueError("Batch size must be greater than 0")
[docs] async def poll_dynamic_task( self, utc_now: datetime.datetime, task: Task, continuation_token: ContinuationToken | None = None, ) -> QueryResponse: response = await self._table.query( IndexName="ready_at", Select="ALL_ATTRIBUTES", Limit=self._batch_size, KeyConditionExpression=Key("partition_id").eq(self._get_partition_id(task)) & Key("ready_at").lt(f"{utc_now.isoformat()}`"), **({"ExclusiveStartKey": continuation_token} if continuation_token is not None else {}), ) return QueryResponse( tasks=[ ( task.instantiate_from_spec(json.loads(record["task_spec"])), Lease(_TaskRecord.from_dynamo_item(record)), ) for record in response["Items"] ], continuation_token=response.get("LastEvaluatedKey", None), )
[docs] async def register_task(self, utc_now: datetime.datetime, task: Task) -> None: execute_after = task.get_next_execution(utc_now, None) assert execute_after is not None await self._table.put_item( Item={ "partition_id": self._get_partition_id(task), "task_id": task.canonical_name, "task_spec": json.dumps(task.as_spec()), "version": 0, "locked_by": None, "locked_until": None, "execute_after": execute_after.isoformat(), "ready_at": f"{execute_after.isoformat()}_{task.canonical_name}", }, )
[docs] async def unregister_task(self, utc_now: datetime.datetime, task: Task) -> None: await self._table.delete_item( Key={ "partition_id": self._get_partition_id(task), "task_id": task.canonical_name, } )
def _get_partition_id(self, task: Task) -> str: prefix = (self._partition_prefix or "").replace(":", "::") if task.parent_task: return f"{prefix}:{task.parent_task.canonical_name}" else: return f"{prefix}:{task.canonical_name}" async def _retreive_item(self, task: Task, consistent_read: bool = False) -> _TaskRecord | None: result = _TaskRecord.from_dynamo_response( await self._table.get_item( Key={ "partition_id": self._get_partition_id(task), "task_id": task.canonical_name, }, ConsistentRead=consistent_read, ) ) return result async def _update_item(self, task: Task, record: _TaskRecord) -> bool: current_version = record.version try: if record.execute_after is None: await self._table.delete_item( Key={ "partition_id": self._get_partition_id(task), "task_id": task.canonical_name, }, ConditionExpression=(Attr("version").not_exists() | Attr("version").eq(0)) if current_version == 0 else Attr("version").eq(current_version), ) task._last_lease = None # type: ignore else: ready_at = max(record.execute_after, record.locked_until) if record.locked_until is not None else record.execute_after await self._table.update_item( Key={ "partition_id": self._get_partition_id(task), "task_id": task.canonical_name, }, UpdateExpression=""" set execute_after=:execute_after, locked_until=:locked_until, locked_by=:locked_by, ready_at=:ready_at, version=:version """, ExpressionAttributeValues={ ":execute_after": record.execute_after.isoformat() if record.execute_after is not None else "", ":locked_until": record.locked_until.isoformat() if record.locked_until is not None else "", ":locked_by": record.locked_by, # Add a suffix to ready_at to guarantee uniqueness of the secondary index ":ready_at": f"{ready_at.isoformat()}_{task.canonical_name}", ":version": current_version + 1, }, ConditionExpression=(Attr("version").not_exists() | Attr("version").eq(0)) if current_version == 0 else Attr("version").eq(current_version), ) record.version = current_version + 1 task._last_lease = record # type: ignore except ClientError as e: if e.response["Error"]["Code"] == "ConditionalCheckFailedException": return False else: raise else: return True
[docs] async def poll_task(self, utc_now: datetime.datetime, task: Task, lease: Lease | None = None) -> PollResponse: last_lease = getattr(task, "_last_lease", None) # Similar logic as in Redis repository. If we have previously processed this # task in any manner, we try to reuse the latest state of the task we have at hand # from cache (or lease) to avoid two roundtrips to DynamoDB in the optimistic case. # If we are wrong, we will get a version mismatch, whereby we will load the current # state from DB. # However, in case the task would be PENDING or LOCKED, this will result in no requests # being made to the DB at all. For pending this is OK, but for locked, we want to revalidate # in order to be able to execute the task as soon as it is unlocked. record: _TaskRecord | None potentially_stale = False if lease is not None: assert isinstance(lease, _TaskRecord) record = cast(_TaskRecord, copy.copy(lease)) elif last_lease is not None: logger.debug("Using cached values for last lease") assert isinstance(last_lease, _TaskRecord) record = cast(_TaskRecord, copy.copy(last_lease)) potentially_stale = True else: record = await self._retreive_item(task) for _ in range(MAX_OPTIMISTIC_RETRY_COUNT): must_revalidate = False update = False if record is None: if task.parent_task is not None: raise PyncetteException("Task not found") record = _TaskRecord( execute_after=task.get_next_execution(utc_now, None), locked_until=None, locked_by=None, version=0, ) update = True assert record.execute_after is not None scheduled_at = record.execute_after if record.locked_until is not None and record.locked_until > utc_now and (lease is None or lease.locked_by != record.locked_by): result = ResultType.LOCKED if potentially_stale: must_revalidate = True elif record.execute_after <= utc_now and task.execution_mode == ExecutionMode.AT_MOST_ONCE: result = ResultType.READY record.execute_after = task.get_next_execution(utc_now, record.execute_after) record.locked_until = None record.locked_by = None update = True elif record.execute_after <= utc_now and task.execution_mode == ExecutionMode.AT_LEAST_ONCE: result = ResultType.READY record.locked_until = utc_now + task.lease_duration record.locked_by = str(uuid.uuid4()) update = True else: result = ResultType.PENDING if must_revalidate or (update and not await self._update_item(task, record)): logger.debug("Using cached values for last lease") record = await self._retreive_item(task, consistent_read=True) potentially_stale = False continue return PollResponse(result=result, scheduled_at=scheduled_at, lease=Lease(record)) raise PyncetteException("Unable to acquire the lock on the task due to contention")
[docs] async def commit_task(self, utc_now: datetime.datetime, task: Task, lease: Lease) -> None: assert isinstance(lease, _TaskRecord) record: _TaskRecord | None = cast(_TaskRecord, copy.copy(lease)) for _ in range(MAX_OPTIMISTIC_RETRY_COUNT): if not record: logger.warning(f"Task {task} no longer exists, skipping.") return if record.locked_by != lease.locked_by: logger.warning(f"Lease lost on task {task}, skipping.") return record.execute_after = task.get_next_execution(utc_now, lease.execute_after) record.locked_by = None record.locked_until = None if await self._update_item(task, record): return # If the update fails due to version mismatch record = await self._retreive_item(task, consistent_read=True) raise PyncetteException("Unable to acquire the lock on the task due to contention")
[docs] async def unlock_task(self, utc_now: datetime.datetime, task: Task, lease: Lease) -> None: assert isinstance(lease, _TaskRecord) record: _TaskRecord | None = cast(_TaskRecord, copy.copy(lease)) for _ in range(MAX_OPTIMISTIC_RETRY_COUNT): if not record: logger.warning(f"Task {task} no longer exists, skipping.") return if record.locked_by != lease.locked_by: logger.warning(f"Lease lost on task {task}, skipping.") return record.locked_by = None record.locked_until = None if await self._update_item(task, record): return # If the update fails due to version mismatch record = await self._retreive_item(task, consistent_read=True) raise PyncetteException("Unable to acquire the lock on the task due to contention")
[docs] async def extend_lease(self, utc_now: datetime.datetime, task: Task, lease: Lease) -> Lease | None: assert isinstance(lease, _TaskRecord) record: _TaskRecord | None = cast(_TaskRecord, copy.copy(lease)) for _ in range(MAX_OPTIMISTIC_RETRY_COUNT): if not record: logger.warning(f"Task {task} no longer exists, skipping.") return None if record.locked_by != lease.locked_by: logger.warning(f"Lease lost on task {task}, skipping.") return None record.locked_until = utc_now + task.lease_duration if await self._update_item(task, record): return Lease(record) # If the update fails due to version mismatch record = await self._retreive_item(task, consistent_read=True) raise PyncetteException("Unable to acquire the lock on the task due to contention")
[docs] async def initialize(self) -> None: self._table = await self._dynamo_resource.Table(self._table_name) if not self._skip_table_create: try: await self._table.load()"Table already exists...") except ClientError as e: if e.response["Error"]["Code"] == "ResourceNotFoundException": self._table = await self._dynamo_resource.create_table( TableName=self._table_name, KeySchema=[ {"AttributeName": "partition_id", "KeyType": "HASH"}, {"AttributeName": "task_id", "KeyType": "RANGE"}, ], AttributeDefinitions=[ {"AttributeName": "partition_id", "AttributeType": "S"}, {"AttributeName": "task_id", "AttributeType": "S"}, {"AttributeName": "ready_at", "AttributeType": "S"}, ], LocalSecondaryIndexes=[ { "IndexName": "ready_at", "KeySchema": [ { "AttributeName": "partition_id", "KeyType": "HASH", }, {"AttributeName": "ready_at", "KeyType": "RANGE"}, ], "Projection": {"ProjectionType": "ALL"}, } ], BillingMode="PAY_PER_REQUEST", )"Waiting until table is created...") await self._table.wait_until_exists() else: raise
[docs]@contextlib.asynccontextmanager async def dynamodb_repository( *, dynamodb_endpoint: str | None = None, dynamodb_region_name: str | None = None, dynamodb_skip_table_create: bool = False, dynamodb_partition_prefix: str = "", **kwargs: Any, ) -> AsyncIterator[DynamoDBRepository]: """Factory context manager for Redis repository that initializes the connection to Redis""" session = aioboto3.Session() async with session.resource( "dynamodb", region_name=dynamodb_region_name, endpoint_url=dynamodb_endpoint, ) as dynamo_resource: repository = DynamoDBRepository( dynamo_resource, skip_table_create=dynamodb_skip_table_create, partition_prefix=dynamodb_partition_prefix, **kwargs, ) await repository.initialize() yield repository