host-provisioner/host_provisioner.py

212 lines
6.4 KiB
Python

#!/usr/bin/env python3
import json
import logging
import os
import subprocess
import ssl
import tempfile
from pathlib import Path
import colorlog
import pika
import pika.credentials
import pika.exceptions
import pydantic
log = logging.getLogger('hostprovision')
CONFIGPOLICY = 'https://git.pyrocufflink.net/dustin/configpolicy.git'
HOST_INFO_QUEUE = os.environ.get('HOST_INFO_QUEUE', 'host-provisioner')
QUEUE_TIMEOUT = os.environ.get('QUEUE_TIMEOUT', 10)
class HostInfo(pydantic.BaseModel):
hostname: str
sshkeys: str
branch: str = 'master'
def amqp_connect() -> pika.BlockingConnection:
if 'AMQP_URL' in os.environ:
params = pika.URLParameters(os.environ['AMQP_URL'])
else:
kwargs = {}
if host := os.environ.get('AMQP_HOST'):
kwargs['host'] = host
if port := os.environ.get('AMQP_PORT'):
kwargs['port'] = int(port)
if vhost := os.environ.get('AMQP_VIRTUAL_HOST'):
kwargs['virtual_host'] = vhost
if username := os.environ.get('AMQP_USERNAME'):
password = os.environ.get('AMQP_PASSWORD', '')
kwargs['credentials'] = pika.PlainCredentials(username, password)
elif os.environ.get('AMQP_EXTERNAL_CREDENTIALS'):
kwargs['credentials'] = pika.credentials.ExternalCredentials()
if (
'AMQP_CA_CERT' in os.environ
or 'AMQP_CLIENT_CERT' in os.environ
or 'AMQP_CLIENT_KEY' in os.environ
):
sslctx = ssl.create_default_context(
cafile=os.environ.get('AMQP_CA_CERT')
)
if certfile := os.environ.get('AMQP_CLIENT_CERT'):
keyfile = os.environ.get('AMQP_CLIENT_KEY')
keypassword = os.environ.get('AMQP_CLIENT_KEY_PASSWORD')
sslctx.load_cert_chain(certfile, keyfile, keypassword)
kwargs['ssl_options'] = pika.SSLOptions(sslctx, kwargs.get('host'))
params = pika.ConnectionParameters(**kwargs)
return pika.BlockingConnection(params)
def apply_playbook(*args: str) -> None:
cmd = [
'ansible-playbook',
'-u',
'root',
'-e',
'ansible_become_method=su',
'-e',
"ansible_become_flags='-s /bin/sh'",
]
cmd += args
log.debug('Running command: %s', cmd)
subprocess.run(cmd, check=True, stdin=subprocess.DEVNULL)
def clone_configpolicy(branch: str) -> None:
cmd = [
'git',
'clone',
'--depth=1',
CONFIGPOLICY,
'-b',
branch,
'.',
]
log.info(
'Cloning configuration policy from %s (branch %s) into %s',
CONFIGPOLICY,
branch,
os.getcwd(),
)
subprocess.run(cmd, check=True, stdin=subprocess.DEVNULL)
cmd = [
'git',
'submodule',
'update',
'--remote',
'--init',
'--depth=1',
]
log.info('Updating Git submodules')
subprocess.run(cmd, check=True)
def get_host_info() -> HostInfo | None:
log.debug('Connecting to AMQP broker')
conn = amqp_connect()
log.info('Successfully connected to AMQP broker')
chan = conn.channel()
# Tell the broker to only send a single message
chan.basic_qos(prefetch_count=1)
try:
timeout = int(QUEUE_TIMEOUT)
except ValueError:
timeout = 10
log.debug('Waiting for host info message (timeout: %d seconds)', timeout)
with chan, conn:
for method, properties, body in chan.consume(
HOST_INFO_QUEUE,
inactivity_timeout=timeout,
):
if method is None:
break
log.debug('Received: %r', body)
try:
data = json.loads(body)
host_info = HostInfo.model_validate(data)
except ValueError as e:
log.error('Failed to parse host info message: %s', e)
chan.basic_reject(method.delivery_tag)
return None
else:
chan.basic_ack(method.delivery_tag)
return host_info
def write_ssh_keys(hostname: str, ssh_keys: str) -> None:
known_hosts = Path('~/.ssh/known_hosts').expanduser()
log.info('Writing SSH host keys for %s to %s', hostname, known_hosts)
if not known_hosts.parent.is_dir():
known_hosts.parent.mkdir(parents=True)
with known_hosts.open('a', encoding='utf-8') as f:
for line in ssh_keys.splitlines():
line = line.strip()
if not line:
continue
if not line.startswith(('ssh-', 'ecdsa-')):
log.warning('Ignoring invalid SSH key: %s', line)
continue
f.write(f'{hostname} {line}\n')
def main():
log.setLevel(logging.DEBUG)
logging.getLogger('pika').setLevel(logging.WARNING)
logging.getLogger('pika.adapters').setLevel(logging.CRITICAL)
logging.root.setLevel(logging.INFO)
handler = logging.StreamHandler()
log_colors = dict(colorlog.default_log_colors)
log_colors['DEBUG'] = 'blue'
handler.setFormatter(
colorlog.ColoredFormatter(
'%(log_color)s%(levelname)8s%(reset)s'
' %(light_black)s%(name)s:%(reset)s %(message)s',
log_colors=log_colors,
)
)
handler.setLevel(logging.DEBUG)
logging.root.addHandler(handler)
try:
host_info = get_host_info()
except (OSError, pika.exceptions.AMQPConnectionError) as e:
log.error('Failed to connect to message broker: %s', e)
raise SystemExit(1)
except pika.exceptions.ChannelClosed as e:
log.error(
'Failed to get host info from message queue: %s (code %s)',
e.reply_text,
e.reply_code,
)
raise SystemExit(1)
if not host_info:
log.error('No host info received from queue')
raise SystemExit(1)
log.info('Provisioning host %s', host_info.hostname)
try:
if host_info.sshkeys:
write_ssh_keys(host_info.hostname, host_info.sshkeys)
with tempfile.TemporaryDirectory(prefix='host-provision.') as d:
log.debug('Using working directory %s', d)
os.chdir(d)
clone_configpolicy(host_info.branch)
apply_playbook('site.yml', '-l', host_info.hostname)
except Exception as e:
log.error('Provisioning failed: %s', e)
raise SystemExit(1)
else:
log.info('Successfully provisioned host %s', host_info.hostname)
if __name__ == '__main__':
main()