#!/usr/bin/env python3 import json import logging import os import subprocess import ssl import tempfile from pathlib import Path import colorlog import pika import pika.credentials import pika.exceptions import pydantic log = logging.getLogger('hostprovision') CONFIGPOLICY = 'https://git.pyrocufflink.net/dustin/configpolicy.git' HOST_INFO_QUEUE = os.environ.get('HOST_INFO_QUEUE', 'host-provisioner') QUEUE_TIMEOUT = os.environ.get('QUEUE_TIMEOUT', 10) class HostInfo(pydantic.BaseModel): hostname: str sshkeys: str branch: str = 'master' def amqp_connect() -> pika.BlockingConnection: if 'AMQP_URL' in os.environ: params = pika.URLParameters(os.environ['AMQP_URL']) else: kwargs = {} if host := os.environ.get('AMQP_HOST'): kwargs['host'] = host if port := os.environ.get('AMQP_PORT'): kwargs['port'] = int(port) if vhost := os.environ.get('AMQP_VIRTUAL_HOST'): kwargs['virtual_host'] = vhost if username := os.environ.get('AMQP_USERNAME'): password = os.environ.get('AMQP_PASSWORD', '') kwargs['credentials'] = pika.PlainCredentials(username, password) elif os.environ.get('AMQP_EXTERNAL_CREDENTIALS'): kwargs['credentials'] = pika.credentials.ExternalCredentials() if ( 'AMQP_CA_CERT' in os.environ or 'AMQP_CLIENT_CERT' in os.environ or 'AMQP_CLIENT_KEY' in os.environ ): sslctx = ssl.create_default_context( cafile=os.environ.get('AMQP_CA_CERT') ) if certfile := os.environ.get('AMQP_CLIENT_CERT'): keyfile = os.environ.get('AMQP_CLIENT_KEY') keypassword = os.environ.get('AMQP_CLIENT_KEY_PASSWORD') sslctx.load_cert_chain(certfile, keyfile, keypassword) kwargs['ssl_options'] = pika.SSLOptions(sslctx, kwargs.get('host')) params = pika.ConnectionParameters(**kwargs) return pika.BlockingConnection(params) def apply_playbook(*args: str) -> None: cmd = ['ansible-playbook', '-u', 'root'] cmd += args log.debug('Running command: %s', cmd) subprocess.run(cmd, check=True, stdin=subprocess.DEVNULL) def clone_configpolicy(branch: str) -> None: cmd = [ 'git', 'clone', '--depth=1', CONFIGPOLICY, '-b', branch, '.', ] log.info( 'Cloning configuration policy from %s (branch %s) into %s', CONFIGPOLICY, branch, os.getcwd(), ) subprocess.run(cmd, check=True, stdin=subprocess.DEVNULL) cmd = [ 'git', 'submodule', 'update', '--remote', '--init', '--depth=1', ] log.info('Updating Git submodules') subprocess.run(cmd, check=True) def get_host_info() -> HostInfo | None: log.debug('Connecting to AMQP broker') conn = amqp_connect() log.info('Successfully connected to AMQP broker') chan = conn.channel() # Tell the broker to only send a single message chan.basic_qos(prefetch_count=1) try: timeout = int(QUEUE_TIMEOUT) except ValueError: timeout = 10 log.debug('Waiting for host info message (timeout: %d seconds)', timeout) with chan, conn: for method, properties, body in chan.consume( HOST_INFO_QUEUE, inactivity_timeout=timeout, ): if method is None: break log.debug('Received: %r', body) try: data = json.loads(body) host_info = HostInfo.model_validate(data) except ValueError as e: log.error('Failed to parse host info message: %s', e) chan.basic_reject(method.delivery_tag) return None else: chan.basic_ack(method.delivery_tag) return host_info def write_ssh_keys(hostname: str, ssh_keys: str) -> None: known_hosts = Path('~/.ssh/known_hosts').expanduser() log.info('Writing SSH host keys for %s to %s', hostname, known_hosts) if not known_hosts.parent.is_dir(): known_hosts.parent.mkdir(parents=True) with known_hosts.open('a', encoding='utf-8') as f: for line in ssh_keys.splitlines(): line = line.strip() if not line: continue if not line.startswith(('ssh-', 'ecdsa-')): log.warning('Ignoring invalid SSH key: %s', line) continue f.write(f'{hostname} {line}\n') def main(): log.setLevel(logging.DEBUG) logging.getLogger('pika').setLevel(logging.WARNING) logging.getLogger('pika.adapters').setLevel(logging.CRITICAL) logging.root.setLevel(logging.INFO) handler = logging.StreamHandler() log_colors = dict(colorlog.default_log_colors) log_colors['DEBUG'] = 'blue' handler.setFormatter( colorlog.ColoredFormatter( '%(log_color)s%(levelname)8s%(reset)s' ' %(light_black)s%(name)s:%(reset)s %(message)s', log_colors=log_colors, ) ) handler.setLevel(logging.DEBUG) logging.root.addHandler(handler) try: host_info = get_host_info() except (OSError, pika.exceptions.AMQPConnectionError) as e: log.error('Failed to connect to message broker: %s', e) raise SystemExit(1) except pika.exceptions.ChannelClosed as e: log.error( 'Failed to get host info from message queue: %s (code %s)', e.reply_text, e.reply_code, ) raise SystemExit(1) if not host_info: log.error('No host info received from queue') raise SystemExit(1) log.info('Provisioning host %s', host_info.hostname) try: if host_info.sshkeys: write_ssh_keys(host_info.hostname, host_info.sshkeys) with tempfile.TemporaryDirectory(prefix='host-provision.') as d: log.debug('Using working directory %s', d) os.chdir(d) clone_configpolicy(host_info.branch) apply_playbook('site.yml', '-l', host_info.hostname) except Exception as e: log.error('Provisioning failed: %s', e) raise SystemExit(1) else: log.info('Successfully provisioned host %s', host_info.hostname) if __name__ == '__main__': main()