Source code for pyncette.postgres

import contextlib
import datetime
import json
import logging
import re
import uuid
from contextlib import asynccontextmanager
from typing import Any
from typing import AsyncIterator
from typing import Optional

import asyncpg

from pyncette.errors import PyncetteException
from pyncette.model import ContinuationToken
from pyncette.model import ExecutionMode
from pyncette.model import Lease
from pyncette.model import PollResponse
from pyncette.model import QueryResponse
from pyncette.model import ResultType
from pyncette.repository import Repository
from pyncette.task import Task

logger = logging.getLogger(__name__)


_CONTINUATION_TOKEN = ContinuationToken(object())


[docs]class PostgresRepository(Repository): _pool: asyncpg.pool.Pool _batch_size: int _table_name: str def __init__( self, pool: asyncpg.pool.Pool, **kwargs: Any, ): self._pool = pool self._table_name = kwargs.get("postgres_table_name", "pyncette_tasks") self._batch_size = kwargs.get("batch_size", 100) if self._batch_size < 1: raise ValueError("Batch size must be greater than 0") if not re.match(r"^[a-z_]+$", self._table_name): raise ValueError("Table name can only contain lower-case letters and underscores")
[docs] async def initialize(self) -> None: async with self._transaction() as connection: await connection.execute( f""" CREATE TABLE IF NOT EXISTS {self._table_name} ( name text PRIMARY KEY, parent_name text, locked_until timestamptz, locked_by uuid, execute_after timestamptz, task_spec json ); CREATE INDEX IF NOT EXISTS due_tasks_{self._table_name} ON {self._table_name} (parent_name, GREATEST(locked_until, execute_after)); """ )
[docs] async def poll_dynamic_task( self, utc_now: datetime.datetime, task: Task, continuation_token: Optional[ContinuationToken] = None, ) -> QueryResponse: async with self._transaction() as connection: locked_by = uuid.uuid4() locked_until = utc_now + task.lease_duration ready_tasks = await connection.fetch( f""" UPDATE {self._table_name} a SET locked_until = $4, locked_by = $5 FROM ( SELECT name FROM {self._table_name} WHERE parent_name = $1 AND GREATEST(locked_until, execute_after) <= $2 ORDER BY GREATEST(locked_until, execute_after) ASC LIMIT $3 FOR UPDATE SKIP LOCKED ) b WHERE a.name = b.name RETURNING * """, task.canonical_name, utc_now, self._batch_size, locked_until, locked_by, ) logger.debug(f"poll_dynamic_task returned {ready_tasks}") return QueryResponse( tasks=[ ( task.instantiate_from_spec(json.loads(record["task_spec"])), Lease(locked_by), ) for record in ready_tasks ], # May result in an extra round-trip if there were exactly # batch_size tasks available, but we deem this an acceptable # tradeoff. continuation_token=_CONTINUATION_TOKEN if len(ready_tasks) == self._batch_size else None, )
[docs] async def register_task(self, utc_now: datetime.datetime, task: Task) -> None: assert task.parent_task is not None async with self._transaction() as connection: result = await connection.execute( f""" INSERT INTO {self._table_name} (name, parent_name, task_spec, execute_after) VALUES ($1, $2, $3, $4) ON CONFLICT (name) DO UPDATE SET task_spec = $3, execute_after = $4, locked_by = NULL, locked_until = NULL """, task.canonical_name, task.parent_task.canonical_name, json.dumps(task.as_spec()), task.get_next_execution(utc_now, None), ) logger.debug(f"register_task returned {result}")
[docs] async def unregister_task(self, utc_now: datetime.datetime, task: Task) -> None: async with self._transaction() as connection: await connection.execute(f"DELETE FROM {self._table_name} WHERE name = $1", task.canonical_name)
[docs] async def poll_task(self, utc_now: datetime.datetime, task: Task, lease: Optional[Lease] = None) -> PollResponse: async with self._transaction() as connection: record = await connection.fetchrow( f"SELECT * FROM {self._table_name} WHERE name = $1 FOR UPDATE", task.canonical_name, ) logger.debug(f"poll_task returned {record}") update = False if record is None: # Regular (non-dynamic) tasks will be implicitly created on first poll, # but dynamic task instances must be explicitely created to prevent spurious # poll from re-creating them after being deleted. if task.parent_task is not None: raise PyncetteException("Task not found") execute_after = task.get_next_execution(utc_now, None) locked_until = None locked_by = None update = True else: execute_after = record["execute_after"] locked_until = record["locked_until"] locked_by = record["locked_by"] assert execute_after is not None scheduled_at = execute_after if locked_until is not None and locked_until > utc_now and (lease != locked_by): result = ResultType.LOCKED elif execute_after <= utc_now and task.execution_mode == ExecutionMode.AT_MOST_ONCE: execute_after = task.get_next_execution(utc_now, execute_after) result = ResultType.READY locked_until = None locked_by = None update = True elif execute_after <= utc_now and task.execution_mode == ExecutionMode.AT_LEAST_ONCE: locked_until = utc_now + task.lease_duration locked_by = uuid.uuid4() result = ResultType.READY update = True else: result = ResultType.PENDING if update: await self._update_record( connection, task, locked_until, locked_by, execute_after, ) return PollResponse(result=result, scheduled_at=scheduled_at, lease=locked_by)
[docs] async def commit_task(self, utc_now: datetime.datetime, task: Task, lease: Lease) -> None: async with self._transaction() as connection: record = await connection.fetchrow( f"SELECT * FROM {self._table_name} WHERE name = $1 FOR UPDATE", task.canonical_name, ) logger.debug(f"commit_task returned {record}") if not record: logger.warning(f"Task {task} not found, skipping.") return if record["locked_by"] != lease: logger.warning(f"Lease lost on task {task}, skipping.") return await self._update_record( connection, task, None, None, task.get_next_execution(utc_now, record["execute_after"]), )
[docs] async def extend_lease(self, utc_now: datetime.datetime, task: Task, lease: Lease) -> Optional[Lease]: async with self._transaction() as connection: locked_until = utc_now + task.lease_duration result = await connection.execute( f""" UPDATE {self._table_name} SET locked_until = $1 WHERE name = $2 AND locked_by = $3 """, locked_until, task.canonical_name, lease, ) logger.debug(f"extend_lease returned {result}") if result == "UPDATE 1": return lease else: return None
[docs] async def unlock_task(self, utc_now: datetime.datetime, task: Task, lease: Lease) -> None: async with self._transaction() as connection: result = await connection.execute( f""" UPDATE {self._table_name} SET locked_by = NULL, locked_until = NULL WHERE name = $1 AND locked_by = $2 """, task.canonical_name, lease, ) logger.debug(f"unlock_task returned {result}")
@asynccontextmanager async def _transaction(self) -> AsyncIterator[asyncpg.Connection]: async with self._pool.acquire() as connection: async with connection.transaction(): yield connection async def _update_record( self, connection: asyncpg.Connection, task: Task, locked_until: Optional[datetime.datetime], locked_by: Optional[uuid.UUID], execute_after: Optional[datetime.datetime], ) -> None: if execute_after is None: result = await connection.execute(f"DELETE FROM {self._table_name} WHERE name = $1", task.canonical_name) else: result = await connection.execute( f""" INSERT INTO {self._table_name} (name, locked_until, locked_by, execute_after) VALUES ($1, $2, $3, $4) ON CONFLICT (name) DO UPDATE SET locked_until = $2, locked_by = $3, execute_after = $4 """, task.canonical_name, locked_until, locked_by, execute_after, ) logger.debug(f"update_record returned {result}")
[docs]@contextlib.asynccontextmanager async def postgres_repository(**kwargs: Any) -> AsyncIterator[PostgresRepository]: """Factory context manager for repository that initializes the connection to Postgres""" postgres_pool = await asyncpg.create_pool(kwargs["postgres_url"]) try: repository = PostgresRepository(postgres_pool, **kwargs) if not kwargs.get("postgres_skip_table_create", False): await repository.initialize() yield repository finally: await postgres_pool.close()