Add missing base worker
parent
8633696da0
commit
394dae90f5
|
@ -0,0 +1,79 @@
|
|||
import asyncio
|
||||
import signal
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from app.database import AsyncSession
|
||||
from app.database import async_session
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Worker(Generic[T]):
|
||||
def __init__(self, workers_count: int) -> None:
|
||||
self._loop = asyncio.get_event_loop()
|
||||
self._in_flight: set[int] = set()
|
||||
self._queue: asyncio.Queue[T] = asyncio.Queue(maxsize=1)
|
||||
self._stop_event = asyncio.Event()
|
||||
self._workers_count = workers_count
|
||||
|
||||
async def _consumer(self, db_session: AsyncSession) -> None:
|
||||
while not self._stop_event.is_set():
|
||||
message = await self._queue.get()
|
||||
try:
|
||||
await self.process_message(db_session, message)
|
||||
finally:
|
||||
self._in_flight.remove(message.id) # type: ignore
|
||||
self._queue.task_done()
|
||||
|
||||
async def _producer(self, db_session: AsyncSession) -> None:
|
||||
while not self._stop_event.is_set():
|
||||
next_message = await self.get_next_message(db_session)
|
||||
if next_message:
|
||||
self._in_flight.add(next_message.id) # type: ignore
|
||||
await self._queue.put(next_message)
|
||||
else:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def process_message(self, db_session: AsyncSession, message: T) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_next_message(self, db_session: AsyncSession) -> T | None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def startup(self, db_session: AsyncSession) -> None:
|
||||
return None
|
||||
|
||||
def in_flight_ids(self) -> set[int]:
|
||||
return self._in_flight
|
||||
|
||||
async def run_forever(self) -> None:
|
||||
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
|
||||
for s in signals:
|
||||
self._loop.add_signal_handler(
|
||||
s,
|
||||
lambda s=s: asyncio.create_task(self._shutdown(s)),
|
||||
)
|
||||
|
||||
async with async_session() as db_session:
|
||||
await self.startup(db_session)
|
||||
self._loop.create_task(self._producer(db_session))
|
||||
for _ in range(self._workers_count):
|
||||
self._loop.create_task(self._consumer(db_session))
|
||||
|
||||
await self._stop_event.wait()
|
||||
logger.info("Waiting for tasks to finish")
|
||||
await self._queue.join()
|
||||
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||
logger.info(f"Cancelling {len(tasks)} tasks")
|
||||
[task.cancel() for task in tasks]
|
||||
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
logger.info("stopping loop")
|
||||
|
||||
async def _shutdown(self, sig: signal.Signals) -> None:
|
||||
logger.info(f"Caught {signal=}")
|
||||
self._stop_event.set()
|
Loading…
Reference in New Issue