diff --git a/dch_webhooks.py b/dch_webhooks.py index 64488ce..d3ea6b8 100644 --- a/dch_webhooks.py +++ b/dch_webhooks.py @@ -1,6 +1,9 @@ +import asyncio import base64 import datetime +import functools import importlib.metadata +import json import logging import os import re @@ -9,8 +12,12 @@ from typing import Optional, Self, Type import fastapi import httpx +import pika +import pika.channel import pydantic import pyrfc6266 +from fastapi import Form +from pika.adapters.asyncio_connection import AsyncioConnection __all__ = [ @@ -19,6 +26,8 @@ __all__ = [ log = logging.getLogger(__name__) +context: 'Context' + DIST = importlib.metadata.metadata(__name__) DESCRIPTION_CLEAN_PATTERN = re.compile('[^a-z ]') @@ -257,6 +266,120 @@ class Paperless(HttpxClientMixin): return docs +class AMQPError(Exception): ... + + +class AMQPContext: + def __init__(self) -> None: + self._conn_params = pika.ConnectionParameters() + self._connection: Optional[AsyncioConnection] = None + self._channel: Optional[pika.channel.Channel] = None + + def close(self): + if self._channel: + self._channel.close() + self._channel = None + if self._connection: + self._connection.close() + self._connection = None + + async def connect(self): + loop = asyncio.get_running_loop() + if self._connection is None: + fut = loop.create_future() + self._connection = AsyncioConnection( + self._conn_params, + on_open_callback=functools.partial(self._on_open, fut), + on_open_error_callback=functools.partial( + self._on_open_error, fut + ), + on_close_callback=self._on_close, + ) + await fut + if self._channel is None: + fut = loop.create_future() + self._connection.channel( + on_open_callback=functools.partial(self._on_channel_open, fut), + ) + self._channel = await fut + + def publish(self, exchange: str, routing_key: str, body: bytes): + assert self._channel + self._channel.basic_publish( + exchange, + routing_key, + body, + ) + + async def queue_declare( + self, + name: str, + passive: bool = False, + durable: bool = False, + exclusive: bool = False, + auto_delete: bool = False, + arguments: Optional[dict[str, str]] = None, + ) -> None: + loop = asyncio.get_event_loop() + fut = loop.create_future() + assert self._channel + self._channel.queue_declare( + name, + passive, + durable, + exclusive, + auto_delete, + arguments, + callback=lambda _: fut.set_result(None), + ) + await fut + + def _on_open(self, fut: asyncio.Future[None], _conn): + log.info('AMQP connection open') + fut.set_result(None) + + def _on_open_error(self, fut: asyncio.Future[None], _conn, error): + log.error('Failed to open AMQP connection: %s', error) + self._connection = None + fut.set_exception(AMQPError(error)) + + def _on_close(self, _conn, reason): + level = logging.INFO if reason.reply_code == 200 else logging.WARNING + if log.isEnabledFor(level): + log.log( + level, + 'AMQP connection closed: %s (code %d)', + reason.reply_text, + reason.reply_code, + ) + self._connection = None + + def _on_channel_open( + self, + fut: asyncio.Future[pika.channel.Channel], + chan: pika.channel.Channel, + ) -> None: + chan.add_on_close_callback(self._on_channel_close) + fut.set_result(chan) + + def _on_channel_close(self, _chan, reason): + level = logging.INFO if reason.reply_code == 0 else logging.WARNING + if log.isEnabledFor(level): + log.log( + level, + 'AMQP channel closed: %s (code %d)', + reason.reply_text, + reason.reply_code, + ) + self._channel = None + + +class Context: + + def __init__(self): + self.amqp = AMQPContext() + + async def handle_firefly_transaction(xact: FireflyIIITransaction) -> None: try: xact0 = xact.transactions[0] @@ -377,6 +500,34 @@ def rfc2047_base64encode( return f"=?UTF-8?B?{encoded}?=" +async def start_ansible_job(): ... + + +async def publish_host_info(hostname: str, sshkeys: str): + await context.amqp.connect() + await context.amqp.queue_declare('host-provision', durable=True) + context.amqp.publish( + exchange='', + routing_key='host-provision', + body=json.dumps( + { + 'hostname': hostname, + 'sshkeys': sshkeys, + }, + ).encode('utf-8'), + ) + + +async def handle_host_online(hostname: str, sshkeys: str): + try: + await publish_host_info(hostname, sshkeys) + except asyncio.CancelledError: + raise + except Exception: + log.exception('Failed to publish host info:') + return + + app = fastapi.FastAPI( name=DIST['Name'], version=DIST['Version'], @@ -391,6 +542,14 @@ def on_start() -> None: h.setLevel(logging.DEBUG) log.addHandler(h) + global context + context = Context() + + +@app.on_event('shutdown') +def on_shutdown() -> None: + context.amqp.close() + @app.get('/') def status() -> str: @@ -443,3 +602,12 @@ async def jenkins_notify(request: fastapi.Request) -> None: await ntfy(message, 'jenkins', title, tag, cache=False) except httpx.HTTPError: log.exception('Failed to send notification:') + + +@app.post( + '/host/online', + status_code=fastapi.status.HTTP_202_ACCEPTED, + response_class=fastapi.responses.PlainTextResponse, +) +async def host_online(hostname: str = Form(), sshkeys: str = Form()) -> None: + asyncio.create_task(handle_host_online(hostname, sshkeys)) diff --git a/pyproject.toml b/pyproject.toml index db1f609..7fe3701 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,9 @@ classifiers = [ dependencies = [ "fastapi~=0.97.0", "httpx~=0.24.1", + "pika>=1.3.2", "pyrfc6266~=1.0.2", + "python-multipart>=0.0.20", ] dynamic = ["version"] @@ -57,3 +59,8 @@ check_untyped_defs = true disallow_untyped_decorators = true no_implicit_optional = true warn_return_any = true + +[dependency-groups] +dev = [ + "uvicorn>=0.34.0", +]