scripts/stage3img.py

599 lines
17 KiB
Python
Executable File
Raw Blame History

This file contains invisible Unicode characters!

This file contains invisible Unicode characters that may be processed differently from what appears below. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to reveal hidden characters.

#!/usr/bin/env python3
import argparse
import collections
import itertools
import logging
import os
import random
import re
import shutil
import string
import subprocess
import sys
import tempfile
log = logging.getLogger('stage3img')
LVM_PV_PART = {
'type': 0x8e00,
}
UNITS = {
'k': 2 ** 10,
'kb': 10 ** 3,
'kib': 2 ** 10,
'm': 2 ** 20,
'mb': 10 ** 6,
'mib': 2 ** 20,
'g': 2 ** 30,
'gb': 10 ** 9,
'gib': 2 ** 30,
't': 2 ** 40,
'tb': 10 ** 12,
'tib': 2 ** 40,
'p': 2 ** 50,
'pb': 10 ** 15,
'pib': 2 ** 50,
'e': 2 ** 60,
'eb': 10 ** 18,
'eib': 2 ** 60,
'z': 2 ** 70,
'zb': 10 ** 21,
'zib': 2 ** 70,
'y': 2 ** 80,
'yb': 10 ** 24,
'yib': 2 ** 80,
}
SIZE_RE = re.compile(
r'^(?P<value>[0-9]+(?:\.[0-9]+)?)\s*'
r'(?P<unit>\w+)?\s*$'
)
TRY_SSH_KEYS = [
os.path.expanduser('~/.ssh/id_ed25519.pub'),
os.path.expanduser('~/.ssh/id_rsa.pub'),
os.path.expanduser('~/.ssh/id_ecdsa.pub'),
]
class CommandError(Exception):
pass
class MaxLevelFilter(logging.Filter):
'''Filter log records above a particular level
:param max_level: The maximum level of log message to include
'''
def __init__(self, max_level=logging.WARNING, *args, **kwargs):
logging.Filter.__init__(self, *args, **kwargs)
self.max_level = max_level
def filter(self, record):
return record.levelno <= self.max_level
class Image(object):
EXTRA_MOUNTS = [
('/tmp', 'tmpfs', '-t', 'tmpfs'),
('/run', 'tmpfs', '-t', 'tmpfs'),
('/dev', '/dev', '--rbind'),
('/proc', 'proc', '-t', 'proc'),
('/sys', '/sys', '--rbind'),
]
SCRIPT_ENV = {
'PATH': os.pathsep.join((
'/usr/local/sbin',
'/usr/sbin',
'/sbin',
'/usr/local/bin',
'/usr/bin',
'/bin',
)),
}
default_fstype = 'ext4'
def __init__(self, filename, fd=None):
self.fd = fd
self.filename = filename
self.tempname = self.filename + '.tmp'
self.loopdev = None
self.partitions = None
self.vgname = None
self.volumes = None
self.mountpoint = None
self._extra_mounted = False
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, tb):
self.unmount()
self.deactivate_vg()
self.detach_loopback()
if self.fd is not None:
os.close(self.fd)
if exc_type:
log.error('Error occurred, removing image')
try:
os.unlink(self.tempname)
except:
log.exception('Failed to remove temporary image file:')
try:
os.unlink(self.filename)
except:
log.exception('Failed to remove image file:')
else:
os.rename(self.tempname, self.filename)
@property
def filesystems(self):
if self.partitions:
partnum = 1
for part in self.partitions:
if 'mountpoint' not in part:
continue
dev = '{}p{}'.format(self.loopdev, partnum)
yield (
part['mountpoint'],
dev,
part.get('fstype'),
part.get('fsopts', 'defaults'),
)
partnum += 1
if self.volumes:
for vol in self.volumes:
if 'mountpoint' not in vol:
continue
dev = '/dev/{}/{}'.format(self.vgname, vol['name'])
yield (
vol['mountpoint'],
dev,
vol.get('fstype'),
vol.get('fsopts', 'defaults'),
)
@classmethod
def create(cls, filename):
fd = os.open(filename, os.O_CREAT | os.O_EXCL | os.O_WRONLY)
return cls(filename, fd)
def resize(self, size):
with open(self.tempname, 'wb') as f:
f.truncate(size)
def attach_loopback(self):
self.loopdev = run_cmd('losetup', '-f', '--show', self.tempname)
log.info('Connected {} to {}'.format(self.tempname, self.loopdev))
def detach_loopback(self):
if not self.loopdev:
return
log.info('Detaching {}'.format(self.loopdev))
run_cmd('losetup', '-d', self.loopdev)
def partition(self, partitions):
if not self.loopdev:
self.attach_loopback()
log.info('Partitioning {} with GPT'.format(self.loopdev))
cmd = ['sgdisk', '-a', '4096', '-Z', '-g']
for idx, part in enumerate(partitions):
partnum = idx + 1
cmd += ('-n', '{}::{}'.format(partnum, part.get('size', '')))
if 'type' in part:
cmd += ('-t', '{}:{:X}'.format(partnum, part['type']))
if 'name' in part:
cmd += ('-c', '{}:{}'.format(partnum, part['name']))
cmd.append(self.loopdev)
run_cmd_logged(*cmd)
self.partitions = partititons
def setup_lvm(self, vgname, volumes):
if not self.loopdev:
self.attach_loopback()
pvscan(self.loopdev)
if self.partitions:
pv = '{}p{}'.format(self.loopdev, len(self.partitions))
else:
pv = self.loopdev
pvcreate(pv)
vgcreate(vgname, pv)
self.vgname = vgname
for vol in volumes:
lvcreate(vgname, vol['name'], vol.get('size'))
self.volumes = volumes
def deactivate_vg(self):
if not self.vgname:
return
run_cmd_logged('vgchange', '-an', self.vgname)
def make_filesystems(self):
for mountpoint, dev, fstype, fsopts in self.filesystems:
if not fstype:
fstype = self.default_fstype
log.info('Creating {} filesystem on {}'.format(fstype, dev))
run_cmd('mkfs.{}'.format(fstype), dev)
def mount(self):
self.mountpoint = tempfile.mkdtemp()
for mountpoint, dev, fstype, fsopts in sorted(self.filesystems):
path = os.path.join(self.mountpoint, mountpoint[1:])
if not os.path.isdir(path):
os.makedirs(path)
log.info('Mounting {} on {}'.format(dev, path))
run_cmd_logged('mount', dev, path)
def mount_extra(self):
if self._extra_mounted:
return
for item in self.EXTRA_MOUNTS:
mountpoint, dev, args = item[0], item[1], item[2:]
path = os.path.join(self.mountpoint, mountpoint[1:])
if not os.path.isdir(path):
os.makedirs(path)
run_cmd_logged('mount', dev, path, *args)
def unmount(self):
if not self.mountpoint:
return
log.info('Unmounting {}'.format(self.mountpoint))
run_cmd_logged('umount', '-R', self.mountpoint)
os.rmdir(self.mountpoint)
self.mountpoint = None
self._mount_extra = False
def extract(self, filename):
if not self.mountpoint:
self.mount()
log.info('Extracting {} to {}'.format(filename, self.mountpoint))
run_cmd_logged('tar', '-xha', '--numeric-owner', '-f', filename,
'-C', self.mountpoint)
def run_script(self, script, chroot=True):
if not self._extra_mounted:
self.mount_extra()
name = os.path.basename(script)
dest = os.path.join(self.mountpoint, 'tmp', name)
log.debug('Copying {} to {}'.format(script, dest))
shutil.copy(script, dest)
os.chmod(dest, 0o0755)
kwargs = {
'env': self.SCRIPT_ENV.copy(),
}
if chroot:
cmd = ('chroot', self.mountpoint, os.path.join('/tmp', name))
else:
cmd = (dest,)
kwargs['env']['IMAGE_ROOT'] = self.mountpoint
kwargs['cwd'] = self.mountpoint
log.info('Running script: {}'.format(name))
try:
run_cmd_logged(*cmd, **kwargs)
finally:
os.unlink(dest)
def write_fstab(self, tmpfstmp=True):
tmpl = '{dev}\t{mountpoint}\t{fstype}\t{fsopts}\t0 {passno}\n'
log.info('Writing fstab')
with open(os.path.join(self.mountpoint, 'etc/fstab'), 'w') as fstab:
for mountpoint, dev, fstype, fsopts in sorted(self.filesystems):
line = tmpl.format(
dev=dev,
mountpoint=mountpoint,
fstype=fstype or self.default_fstype,
fsopts=fsopts,
passno=1 if mountpoint == '/' else 2,
)
log.debug(line.rstrip('\n'))
fstab.write(line)
if tmpfstmp:
line = tmpl.format(
dev='tmpfs',
mountpoint='/tmp',
fstype='tmpfs',
fsopts='defaults',
passno=0,
)
log.debug(line.rstrip('\n'))
fstab.write(line)
def gen_vgname(length=8):
name = random.choice(string.ascii_lowercase)
while len(name) < length - 1:
name += random.choice(string.ascii_letters + string.digits)
name += random.choice(string.ascii_lowercase)
return name
def inject_ssh_keys(root, ssh_keys):
ssh_dir = os.path.join(root, 'root', '.ssh')
if not os.path.isdir(ssh_dir):
os.makedirs(ssh_dir)
ssh_keys = set(ssh_keys)
try:
ssh_keys.remove(None)
except KeyError:
pass
else:
for path in TRY_SSH_KEYS:
if os.path.exists(path):
ssh_keys.append(path)
break
with open(os.path.join(ssh_dir, 'authorized_keys'), 'a') as dest:
for keyfile in ssh_keys:
with open(keyfile) as src:
pubkey = src.read().strip()
log.info('Injecting SSH public key: {}'.format(keyfile))
dest.write(pubkey + '\n')
def lvcreate(vg_name, name, size=None):
cmd = ['lvcreate', '-n', name]
if size is not None:
if '%' in size:
cmd += ('-l', size)
else:
cmd += ('-L', size)
else:
cmd += ('-l', '100%FREE')
cmd.append(vg_name)
run_cmd_logged(*cmd)
def parse_volumes(volstrings):
pat = re.compile(r'(?<!\\)=')
keys = ('size', 'fstype', 'name', 'fsopts', 'type')
volumes = []
for vol in volstrings:
parts = pat.split(vol, 1)
mountpoint = parts[0]
try:
pairs = itertools.zip_longest(keys, parts[1].split(';'))
params = dict((k, v) for k, v in pairs if v not in ('', None))
except IndexError:
params = {}
params['mountpoint'] = mountpoint
if 'name' not in params:
if mountpoint.startswith('/'):
name = mountpoint[1:].replace('/', '-')
else:
name = mountpoint
if name == '':
name = 'root'
params['name'] = name
if mountpoint == '/boot':
volumes.insert(0, (mountpoint, params))
elif mountpoint == '/':
volumes.insert(1, (mountpoint, params))
else:
volumes.append((mountpoint, params))
return collections.OrderedDict(volumes)
def parse_size(size):
if isinstance(size, int):
return size
m = SIZE_RE.match(size.lower())
if not m:
raise ValueError('Invalid size: {}'.format(size))
parts = m.groupdict()
if parts['unit'] in (None, 'b', 'byte', 'bytes'):
factor = 1
else:
try:
factor = UNITS[parts['unit']]
except KeyError:
raise ValueError('Invalid size : {}'.format(size))
return int(float(parts['value']) * factor)
def pvcreate(*pvs):
run_cmd_logged('pvcreate', *pvs)
def pvscan(dev=None):
cmd = ['pvscan']
if dev:
cmd += ('--cache', dev)
run_cmd_logged(*cmd)
def run_cmd(*args, **kwargs):
cmd = list(args)
log.debug('EXEC {}'.format(' '.join(cmd)))
with open(os.devnull) as nul:
p = subprocess.Popen(cmd, stdin=nul, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, **kwargs)
out, err = p.communicate()
if p.returncode != 0:
msg = err.strip().decode()
if not msg:
msg = 'Command exited with status {}'.format(p.returncode)
log.error(msg)
raise CommandError(msg)
else:
return out.strip().decode()
def run_cmd_logged(*args, **kwargs):
for line in run_cmd(*args, **kwargs).splitlines():
log.info(line)
def vgcreate(name, *pvs):
run_cmd_logged('vgcreate', name, *pvs)
def setup_logging(verbose=0):
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
stderr_handler = logging.StreamHandler()
stderr_handler.setLevel(logging.ERROR)
logger.addHandler(stderr_handler)
stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.addFilter(MaxLevelFilter())
logger.addHandler(stdout_handler)
if verbose < 1:
stdout_handler.setLevel(logging.WARNING)
elif verbose < 2:
stdout_handler.setLevel(logging.INFO)
else:
stdout_handler.setLevel(logging.DEBUG)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--verbose', '-v',
action='count',
default=0,
help='Print additional status information',
)
parser.add_argument(
'--size', '-s',
type=parse_size,
default='3G',
help='Disk image size',
)
parser.add_argument(
'--lvm', '-L',
metavar='VGNAME',
nargs='?',
help='Create volumes using LVM, in the volume group VGNAME. If a '
'volume group name is not specified, a unique name will be '
'generated.'
)
parser.add_argument(
'--no-lvm',
dest='lvm',
action='store_false',
help='Do not use LVM',
)
parser.add_argument(
'--volume', '-V',
metavar='VOL',
dest='volumes',
action='append',
help='Define a new volume, using the format '
'<mountpoint>[=<size>[;<fstype>[;<name>[;<fsopts>'
'[;<parttype>]]]]]',
)
parser.add_argument(
'--default-fstype',
metavar='FSTYPE',
help='Default filesystem type to use for volumes that do not specify '
'one of their own',
)
parser.add_argument(
'--no-tmpfs-tmp',
dest='tmpfstmp',
action='store_false',
default=True,
help='Do not write an fstab entry for /tmp on tmpfs',
)
parser.add_argument(
'--overlay', '-O',
metavar='FILENAME',
dest='overlays',
action='append',
default=[],
help='Extract the contents of FILENAME onto the image',
)
parser.add_argument(
'--script', '-S',
metavar='FILENAME',
dest='scripts',
action='append',
default=[],
help='Run a script inside the image',
)
parser.add_argument(
'--no-chroot',
dest='chroot',
action='store_false',
default=True,
help='Do not chroot into the image to run scripts - DANGEROUS!',
)
parser.add_argument(
'--inject-ssh-key', '-i',
metavar='FILENAME',
nargs='?',
dest='inject_ssh_keys',
action='append',
help='Pre-authorize SSH public keys for root',
)
parser.add_argument(
'stagetbz',
help='Path to stage tarball',
)
parser.add_argument(
'image',
nargs='?',
help='Path to destination image file',
)
args = parser.parse_args()
if args.image is None:
basename = os.path.splitext(args.stagetbz)[0]
if basename.endswith('.tar'):
basename = basename[:-4]
args.image = basename + '.img'
if args.volumes is None:
args.volumes = ['/']
return args
def main():
args = parse_args()
setup_logging(args.verbose)
if args.lvm is None:
args.lvm = gen_vgname()
volumes = parse_volumes(args.volumes)
if args.lvm:
try:
partitions = [volumes.pop('/boot'), LVM_PV_PART]
except KeyError:
partitions = None
else:
partitions = volumes.values()
try:
image = Image.create(args.image)
except OSError as e:
sys.stderr.write('Failed to create image: {}\n'.format(e))
raise SystemExit(os.EX_CANTCREAT)
if args.default_fstype:
image.default_fstype = args.default_fstype
with image:
image.resize(args.size)
if partitions:
image.partition(partitions)
if args.lvm:
image.setup_lvm(args.lvm, volumes.values())
image.make_filesystems()
image.extract(args.stagetbz)
for overlay in args.overlays:
image.extract(overlay)
image.write_fstab(tmpfstmp=args.tmpfstmp)
for script in args.scripts:
image.run_script(script, chroot=args.chroot)
if args.inject_ssh_keys:
inject_ssh_keys(image.mountpoint, args.inject_ssh_keys)
if __name__ == '__main__':
main()