1
0
Fork 0
dch-webhooks/dch_webhooks.py

747 lines
22 KiB
Python

import asyncio
import base64
import datetime
import functools
import importlib.metadata
import json
import logging
import os
import re
import ssl
from pathlib import Path
from types import TracebackType
from typing import Any, Optional, Self, Type
import fastapi
import httpx
import pika
import pika.channel
import pika.credentials
import pydantic
import pyrfc6266
import ruamel.yaml
from fastapi import Form
from pika.adapters.asyncio_connection import AsyncioConnection
__all__ = [
'app',
]
log = logging.getLogger(__name__)
context: 'Context'
DIST = importlib.metadata.metadata(__name__)
DESCRIPTION_CLEAN_PATTERN = re.compile('[^a-z ]')
EXCLUDE_DESCRIPTION_WORDS = {
'a',
'ach',
'an',
'card',
'debit',
'pay',
'payment',
'purchase',
'retail',
'the',
}
FIREFLY_URL = os.environ.get(
'FIREFLY_URL',
'http://firefly-iii',
)
MAX_DOCUMENT_SIZE = int(
os.environ.get(
'MAX_DOCUMENT_SIZE',
50 * 2**20,
)
)
NTFY_URL = os.environ.get('NTFY_URL', 'http://ntfy.ntfy:2586')
PAPERLESS_URL = os.environ.get(
'PAPERLESS_URL',
'http://paperless-ngx',
)
ANSIBLE_JOB_YAML = Path(os.environ.get('ANSIBLE_JOB_YAML', 'ansible-job.yaml'))
ANSIBLE_JOB_NAMESPACE = os.environ.get('ANSIBLE_JOB_NAMESPACE', 'ansible')
HOST_INFO_QUEUE = os.environ.get('HOST_INFO_QUEUE', 'host-provisioner')
HOST_INFO_TTL = 600000
KUBERNETES_TOKEN_PATH = Path(
os.environ.get(
'KUBERNETES_TOKEN_PATH',
'/run/secrets/kubernetes.io/serviceaccount/token',
)
)
KUBERNETES_CA_CERT = Path(
os.environ.get(
'KUBERNETES_CA_CERT',
'/run/secrets/kubernetes.io/serviceaccount/ca.crt',
)
)
class FireflyIIITransactionSplit(pydantic.BaseModel):
type: str
date: datetime.datetime
amount: str
transaction_journal_id: int
description: str
class FireflyIIITransaction(pydantic.BaseModel):
transactions: list[FireflyIIITransactionSplit]
class FireflyIIIWebhook(pydantic.BaseModel):
content: FireflyIIITransaction
class PaperlessNgxDocument(pydantic.BaseModel):
id: int
title: str
class PaperlessNgxSearchResults(pydantic.BaseModel):
count: int
next: str | None
previous: str | None
results: list[PaperlessNgxDocument]
class HttpxClientMixin:
def __init__(self) -> None:
super().__init__()
self._client: Optional[httpx.AsyncClient] = None
async def __aenter__(self) -> Self:
await self.client.__aenter__()
return self
async def __aexit__(
self,
exc_type: Optional[Type[Exception]],
exc_value: Optional[Exception],
tb: Optional[TracebackType],
) -> None:
await self.client.__aexit__(exc_type, exc_value, tb)
@property
def client(self) -> httpx.AsyncClient:
if self._client is None:
self._client = self._get_client()
return self._client
def _get_client(self, **kwargs) -> httpx.AsyncClient:
headers = kwargs.setdefault('headers', {})
headers['User-Agent'] = f'{DIST["Name"]}/{DIST["Version"]}'
return httpx.AsyncClient(**kwargs)
class Firefly(HttpxClientMixin):
def _get_client(self, **kwargs) -> httpx.AsyncClient:
client = super()._get_client(**kwargs)
if token_file := os.environ.get('FIREFLY_AUTH_TOKEN'):
try:
f = open(token_file, encoding='utf-8')
except OSError as e:
log.error('Could not load Firefly-III access token: %s', e)
else:
with f:
token = f.read().strip()
client.headers['Authorization'] = f'Bearer {token}'
return client
async def attach_receipt(
self,
xact_id: int,
doc: bytes,
filename: str,
title: str | None = None,
) -> None:
log.info('Attaching receipt %r to transaction %d', filename, xact_id)
url = f'{FIREFLY_URL}/api/v1/attachments'
data = {
'filename': filename,
'attachable_type': 'TransactionJournal',
'attachable_id': xact_id,
}
if title:
data['title'] = title
r = await self.client.post(url, data=data)
r.raise_for_status()
rbody = r.json()
attachment = rbody['data']
url = f'{FIREFLY_URL}/api/v1/attachments/{attachment["id"]}/upload'
r = await self.client.post(
url,
content=doc,
headers={
'Content-Type': 'application/octet-stream',
},
)
r.raise_for_status()
class Paperless(HttpxClientMixin):
def _get_client(self, **kwargs) -> httpx.AsyncClient:
client = super()._get_client(**kwargs)
if token_file := os.environ.get('PAPERLESS_AUTH_TOKEN'):
try:
f = open(token_file, encoding='utf-8')
except OSError as e:
log.error(
'Could not load Paperless-ngx authentication token: %s', e
)
else:
with f:
token = f.read().strip()
client.headers['Authorization'] = f'Token {token}'
return client
async def find_receipts(
self, search: str, amount: float, date: datetime.date
) -> list[tuple[str, str, bytes]]:
date_begin = date - datetime.timedelta(days=2)
date_end = date + datetime.timedelta(days=2)
query = ' '.join(
(
search,
str(amount),
'type:Invoice/Receipt',
f'created:[{date_begin} TO {date_end}]',
)
)
log.info('Searching for receipt in Paperless: %s', query)
docs: list[tuple[str, str, bytes]] = []
url = f'{PAPERLESS_URL}/api/documents/'
r = await self.client.get(url, params={'query': query})
if r.status_code != 200:
if log.isEnabledFor(logging.ERROR):
try:
data = r.json()
except ValueError as e:
log.debug(
'Failed to parse HTTP error response as JSON: %s', e
)
detail = r.text
else:
try:
detail = data['detail']
except KeyError:
detail = ''
log.error(
'Error searching Paperless: HTTP %d %s: %s',
r.status_code,
r.reason_phrase,
detail,
)
return docs
try:
data = r.json()
except ValueError as e:
log.error('Failed to parse HTTP response as JSON: %s', e)
return docs
try:
results = PaperlessNgxSearchResults.parse_obj(data)
except pydantic.ValidationError as e:
log.error('Could not parse search response: %s', e)
return docs
log.info('Search returned %d documents', results.count)
if results.next:
log.warning(
'Search returned multiple pages of results; '
'only the results on the first page are used'
)
for doc in results.results:
url = f'{PAPERLESS_URL}/api/documents/{doc.id}/download/'
r = await self.client.get(url, params={'original': True})
if r.status_code != 200:
log.error(
'Failed to download document: HTTP %d %s',
r.status_code,
r.reason_phrase,
)
continue
try:
size = int(r.headers['Content-Length'])
except (KeyError, ValueError) as e:
log.error(
'Skipping document ID %d: Cannot determine file size: %s',
doc.id,
e,
)
continue
if size > MAX_DOCUMENT_SIZE:
log.warning(
'Skipping document ID %d: Size (%d bytes) is greater than '
'the configured maximum document size (%d bytes)',
size,
MAX_DOCUMENT_SIZE,
)
continue
docs.append((response_filename(r), doc.title, await r.aread()))
return docs
class AMQPError(Exception): ...
class AMQPContext:
def __init__(
self,
conn_params: Optional[
pika.ConnectionParameters | pika.URLParameters
] = None,
) -> None:
self._conn_params = conn_params
self._connection: Optional[AsyncioConnection] = None
self._channel: Optional[pika.channel.Channel] = None
@classmethod
def from_env(cls) -> Self:
if 'AMQP_URL' in os.environ:
params = pika.URLParameters(os.environ['AMQP_URL'])
else:
kwargs = {}
if host := os.environ.get('AMQP_HOST'):
kwargs['host'] = host
if port := os.environ.get('AMQP_PORT'):
kwargs['port'] = int(port)
if vhost := os.environ.get('AMQP_VIRTUAL_HOST'):
kwargs['virtual_host'] = vhost
if username := os.environ.get('AMQP_USERNAME'):
password = os.environ.get('AMQP_PASSWORD', '')
kwargs['credentials'] = pika.PlainCredentials(
username, password
)
elif os.environ.get('AMQP_EXTERNAL_CREDENTIALS'):
kwargs['credentials'] = pika.credentials.ExternalCredentials()
if (
'AMQP_CA_CERT' in os.environ
or 'AMQP_CLIENT_CERT' in os.environ
or 'AMQP_CLIENT_KEY' in os.environ
):
sslctx = ssl.create_default_context(
cafile=os.environ.get('AMQP_CA_CERT')
)
if certfile := os.environ.get('AMQP_CLIENT_CERT'):
keyfile = os.environ.get('AMQP_CLIENT_KEY')
keypassword = os.environ.get('AMQP_CLIENT_KEY_PASSWORD')
sslctx.load_cert_chain(certfile, keyfile, keypassword)
kwargs['ssl_options'] = pika.SSLOptions(
sslctx, kwargs.get('host')
)
params = pika.ConnectionParameters(**kwargs)
return cls(params)
def close(self):
if self._channel:
self._channel.close()
self._channel = None
if self._connection:
self._connection.close()
self._connection = None
async def connect(self):
loop = asyncio.get_running_loop()
if self._connection is None:
fut = loop.create_future()
self._connection = AsyncioConnection(
self._conn_params,
on_open_callback=functools.partial(self._on_open, fut),
on_open_error_callback=functools.partial(
self._on_open_error, fut
),
on_close_callback=self._on_close,
)
await fut
if self._channel is None:
fut = loop.create_future()
self._connection.channel(
on_open_callback=functools.partial(self._on_channel_open, fut),
)
self._channel = await fut
def publish(
self,
exchange: str,
routing_key: str,
body: bytes,
properties: Optional[pika.BasicProperties] = None,
):
assert self._channel
self._channel.basic_publish(
exchange,
routing_key,
body,
properties,
)
async def queue_declare(
self,
name: str,
passive: bool = False,
durable: bool = False,
exclusive: bool = False,
auto_delete: bool = False,
arguments: Optional[dict[str, str]] = None,
) -> None:
loop = asyncio.get_event_loop()
fut = loop.create_future()
assert self._channel
self._channel.queue_declare(
name,
passive,
durable,
exclusive,
auto_delete,
arguments,
callback=lambda _: fut.set_result(None),
)
await fut
def _on_open(self, fut: asyncio.Future[None], _conn):
log.info('AMQP connection open')
fut.set_result(None)
def _on_open_error(self, fut: asyncio.Future[None], _conn, error):
log.error('Failed to open AMQP connection: %s', error)
self._connection = None
fut.set_exception(AMQPError(error))
def _on_close(self, _conn, reason):
level = logging.INFO if reason.reply_code == 200 else logging.WARNING
if log.isEnabledFor(level):
log.log(
level,
'AMQP connection closed: %s (code %d)',
reason.reply_text,
reason.reply_code,
)
self._connection = None
def _on_channel_open(
self,
fut: asyncio.Future[pika.channel.Channel],
chan: pika.channel.Channel,
) -> None:
chan.add_on_close_callback(self._on_channel_close)
fut.set_result(chan)
def _on_channel_close(self, _chan, reason):
level = logging.INFO if reason.reply_code == 0 else logging.WARNING
if log.isEnabledFor(level):
log.log(
level,
'AMQP channel closed: %s (code %d)',
reason.reply_text,
reason.reply_code,
)
self._channel = None
class Kubernetes(HttpxClientMixin):
@functools.cached_property
def base_url(self) -> str:
https = True
port = os.environ.get('KUBERNETES_SERVICE_PORT_HTTPS')
if not port:
https = False
port = os.environ.get('KUBERNETES_SERVICE_PORT', 8001)
host = os.environ.get('KUBERNETES_SERVICE_HOST', '127.0.0.1')
url = f'{"https" if https else "http"}://{host}:{port}'
log.info('Using Kubernetes URL: %s', url)
return url
@functools.cached_property
def token(self) -> str:
return KUBERNETES_TOKEN_PATH.read_text().strip()
def _get_client(self, **kwargs) -> httpx.AsyncClient:
if KUBERNETES_CA_CERT.exists():
kwargs.setdefault('verify', str(KUBERNETES_CA_CERT))
client = super()._get_client(**kwargs)
try:
client.headers['Authorization'] = f'Bearer {self.token}'
except (OSError, UnicodeDecodeError) as e:
log.warning('Failed to read k8s auth token: %s', e)
return client
class Context:
def __init__(self):
self.amqp = AMQPContext.from_env()
async def handle_firefly_transaction(xact: FireflyIIITransaction) -> None:
try:
xact0 = xact.transactions[0]
except IndexError:
log.warning('Received empty transaction Firefly?')
else:
message = (
f'${xact0.amount} for {xact0.description} on {xact0.date.date()}'
)
title = f'Firefly III: New {xact.transactions[0].type.title()}'
tags = 'money_with_wings'
try:
await ntfy(message, 'firefly', title, tags)
except Exception:
log.exception('Failed to send notification')
async with Firefly() as ff, Paperless() as pl:
for split in xact.transactions:
search = clean_description(split.description)
try:
amount = float(split.amount)
except ValueError as e:
log.error('Invalid transaction amount: %s', e)
continue
for filename, title, doc in await pl.find_receipts(
search,
amount,
split.date.date(),
):
try:
await ff.attach_receipt(
split.transaction_journal_id, doc, filename, title
)
except Exception as e:
log.error(
'Failed to attach receipt to transaction ID %d: %s',
split.transaction_journal_id,
e,
)
def clean_description(text: str) -> str:
matches = DESCRIPTION_CLEAN_PATTERN.sub('', text.lower())
if not matches:
log.warning(
'Failed to clean transaction description: '
'text did not match regular expression pattern'
)
return text
match_tokens = set(matches.split())
terms = match_tokens - EXCLUDE_DESCRIPTION_WORDS
return ' '.join(terms)
def response_filename(response: httpx.Response) -> str:
if cd := response.headers.get('Content-Disposition'):
__, params = pyrfc6266.parse(cd)
maybename = ''
for p in params:
if p.name == 'filename*':
return p.value
if p.name == 'filename':
maybename = p.value
if maybename:
if maybename.startswith("b'") and maybename.endswith("'"):
maybename = maybename[2:-1]
return maybename
return response.url.path.rstrip('/').rsplit('/', 1)[-1]
async def ntfy(
message: Optional[str],
topic: str,
title: Optional[str] = None,
tags: Optional[str] = None,
attach: Optional[bytes] = None,
filename: Optional[str] = None,
cache: Optional[bool] = None,
) -> None:
assert message or attach
headers = {}
if title:
headers['Title'] = title
if tags:
headers['Tags'] = tags
if cache is not None:
headers['Cache'] = 'yes' if cache else 'no'
url = f'{NTFY_URL}/{topic}'
client = httpx.AsyncClient()
if attach:
if filename:
headers['Filename'] = filename
if message:
try:
message.encode("ascii")
except UnicodeEncodeError:
message = rfc2047_base64encode(message)
else:
message = message.replace('\n', '\\n')
headers['Message'] = message
r = await client.put(
url,
headers=headers,
content=attach,
)
else:
r = await client.post(
url,
headers=headers,
content=message,
)
r.raise_for_status()
def rfc2047_base64encode(
message: str,
) -> str:
encoded = base64.b64encode(message.encode("utf-8")).decode("ascii")
return f"=?UTF-8?B?{encoded}?="
def load_job_yaml(path: Optional[Path] = None) -> dict[str, Any]:
if path is None:
path = ANSIBLE_JOB_YAML
yaml = ruamel.yaml.YAML()
with path.open(encoding='utf-8') as f:
return yaml.load(f)
async def start_ansible_job():
async with Kubernetes() as kube:
url = (
f'{kube.base_url}/apis/batch/v1/namespaces/'
f'{ANSIBLE_JOB_NAMESPACE}/jobs'
)
job = load_job_yaml()
r = await kube.client.post(url, json=job)
if r.status_code > 299:
raise Exception(r.read())
async def publish_host_info(
hostname: str, sshkeys: str, branch: Optional[str]
):
data = {
'hostname': hostname,
'sshkeys': sshkeys,
}
if branch:
data['branch'] = branch
await context.amqp.connect()
await context.amqp.queue_declare(HOST_INFO_QUEUE, durable=True)
properties = pika.BasicProperties()
properties.expiration = str(HOST_INFO_TTL)
context.amqp.publish(
exchange='',
routing_key=HOST_INFO_QUEUE,
body=json.dumps(data).encode('utf-8'),
properties=properties,
)
async def handle_host_online(
hostname: str, sshkeys: str, branch: Optional[str]
):
try:
await publish_host_info(hostname, sshkeys, branch)
except asyncio.CancelledError:
raise
except Exception:
log.exception('Failed to publish host info:')
return
try:
await start_ansible_job()
except asyncio.CancelledError:
raise
except Exception:
log.exception('Failed to start Ansible job:')
app = fastapi.FastAPI(
name=DIST['Name'],
version=DIST['Version'],
docs_url='/api-doc/',
)
@app.on_event('startup')
def on_start() -> None:
log.setLevel(logging.DEBUG)
h = logging.StreamHandler()
h.setLevel(logging.DEBUG)
log.addHandler(h)
global context
context = Context()
@app.on_event('shutdown')
def on_shutdown() -> None:
context.amqp.close()
@app.get('/')
def status() -> str:
return 'UP'
@app.post('/hooks/firefly-iii/create')
async def firefly_iii_create(hook: FireflyIIIWebhook) -> None:
await handle_firefly_transaction(hook.content)
@app.post('/hooks/jenkins')
async def jenkins_notify(request: fastapi.Request) -> None:
body = await request.json()
data = body.get('data', {})
if body['type'] == 'run.started':
title = 'Build started'
tag = 'building_construction'
build_cause = None
for action in data.get('actions', []):
for cause in action.get('causes', []):
if build_cause := cause.get('shortDescription'):
break
else:
continue
break
message = f'Build started: {data["fullDisplayName"]}'
if build_cause:
message = f'{message} ({build_cause})'
elif body['type'] == 'run.finalized':
message = f'{data["fullDisplayName"]} {data["result"]}'
title = 'Build finished'
match data['result']:
case 'FAILURE':
tag = 'red_circle'
case 'SUCCESS':
tag = 'green_circle'
case 'UNSTABLE':
tag = 'yellow_circle'
case 'NOT_BUILT':
tag = 'large_blue_circle'
case 'ABORTED':
tag = 'black_circle'
case _:
tag = 'white_circle'
else:
return
try:
await ntfy(message, 'jenkins', title, tag, cache=False)
except httpx.HTTPError:
log.exception('Failed to send notification:')
@app.post(
'/host/online',
status_code=fastapi.status.HTTP_202_ACCEPTED,
response_class=fastapi.responses.PlainTextResponse,
)
async def host_online(
hostname: str = Form(),
sshkeys: str = Form(),
branch: Optional[str] = Form(None),
) -> None:
asyncio.create_task(handle_host_online(hostname, sshkeys, branch))