diff --git a/stage3img.py b/stage3img.py new file mode 100755 index 0000000..b44c3a8 --- /dev/null +++ b/stage3img.py @@ -0,0 +1,598 @@ +#!/usr/bin/env python3 +import argparse +import collections +import itertools +import logging +import os +import random +import re +import shutil +import string +import subprocess +import sys +import tempfile + + +log = logging.getLogger('stage3img') + + +LVM_PV_PART = { + 'type': 0x8e00, +} + +UNITS = { + 'k': 2 ** 10, + 'kb': 10 ** 3, + 'kib': 2 ** 10, + 'm': 2 ** 20, + 'mb': 10 ** 6, + 'mib': 2 ** 20, + 'g': 2 ** 30, + 'gb': 10 ** 9, + 'gib': 2 ** 30, + 't': 2 ** 40, + 'tb': 10 ** 12, + 'tib': 2 ** 40, + 'p': 2 ** 50, + 'pb': 10 ** 15, + 'pib': 2 ** 50, + 'e': 2 ** 60, + 'eb': 10 ** 18, + 'eib': 2 ** 60, + 'z': 2 ** 70, + 'zb': 10 ** 21, + 'zib': 2 ** 70, + 'y': 2 ** 80, + 'yb': 10 ** 24, + 'yib': 2 ** 80, +} + +SIZE_RE = re.compile( + r'^(?P[0-9]+(?:\.[0-9]+)?)\s*' + r'(?P\w+)?\s*$' +) + +TRY_SSH_KEYS = [ + os.path.expanduser('~/.ssh/id_ed25519.pub'), + os.path.expanduser('~/.ssh/id_rsa.pub'), + os.path.expanduser('~/.ssh/id_ecdsa.pub'), +] + + +class CommandError(Exception): + pass + + +class MaxLevelFilter(logging.Filter): + '''Filter log records above a particular level + + :param max_level: The maximum level of log message to include + ''' + + def __init__(self, max_level=logging.WARNING, *args, **kwargs): + logging.Filter.__init__(self, *args, **kwargs) + self.max_level = max_level + + def filter(self, record): + return record.levelno <= self.max_level + + +class Image(object): + + EXTRA_MOUNTS = [ + ('/tmp', 'tmpfs', '-t', 'tmpfs'), + ('/run', 'tmpfs', '-t', 'tmpfs'), + ('/dev', '/dev', '--rbind'), + ('/proc', 'proc', '-t', 'proc'), + ('/sys', '/sys', '--rbind'), + ] + SCRIPT_ENV = { + 'PATH': os.pathsep.join(( + '/usr/local/sbin', + '/usr/sbin', + '/sbin', + '/usr/local/bin', + '/usr/bin', + '/bin', + )), + } + + default_fstype = 'ext4' + + def __init__(self, filename, fd=None): + self.fd = fd + self.filename = filename + self.tempname = self.filename + '.tmp' + self.loopdev = None + self.partitions = None + self.vgname = None + self.volumes = None + self.mountpoint = None + self._extra_mounted = False + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, tb): + self.unmount() + self.deactivate_vg() + self.detach_loopback() + if self.fd is not None: + os.close(self.fd) + if exc_type: + log.error('Error occurred, removing image') + try: + os.unlink(self.tempname) + except: + log.exception('Failed to remove temporary image file:') + try: + os.unlink(self.filename) + except: + log.exception('Failed to remove image file:') + else: + os.rename(self.tempname, self.filename) + + + @property + def filesystems(self): + if self.partitions: + partnum = 1 + for part in self.partitions: + if 'mountpoint' not in part: + continue + dev = '{}p{}'.format(self.loopdev, partnum) + yield ( + part['mountpoint'], + dev, + part.get('fstype'), + part.get('fsopts', 'defaults'), + ) + partnum += 1 + if self.volumes: + for vol in self.volumes: + if 'mountpoint' not in vol: + continue + dev = '/dev/{}/{}'.format(self.vgname, vol['name']) + yield ( + vol['mountpoint'], + dev, + vol.get('fstype'), + vol.get('fsopts', 'defaults'), + ) + + @classmethod + def create(cls, filename): + fd = os.open(filename, os.O_CREAT | os.O_EXCL | os.O_WRONLY) + return cls(filename, fd) + + def resize(self, size): + with open(self.tempname, 'wb') as f: + f.truncate(size) + + def attach_loopback(self): + self.loopdev = run_cmd('losetup', '-f', '--show', self.tempname) + log.info('Connected {} to {}'.format(self.tempname, self.loopdev)) + + def detach_loopback(self): + if not self.loopdev: + return + log.info('Detaching {}'.format(self.loopdev)) + run_cmd('losetup', '-d', self.loopdev) + + def partition(self, partitions): + if not self.loopdev: + self.attach_loopback() + log.info('Partitioning {} with GPT'.format(self.loopdev)) + cmd = ['sgdisk', '-a', '4096', '-Z', '-g'] + for idx, part in enumerate(partitions): + partnum = idx + 1 + cmd += ('-n', '{}::{}'.format(partnum, part.get('size', ''))) + if 'type' in part: + cmd += ('-t', '{}:{:X}'.format(partnum, part['type'])) + if 'name' in part: + cmd += ('-c', '{}:{}'.format(partnum, part['name'])) + cmd.append(self.loopdev) + run_cmd_logged(*cmd) + self.partitions = partititons + + def setup_lvm(self, vgname, volumes): + if not self.loopdev: + self.attach_loopback() + pvscan(self.loopdev) + if self.partitions: + pv = '{}p{}'.format(self.loopdev, len(self.partitions)) + else: + pv = self.loopdev + pvcreate(pv) + vgcreate(vgname, pv) + self.vgname = vgname + for vol in volumes: + lvcreate(vgname, vol['name'], vol.get('size')) + self.volumes = volumes + + def deactivate_vg(self): + if not self.vgname: + return + run_cmd_logged('vgchange', '-an', self.vgname) + + def make_filesystems(self): + for mountpoint, dev, fstype, fsopts in self.filesystems: + if not fstype: + fstype = self.default_fstype + log.info('Creating {} filesystem on {}'.format(fstype, dev)) + run_cmd('mkfs.{}'.format(fstype), dev) + + def mount(self): + self.mountpoint = tempfile.mkdtemp() + for mountpoint, dev, fstype, fsopts in sorted(self.filesystems): + path = os.path.join(self.mountpoint, mountpoint[1:]) + if not os.path.isdir(path): + os.makedirs(path) + log.info('Mounting {} on {}'.format(dev, path)) + run_cmd_logged('mount', dev, path) + + def mount_extra(self): + if self._extra_mounted: + return + for item in self.EXTRA_MOUNTS: + mountpoint, dev, args = item[0], item[1], item[2:] + path = os.path.join(self.mountpoint, mountpoint[1:]) + if not os.path.isdir(path): + os.makedirs(path) + run_cmd_logged('mount', dev, path, *args) + + def unmount(self): + if not self.mountpoint: + return + log.info('Unmounting {}'.format(self.mountpoint)) + run_cmd_logged('umount', '-R', self.mountpoint) + os.rmdir(self.mountpoint) + self.mountpoint = None + self._mount_extra = False + + def extract(self, filename): + if not self.mountpoint: + self.mount() + log.info('Extracting {} to {}'.format(filename, self.mountpoint)) + run_cmd_logged('tar', '-xha', '--numeric-owner', '-f', filename, + '-C', self.mountpoint) + + def run_script(self, script, chroot=True): + if not self._extra_mounted: + self.mount_extra() + name = os.path.basename(script) + dest = os.path.join(self.mountpoint, 'tmp', name) + log.debug('Copying {} to {}'.format(script, dest)) + shutil.copy(script, dest) + os.chmod(dest, 0o0755) + kwargs = { + 'env': self.SCRIPT_ENV.copy(), + } + if chroot: + cmd = ('chroot', self.mountpoint, os.path.join('/tmp', name)) + else: + cmd = (dest,) + kwargs['env']['IMAGE_ROOT'] = self.mountpoint + kwargs['cwd'] = self.mountpoint + log.info('Running script: {}'.format(name)) + try: + run_cmd_logged(*cmd, **kwargs) + finally: + os.unlink(dest) + + def write_fstab(self, tmpfstmp=True): + tmpl = '{dev}\t{mountpoint}\t{fstype}\t{fsopts}\t0 {passno}\n' + log.info('Writing fstab') + with open(os.path.join(self.mountpoint, 'etc/fstab'), 'w') as fstab: + for mountpoint, dev, fstype, fsopts in sorted(self.filesystems): + line = tmpl.format( + dev=dev, + mountpoint=mountpoint, + fstype=fstype or self.default_fstype, + fsopts=fsopts, + passno=1 if mountpoint == '/' else 2, + ) + log.debug(line.rstrip('\n')) + fstab.write(line) + if tmpfstmp: + line = tmpl.format( + dev='tmpfs', + mountpoint='/tmp', + fstype='tmpfs', + fsopts='defaults', + passno=0, + ) + log.debug(line.rstrip('\n')) + fstab.write(line) + + +def gen_vgname(length=8): + name = random.choice(string.ascii_lowercase) + while len(name) < length - 1: + name += random.choice(string.ascii_letters + string.digits) + name += random.choice(string.ascii_lowercase) + return name + + +def inject_ssh_keys(root, ssh_keys): + ssh_dir = os.path.join(root, 'root', '.ssh') + if not os.path.isdir(ssh_dir): + os.makedirs(ssh_dir) + ssh_keys = set(ssh_keys) + try: + ssh_keys.remove(None) + except KeyError: + pass + else: + for path in TRY_SSH_KEYS: + if os.path.exists(path): + ssh_keys.append(path) + break + with open(os.path.join(ssh_dir, 'authorized_keys'), 'a') as dest: + for keyfile in ssh_keys: + with open(keyfile) as src: + pubkey = src.read().strip() + log.info('Injecting SSH public key: {}'.format(keyfile)) + dest.write(pubkey + '\n') + + +def lvcreate(vg_name, name, size=None): + cmd = ['lvcreate', '-n', name] + if size is not None: + if '%' in size: + cmd += ('-l', size) + else: + cmd += ('-L', size) + else: + cmd += ('-l', '100%FREE') + cmd.append(vg_name) + run_cmd_logged(*cmd) + + +def parse_volumes(volstrings): + pat = re.compile(r'(?[=[;[;[;' + '[;]]]]]', + ) + parser.add_argument( + '--default-fstype', + metavar='FSTYPE', + help='Default filesystem type to use for volumes that do not specify ' + 'one of their own', + ) + parser.add_argument( + '--no-tmpfs-tmp', + dest='tmpfstmp', + action='store_false', + default=True, + help='Do not write an fstab entry for /tmp on tmpfs', + ) + parser.add_argument( + '--overlay', '-O', + metavar='FILENAME', + dest='overlays', + action='append', + default=[], + help='Extract the contents of FILENAME onto the image', + ) + parser.add_argument( + '--script', '-S', + metavar='FILENAME', + dest='scripts', + action='append', + default=[], + help='Run a script inside the image', + ) + parser.add_argument( + '--no-chroot', + dest='chroot', + action='store_false', + default=True, + help='Do not chroot into the image to run scripts - DANGEROUS!', + ) + parser.add_argument( + '--inject-ssh-key', '-i', + metavar='FILENAME', + nargs='?', + dest='inject_ssh_keys', + action='append', + help='Pre-authorize SSH public keys for root', + ) + parser.add_argument( + 'stagetbz', + help='Path to stage tarball', + ) + parser.add_argument( + 'image', + nargs='?', + help='Path to destination image file', + ) + args = parser.parse_args() + if args.image is None: + basename = os.path.splitext(args.stagetbz)[0] + if basename.endswith('.tar'): + basename = basename[:-4] + args.image = basename + '.img' + if args.volumes is None: + args.volumes = ['/'] + return args + + +def main(): + args = parse_args() + + setup_logging(args.verbose) + + if args.lvm is None: + args.lvm = gen_vgname() + + volumes = parse_volumes(args.volumes) + if args.lvm: + try: + partitions = [volumes.pop('/boot'), LVM_PV_PART] + except KeyError: + partitions = None + else: + partitions = volumes.values() + + try: + image = Image.create(args.image) + except OSError as e: + sys.stderr.write('Failed to create image: {}\n'.format(e)) + raise SystemExit(os.EX_CANTCREAT) + if args.default_fstype: + image.default_fstype = args.default_fstype + with image: + image.resize(args.size) + if partitions: + image.partition(partitions) + if args.lvm: + image.setup_lvm(args.lvm, volumes.values()) + image.make_filesystems() + image.extract(args.stagetbz) + for overlay in args.overlays: + image.extract(overlay) + image.write_fstab(tmpfstmp=args.tmpfstmp) + for script in args.scripts: + image.run_script(script, chroot=args.chroot) + if args.inject_ssh_keys: + inject_ssh_keys(image.mountpoint, args.inject_ssh_keys) + + +if __name__ == '__main__': + main()