74 lines
2.4 KiB
Python
74 lines
2.4 KiB
Python
import asyncio
|
|
import signal
|
|
from typing import Generic
|
|
from typing import TypeVar
|
|
|
|
from loguru import logger
|
|
|
|
from app.database import AsyncSession
|
|
from app.database import async_session
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class Worker(Generic[T]):
|
|
def __init__(self) -> None:
|
|
self._loop = asyncio.get_event_loop()
|
|
self._stop_event = asyncio.Event()
|
|
|
|
async def process_message(self, db_session: AsyncSession, message: T) -> None:
|
|
raise NotImplementedError
|
|
|
|
async def get_next_message(self, db_session: AsyncSession) -> T | None:
|
|
raise NotImplementedError
|
|
|
|
async def startup(self, db_session: AsyncSession) -> None:
|
|
return None
|
|
|
|
async def _main_loop(self, db_session: AsyncSession) -> None:
|
|
while not self._stop_event.is_set():
|
|
next_message = await self.get_next_message(db_session)
|
|
if next_message:
|
|
await self.process_message(db_session, next_message)
|
|
await asyncio.sleep(0.5)
|
|
else:
|
|
await asyncio.sleep(2)
|
|
|
|
async def _until_stopped(self) -> None:
|
|
await self._stop_event.wait()
|
|
|
|
async def run_forever(self) -> None:
|
|
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
|
|
for s in signals:
|
|
self._loop.add_signal_handler(
|
|
s,
|
|
lambda s=s: asyncio.create_task(self._shutdown(s)),
|
|
)
|
|
|
|
async with async_session() as db_session:
|
|
await self.startup(db_session)
|
|
task = self._loop.create_task(self._main_loop(db_session))
|
|
stop_task = self._loop.create_task(self._until_stopped())
|
|
|
|
done, pending = await asyncio.wait(
|
|
{task, stop_task}, return_when=asyncio.FIRST_COMPLETED
|
|
)
|
|
logger.info(f"Waiting for tasks to finish {done=}/{pending=}")
|
|
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
|
logger.info(f"Cancelling {len(tasks)} tasks")
|
|
[task.cancel() for task in tasks]
|
|
|
|
try:
|
|
await asyncio.wait_for(
|
|
asyncio.gather(*tasks, return_exceptions=True),
|
|
timeout=15,
|
|
)
|
|
except asyncio.TimeoutError:
|
|
logger.info("Tasks failed to cancel")
|
|
|
|
logger.info("stopping loop")
|
|
|
|
async def _shutdown(self, sig: signal.Signals) -> None:
|
|
logger.info(f"Caught {sig=}")
|
|
self._stop_event.set()
|