212 lines
6.4 KiB
Python
212 lines
6.4 KiB
Python
#!/usr/bin/env python3
|
|
import json
|
|
import logging
|
|
import os
|
|
import subprocess
|
|
import ssl
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
import colorlog
|
|
import pika
|
|
import pika.credentials
|
|
import pika.exceptions
|
|
import pydantic
|
|
|
|
|
|
log = logging.getLogger('hostprovision')
|
|
|
|
|
|
CONFIGPOLICY = 'https://git.pyrocufflink.net/dustin/configpolicy.git'
|
|
|
|
HOST_INFO_QUEUE = os.environ.get('HOST_INFO_QUEUE', 'host-provisioner')
|
|
QUEUE_TIMEOUT = os.environ.get('QUEUE_TIMEOUT', 10)
|
|
|
|
|
|
class HostInfo(pydantic.BaseModel):
|
|
hostname: str
|
|
sshkeys: str
|
|
branch: str = 'master'
|
|
|
|
|
|
def amqp_connect() -> pika.BlockingConnection:
|
|
if 'AMQP_URL' in os.environ:
|
|
params = pika.URLParameters(os.environ['AMQP_URL'])
|
|
else:
|
|
kwargs = {}
|
|
if host := os.environ.get('AMQP_HOST'):
|
|
kwargs['host'] = host
|
|
if port := os.environ.get('AMQP_PORT'):
|
|
kwargs['port'] = int(port)
|
|
if vhost := os.environ.get('AMQP_VIRTUAL_HOST'):
|
|
kwargs['virtual_host'] = vhost
|
|
if username := os.environ.get('AMQP_USERNAME'):
|
|
password = os.environ.get('AMQP_PASSWORD', '')
|
|
kwargs['credentials'] = pika.PlainCredentials(username, password)
|
|
elif os.environ.get('AMQP_EXTERNAL_CREDENTIALS'):
|
|
kwargs['credentials'] = pika.credentials.ExternalCredentials()
|
|
if (
|
|
'AMQP_CA_CERT' in os.environ
|
|
or 'AMQP_CLIENT_CERT' in os.environ
|
|
or 'AMQP_CLIENT_KEY' in os.environ
|
|
):
|
|
sslctx = ssl.create_default_context(
|
|
cafile=os.environ.get('AMQP_CA_CERT')
|
|
)
|
|
if certfile := os.environ.get('AMQP_CLIENT_CERT'):
|
|
keyfile = os.environ.get('AMQP_CLIENT_KEY')
|
|
keypassword = os.environ.get('AMQP_CLIENT_KEY_PASSWORD')
|
|
sslctx.load_cert_chain(certfile, keyfile, keypassword)
|
|
kwargs['ssl_options'] = pika.SSLOptions(sslctx, kwargs.get('host'))
|
|
params = pika.ConnectionParameters(**kwargs)
|
|
return pika.BlockingConnection(params)
|
|
|
|
|
|
def apply_playbook(*args: str) -> None:
|
|
cmd = [
|
|
'ansible-playbook',
|
|
'-u',
|
|
'root',
|
|
'-e',
|
|
'ansible_become_method=su',
|
|
'-e',
|
|
"ansible_become_flags='-s /bin/sh'",
|
|
]
|
|
cmd += args
|
|
log.debug('Running command: %s', cmd)
|
|
subprocess.run(cmd, check=True, stdin=subprocess.DEVNULL)
|
|
|
|
|
|
def clone_configpolicy(branch: str) -> None:
|
|
cmd = [
|
|
'git',
|
|
'clone',
|
|
'--depth=1',
|
|
CONFIGPOLICY,
|
|
'-b',
|
|
branch,
|
|
'.',
|
|
]
|
|
log.info(
|
|
'Cloning configuration policy from %s (branch %s) into %s',
|
|
CONFIGPOLICY,
|
|
branch,
|
|
os.getcwd(),
|
|
)
|
|
subprocess.run(cmd, check=True, stdin=subprocess.DEVNULL)
|
|
cmd = [
|
|
'git',
|
|
'submodule',
|
|
'update',
|
|
'--remote',
|
|
'--init',
|
|
'--depth=1',
|
|
]
|
|
log.info('Updating Git submodules')
|
|
subprocess.run(cmd, check=True)
|
|
|
|
|
|
def get_host_info() -> HostInfo | None:
|
|
log.debug('Connecting to AMQP broker')
|
|
conn = amqp_connect()
|
|
log.info('Successfully connected to AMQP broker')
|
|
chan = conn.channel()
|
|
|
|
# Tell the broker to only send a single message
|
|
chan.basic_qos(prefetch_count=1)
|
|
|
|
try:
|
|
timeout = int(QUEUE_TIMEOUT)
|
|
except ValueError:
|
|
timeout = 10
|
|
|
|
log.debug('Waiting for host info message (timeout: %d seconds)', timeout)
|
|
with chan, conn:
|
|
for method, properties, body in chan.consume(
|
|
HOST_INFO_QUEUE,
|
|
inactivity_timeout=timeout,
|
|
):
|
|
if method is None:
|
|
break
|
|
log.debug('Received: %r', body)
|
|
try:
|
|
data = json.loads(body)
|
|
host_info = HostInfo.model_validate(data)
|
|
except ValueError as e:
|
|
log.error('Failed to parse host info message: %s', e)
|
|
chan.basic_reject(method.delivery_tag)
|
|
return None
|
|
else:
|
|
chan.basic_ack(method.delivery_tag)
|
|
return host_info
|
|
|
|
|
|
def write_ssh_keys(hostname: str, ssh_keys: str) -> None:
|
|
known_hosts = Path('~/.ssh/known_hosts').expanduser()
|
|
log.info('Writing SSH host keys for %s to %s', hostname, known_hosts)
|
|
if not known_hosts.parent.is_dir():
|
|
known_hosts.parent.mkdir(parents=True)
|
|
with known_hosts.open('a', encoding='utf-8') as f:
|
|
for line in ssh_keys.splitlines():
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
if not line.startswith(('ssh-', 'ecdsa-')):
|
|
log.warning('Ignoring invalid SSH key: %s', line)
|
|
continue
|
|
f.write(f'{hostname} {line}\n')
|
|
|
|
|
|
def main():
|
|
log.setLevel(logging.DEBUG)
|
|
logging.getLogger('pika').setLevel(logging.WARNING)
|
|
logging.getLogger('pika.adapters').setLevel(logging.CRITICAL)
|
|
logging.root.setLevel(logging.INFO)
|
|
handler = logging.StreamHandler()
|
|
log_colors = dict(colorlog.default_log_colors)
|
|
log_colors['DEBUG'] = 'blue'
|
|
handler.setFormatter(
|
|
colorlog.ColoredFormatter(
|
|
'%(log_color)s%(levelname)8s%(reset)s'
|
|
' %(light_black)s%(name)s:%(reset)s %(message)s',
|
|
log_colors=log_colors,
|
|
)
|
|
)
|
|
handler.setLevel(logging.DEBUG)
|
|
logging.root.addHandler(handler)
|
|
|
|
try:
|
|
host_info = get_host_info()
|
|
except (OSError, pika.exceptions.AMQPConnectionError) as e:
|
|
log.error('Failed to connect to message broker: %s', e)
|
|
raise SystemExit(1)
|
|
except pika.exceptions.ChannelClosed as e:
|
|
log.error(
|
|
'Failed to get host info from message queue: %s (code %s)',
|
|
e.reply_text,
|
|
e.reply_code,
|
|
)
|
|
raise SystemExit(1)
|
|
if not host_info:
|
|
log.error('No host info received from queue')
|
|
raise SystemExit(1)
|
|
|
|
log.info('Provisioning host %s', host_info.hostname)
|
|
try:
|
|
if host_info.sshkeys:
|
|
write_ssh_keys(host_info.hostname, host_info.sshkeys)
|
|
with tempfile.TemporaryDirectory(prefix='host-provision.') as d:
|
|
log.debug('Using working directory %s', d)
|
|
os.chdir(d)
|
|
clone_configpolicy(host_info.branch)
|
|
apply_playbook('site.yml', '-l', host_info.hostname)
|
|
except Exception as e:
|
|
log.error('Provisioning failed: %s', e)
|
|
raise SystemExit(1)
|
|
else:
|
|
log.info('Successfully provisioned host %s', host_info.hostname)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|