from __future__ import annotations
import contextlib
import datetime
import json
import logging
import uuid
from dataclasses import dataclass
from importlib.resources import read_text
from typing import Any
from typing import AsyncIterator
import redis
from redis import asyncio as aioredis
from pyncette.errors import PyncetteException
from pyncette.model import ContinuationToken
from pyncette.model import Lease
from pyncette.model import PollResponse
from pyncette.model import QueryResponse
from pyncette.model import ResultType
from pyncette.repository import Repository
from pyncette.task import Task
logger = logging.getLogger(__name__)
_CONTINUATION_TOKEN = ContinuationToken(object())
class _LuaScript:
"""A wrapper for Redis lua scripts that automaticaly reloads it if e.g. SCRIPT FLUSH is invoked"""
_script: str
_sha: str | None
def __init__(self, script_path: str):
self._script = read_text(__name__, script_path)
self._sha = None
async def register(self, client: aioredis.Redis) -> None:
self._sha = await client.script_load(self._script)
async def execute(
self,
client: aioredis.Redis,
keys: list[Any] | None = None,
args: list[Any] | None = None,
) -> Any:
if self._sha is None:
await self.register(client)
keys = keys or []
args = args or []
for _ in range(3):
try:
return await client.evalsha(self._sha, len(keys), *keys, *args)
except redis.exceptions.NoScriptError:
logger.warning("We seem to have lost the LUA script, reloading...")
await self.register(client)
raise PyncetteException("Could not reload the Lua script.")
@dataclass
class _ManageScriptResponse:
result: ResultType
version: int
execute_after: datetime.datetime | None
locked_until: datetime.datetime | None
task_spec: dict[str, Any] | None
locked_by: str | None
@classmethod
def from_response(cls, response: list[bytes]) -> _ManageScriptResponse:
return cls(
result=ResultType[response[0].decode()],
version=int(response[1] or 0),
execute_after=None if response[2] is None else datetime.datetime.fromisoformat(response[2].decode()),
locked_until=None if response[3] is None else datetime.datetime.fromisoformat(response[3].decode()),
locked_by=None if response[4] is None else response[4].decode(),
task_spec=None if response[5] is None else json.loads(response[5]),
)
def _create_dynamic_task(task: Task, response_data: list[bytes]) -> tuple[Task, Lease]:
task_data = _ManageScriptResponse.from_response(response_data)
assert task_data.task_spec is not None
return (task.instantiate_from_spec(task_data.task_spec), Lease(task_data))
[docs]class RedisRepository(Repository):
"""Redis-backed store for Pyncete task execution data"""
_redis_client: aioredis.Redis
_namespace: str
_manage_script: _LuaScript
_poll_dynamic_script: _LuaScript
def __init__(self, redis_client: aioredis.Redis, **kwargs: Any):
self._redis_client = redis_client
self._namespace = kwargs.get("redis_namespace", "")
self._batch_size = kwargs.get("batch_size", 100)
self._poll_dynamic_script = _LuaScript("poll_dynamic.lua")
self._manage_script = _LuaScript("manage.lua")
if self._batch_size < 1:
raise ValueError("Batch size must be greater than 0")
[docs] async def register_scripts(self) -> None:
"""Registers the Lua scripts used by the implementation ahead of time"""
await self._poll_dynamic_script.register(self._redis_client)
await self._manage_script.register(self._redis_client)
[docs] async def poll_dynamic_task(
self,
utc_now: datetime.datetime,
task: Task,
continuation_token: ContinuationToken | None = None,
) -> QueryResponse:
new_locked_until = utc_now + task.lease_duration
response = await self._poll_dynamic_script.execute(
self._redis_client,
keys=[self._get_task_index_key(task)],
args=[
utc_now.isoformat(),
self._batch_size,
new_locked_until.isoformat(),
str(uuid.uuid4()),
],
)
logger.debug(f"query_lua script returned [{self._batch_size}] {response}")
return QueryResponse(
tasks=[_create_dynamic_task(task, response_data) for response_data in response[1:]],
continuation_token=_CONTINUATION_TOKEN if response[0] == b"HAS_MORE" else None,
)
[docs] async def register_task(self, utc_now: datetime.datetime, task: Task) -> None:
execute_after = task.get_next_execution(utc_now, None)
assert execute_after is not None
await self._manage_record(
task,
"REGISTER",
execute_after.isoformat(),
json.dumps(task.as_spec()),
)
[docs] async def unregister_task(self, utc_now: datetime.datetime, task: Task) -> None:
await self._manage_record(task, "UNREGISTER")
[docs] async def poll_task(self, utc_now: datetime.datetime, task: Task, lease: Lease | None = None) -> PollResponse:
# Nominally, we need at least two round-trips to Redis since the next execute_after is calculated
# in Python code due to extra flexibility. This is why we have optimistic locking below to ensure that
# the next execution time was calculated using a correct base if another process modified it in between.
# In most cases, however, we can assume that the base time has not changed since the last invocation,
# so by caching it, we can poll a task using a single round-trip (if we are wrong, the loop below will still
# ensure correctness as the version will not match).
last_lease = getattr(task, "_last_lease", None)
if isinstance(lease, _ManageScriptResponse):
version, execute_after, locked_by = (
lease.version,
lease.execute_after,
lease.locked_by,
)
elif last_lease is not None:
logger.debug("Using cached values for execute_after")
version, execute_after, locked_by = (
last_lease.version,
last_lease.execute_after,
str(uuid.uuid4()),
)
else:
# By default we assume that the task is brand new
version, execute_after, locked_by = (
0,
None,
str(uuid.uuid4()),
)
new_locked_until = utc_now + task.lease_duration
for _ in range(5):
next_execution = task.get_next_execution(utc_now, execute_after)
response = await self._manage_record(
task,
"POLL",
task.execution_mode.name,
"REGULAR" if task.parent_task is None else "DYNAMIC",
utc_now.isoformat(),
version,
next_execution.isoformat() if next_execution is not None else "",
new_locked_until.isoformat(),
locked_by,
)
task._last_lease = response # type: ignore
if response.result == ResultType.LEASE_MISMATCH:
logger.debug("Lease mismatch, retrying.")
execute_after = response.execute_after
version = response.version
elif response.result == ResultType.MISSING:
raise PyncetteException("Task not found")
else:
return PollResponse(
result=response.result,
scheduled_at=execute_after,
lease=Lease(response),
)
raise PyncetteException("Unable to acquire the lock on the task due to contention")
[docs] async def commit_task(self, utc_now: datetime.datetime, task: Task, lease: Lease) -> None:
assert isinstance(lease, _ManageScriptResponse)
next_execution = task.get_next_execution(utc_now, lease.execute_after)
response = await self._manage_record(
task,
"COMMIT",
lease.version,
lease.locked_by,
next_execution.isoformat() if next_execution is not None else "",
)
task._last_lease = response # type: ignore
if response.result == ResultType.LEASE_MISMATCH:
logger.info("Not commiting, as we have lost the lease")
[docs] async def unlock_task(self, utc_now: datetime.datetime, task: Task, lease: Lease) -> None:
assert isinstance(lease, _ManageScriptResponse)
response = await self._manage_record(task, "UNLOCK", lease.version, lease.locked_by)
task._last_lease = response # type: ignore
if response.result == ResultType.LEASE_MISMATCH:
logger.info("Not unlocking, as we have lost the lease")
[docs] async def extend_lease(self, utc_now: datetime.datetime, task: Task, lease: Lease) -> Lease | None:
assert isinstance(lease, _ManageScriptResponse)
new_locked_until = utc_now + task.lease_duration
response = await self._manage_record(task, "EXTEND", lease.version, lease.locked_by, new_locked_until.isoformat())
task._last_lease = response # type: ignore
if response.result == ResultType.READY:
return Lease(response)
else:
return None
async def _manage_record(self, task: Task, *args: Any) -> _ManageScriptResponse:
response = await self._manage_script.execute(
self._redis_client,
keys=[
self._get_task_record_key(task),
self._get_task_index_key(task.parent_task),
],
args=list(args),
)
logger.debug(f"manage_lua script returned {response}")
return _ManageScriptResponse.from_response(response)
def _get_task_record_key(self, task: Task) -> str:
return f"pyncette:{self._namespace}:task:{task.canonical_name}"
def _get_task_index_key(self, task: Task | None) -> str:
# A prefix-coded index key, so there are no restrictions on task names.
index_name = f"index:{task.canonical_name}" if task else "index"
return f"pyncette:{self._namespace}:{index_name}"
[docs]@contextlib.asynccontextmanager
async def redis_repository(**kwargs: Any) -> AsyncIterator[RedisRepository]:
"""Factory context manager for Redis repository that initializes the connection to Redis"""
if not isinstance(kwargs["redis_url"], str):
raise PyncetteException("Redis URL is required")
async with aioredis.from_url(kwargs["redis_url"]) as redis_pool:
repository = RedisRepository(redis_pool, **kwargs)
await repository.register_scripts()
yield repository