#!/usr/bin/env python3 import argparse import collections import itertools import logging import os import random import re import shutil import string import subprocess import sys import tempfile log = logging.getLogger('stage3img') LVM_PV_PART = { 'type': 0x8e00, } UNITS = { 'k': 2 ** 10, 'kb': 10 ** 3, 'kib': 2 ** 10, 'm': 2 ** 20, 'mb': 10 ** 6, 'mib': 2 ** 20, 'g': 2 ** 30, 'gb': 10 ** 9, 'gib': 2 ** 30, 't': 2 ** 40, 'tb': 10 ** 12, 'tib': 2 ** 40, 'p': 2 ** 50, 'pb': 10 ** 15, 'pib': 2 ** 50, 'e': 2 ** 60, 'eb': 10 ** 18, 'eib': 2 ** 60, 'z': 2 ** 70, 'zb': 10 ** 21, 'zib': 2 ** 70, 'y': 2 ** 80, 'yb': 10 ** 24, 'yib': 2 ** 80, } SIZE_RE = re.compile( r'^(?P[0-9]+(?:\.[0-9]+)?)\s*' r'(?P\w+)?\s*$' ) TRY_SSH_KEYS = [ os.path.expanduser('~/.ssh/id_ed25519.pub'), os.path.expanduser('~/.ssh/id_rsa.pub'), os.path.expanduser('~/.ssh/id_ecdsa.pub'), ] class CommandError(Exception): pass class MaxLevelFilter(logging.Filter): '''Filter log records above a particular level :param max_level: The maximum level of log message to include ''' def __init__(self, max_level=logging.WARNING, *args, **kwargs): logging.Filter.__init__(self, *args, **kwargs) self.max_level = max_level def filter(self, record): return record.levelno <= self.max_level class Image(object): EXTRA_MOUNTS = [ ('/tmp', 'tmpfs', '-t', 'tmpfs'), ('/run', 'tmpfs', '-t', 'tmpfs'), ('/dev', '/dev', '--rbind'), ('/proc', 'proc', '-t', 'proc'), ('/sys', '/sys', '--rbind'), ] SCRIPT_ENV = { 'PATH': os.pathsep.join(( '/usr/local/sbin', '/usr/sbin', '/sbin', '/usr/local/bin', '/usr/bin', '/bin', )), } default_fstype = 'ext4' def __init__(self, filename, fd=None): self.fd = fd self.filename = filename self.tempname = self.filename + '.tmp' self.loopdev = None self.partitions = None self.vgname = None self.volumes = None self.mountpoint = None self._extra_mounted = False def __enter__(self): return self def __exit__(self, exc_type, exc_value, tb): self.unmount() self.deactivate_vg() self.detach_loopback() if self.fd is not None: os.close(self.fd) if exc_type: log.error('Error occurred, removing image') try: os.unlink(self.tempname) except: log.exception('Failed to remove temporary image file:') try: os.unlink(self.filename) except: log.exception('Failed to remove image file:') else: os.rename(self.tempname, self.filename) @property def filesystems(self): if self.partitions: partnum = 1 for part in self.partitions: if 'mountpoint' not in part: continue dev = '{}p{}'.format(self.loopdev, partnum) yield ( part['mountpoint'], dev, part.get('fstype'), part.get('fsopts', 'defaults'), ) partnum += 1 if self.volumes: for vol in self.volumes: if 'mountpoint' not in vol: continue dev = '/dev/{}/{}'.format(self.vgname, vol['name']) yield ( vol['mountpoint'], dev, vol.get('fstype'), vol.get('fsopts', 'defaults'), ) @classmethod def create(cls, filename): fd = os.open(filename, os.O_CREAT | os.O_EXCL | os.O_WRONLY) return cls(filename, fd) def resize(self, size): with open(self.tempname, 'wb') as f: f.truncate(size) def attach_loopback(self): self.loopdev = run_cmd('losetup', '-f', '--show', self.tempname) log.info('Connected {} to {}'.format(self.tempname, self.loopdev)) def detach_loopback(self): if not self.loopdev: return log.info('Detaching {}'.format(self.loopdev)) run_cmd('losetup', '-d', self.loopdev) def partition(self, partitions): if not self.loopdev: self.attach_loopback() log.info('Partitioning {} with GPT'.format(self.loopdev)) cmd = ['sgdisk', '-a', '4096', '-Z', '-g'] for idx, part in enumerate(partitions): partnum = idx + 1 cmd += ('-n', '{}::{}'.format(partnum, part.get('size', ''))) if 'type' in part: cmd += ('-t', '{}:{:X}'.format(partnum, part['type'])) if 'name' in part: cmd += ('-c', '{}:{}'.format(partnum, part['name'])) cmd.append(self.loopdev) run_cmd_logged(*cmd) self.partitions = partititons def setup_lvm(self, vgname, volumes): if not self.loopdev: self.attach_loopback() pvscan(self.loopdev) if self.partitions: pv = '{}p{}'.format(self.loopdev, len(self.partitions)) else: pv = self.loopdev pvcreate(pv) vgcreate(vgname, pv) self.vgname = vgname for vol in volumes: lvcreate(vgname, vol['name'], vol.get('size')) self.volumes = volumes def deactivate_vg(self): if not self.vgname: return run_cmd_logged('vgchange', '-an', self.vgname) def make_filesystems(self): for mountpoint, dev, fstype, fsopts in self.filesystems: if not fstype: fstype = self.default_fstype log.info('Creating {} filesystem on {}'.format(fstype, dev)) run_cmd('mkfs.{}'.format(fstype), dev) def mount(self): self.mountpoint = tempfile.mkdtemp() for mountpoint, dev, fstype, fsopts in sorted(self.filesystems): path = os.path.join(self.mountpoint, mountpoint[1:]) if not os.path.isdir(path): os.makedirs(path) log.info('Mounting {} on {}'.format(dev, path)) run_cmd_logged('mount', dev, path) def mount_extra(self): if self._extra_mounted: return for item in self.EXTRA_MOUNTS: mountpoint, dev, args = item[0], item[1], item[2:] path = os.path.join(self.mountpoint, mountpoint[1:]) if not os.path.isdir(path): os.makedirs(path) run_cmd_logged('mount', dev, path, *args) def unmount(self): if not self.mountpoint: return log.info('Unmounting {}'.format(self.mountpoint)) run_cmd_logged('umount', '-R', self.mountpoint) os.rmdir(self.mountpoint) self.mountpoint = None self._mount_extra = False def extract(self, filename): if not self.mountpoint: self.mount() log.info('Extracting {} to {}'.format(filename, self.mountpoint)) run_cmd_logged('tar', '-xha', '--numeric-owner', '-f', filename, '-C', self.mountpoint) def run_script(self, script, chroot=True): if not self._extra_mounted: self.mount_extra() name = os.path.basename(script) dest = os.path.join(self.mountpoint, 'tmp', name) log.debug('Copying {} to {}'.format(script, dest)) shutil.copy(script, dest) os.chmod(dest, 0o0755) kwargs = { 'env': self.SCRIPT_ENV.copy(), } if chroot: cmd = ('chroot', self.mountpoint, os.path.join('/tmp', name)) else: cmd = (dest,) kwargs['env']['IMAGE_ROOT'] = self.mountpoint kwargs['cwd'] = self.mountpoint log.info('Running script: {}'.format(name)) try: run_cmd_logged(*cmd, **kwargs) finally: os.unlink(dest) def write_fstab(self, tmpfstmp=True): tmpl = '{dev}\t{mountpoint}\t{fstype}\t{fsopts}\t0 {passno}\n' log.info('Writing fstab') with open(os.path.join(self.mountpoint, 'etc/fstab'), 'w') as fstab: for mountpoint, dev, fstype, fsopts in sorted(self.filesystems): line = tmpl.format( dev=dev, mountpoint=mountpoint, fstype=fstype or self.default_fstype, fsopts=fsopts, passno=1 if mountpoint == '/' else 2, ) log.debug(line.rstrip('\n')) fstab.write(line) if tmpfstmp: line = tmpl.format( dev='tmpfs', mountpoint='/tmp', fstype='tmpfs', fsopts='defaults', passno=0, ) log.debug(line.rstrip('\n')) fstab.write(line) def gen_vgname(length=8): name = random.choice(string.ascii_lowercase) while len(name) < length - 1: name += random.choice(string.ascii_letters + string.digits) name += random.choice(string.ascii_lowercase) return name def inject_ssh_keys(root, ssh_keys): ssh_dir = os.path.join(root, 'root', '.ssh') if not os.path.isdir(ssh_dir): os.makedirs(ssh_dir) ssh_keys = set(ssh_keys) try: ssh_keys.remove(None) except KeyError: pass else: for path in TRY_SSH_KEYS: if os.path.exists(path): ssh_keys.append(path) break with open(os.path.join(ssh_dir, 'authorized_keys'), 'a') as dest: for keyfile in ssh_keys: with open(keyfile) as src: pubkey = src.read().strip() log.info('Injecting SSH public key: {}'.format(keyfile)) dest.write(pubkey + '\n') def lvcreate(vg_name, name, size=None): cmd = ['lvcreate', '-n', name] if size is not None: if '%' in size: cmd += ('-l', size) else: cmd += ('-L', size) else: cmd += ('-l', '100%FREE') cmd.append(vg_name) run_cmd_logged(*cmd) def parse_volumes(volstrings): pat = re.compile(r'(?[=[;[;[;' '[;]]]]]', ) parser.add_argument( '--default-fstype', metavar='FSTYPE', help='Default filesystem type to use for volumes that do not specify ' 'one of their own', ) parser.add_argument( '--no-tmpfs-tmp', dest='tmpfstmp', action='store_false', default=True, help='Do not write an fstab entry for /tmp on tmpfs', ) parser.add_argument( '--overlay', '-O', metavar='FILENAME', dest='overlays', action='append', default=[], help='Extract the contents of FILENAME onto the image', ) parser.add_argument( '--script', '-S', metavar='FILENAME', dest='scripts', action='append', default=[], help='Run a script inside the image', ) parser.add_argument( '--no-chroot', dest='chroot', action='store_false', default=True, help='Do not chroot into the image to run scripts - DANGEROUS!', ) parser.add_argument( '--inject-ssh-key', '-i', metavar='FILENAME', nargs='?', dest='inject_ssh_keys', action='append', help='Pre-authorize SSH public keys for root', ) parser.add_argument( 'stagetbz', help='Path to stage tarball', ) parser.add_argument( 'image', nargs='?', help='Path to destination image file', ) args = parser.parse_args() if args.image is None: basename = os.path.splitext(args.stagetbz)[0] if basename.endswith('.tar'): basename = basename[:-4] args.image = basename + '.img' if args.volumes is None: args.volumes = ['/'] return args def main(): args = parse_args() setup_logging(args.verbose) if args.lvm is None: args.lvm = gen_vgname() volumes = parse_volumes(args.volumes) if args.lvm: try: partitions = [volumes.pop('/boot'), LVM_PV_PART] except KeyError: partitions = None else: partitions = volumes.values() try: image = Image.create(args.image) except OSError as e: sys.stderr.write('Failed to create image: {}\n'.format(e)) raise SystemExit(os.EX_CANTCREAT) if args.default_fstype: image.default_fstype = args.default_fstype with image: image.resize(args.size) if partitions: image.partition(partitions) if args.lvm: image.setup_lvm(args.lvm, volumes.values()) image.make_filesystems() image.extract(args.stagetbz) for overlay in args.overlays: image.extract(overlay) image.write_fstab(tmpfstmp=args.tmpfstmp) for script in args.scripts: image.run_script(script, chroot=args.chroot) if args.inject_ssh_keys: inject_ssh_keys(image.mountpoint, args.inject_ssh_keys) if __name__ == '__main__': main()