1
0
Fork 0
xactfetch/secretsocket.py

205 lines
5.9 KiB
Python
Executable File

#!/usr/bin/env python3
import asyncio
import logging
import os
import shlex
import signal
import socket
import struct
import tomllib
from pathlib import Path
from typing import Optional
log = logging.getLogger('secretsocket')
ALLOW_UNKNOWN_PEER = os.environ.get('ALLOW_UNKNOWN_PEER') == '1'
SECRET_SOCKET_PATH = os.environ.get('SECRET_SOCKET_PATH')
XDG_RUNTIME_DIR = os.environ.get('XDG_RUNTIME_DIR')
class Secret:
async def lookup(self) -> Optional[bytes]:
raise NotImplementedError
class EnvSecret(Secret):
def __init__(self, env_var: str) -> None:
self.env_var = env_var
async def lookup(self) -> Optional[bytes]:
return os.environb.get(self.env_var.encode('utf-8'))
class ExecSecret(Secret):
def __init__(self, cmd: str) -> None:
self.cmd = cmd
async def lookup(self) -> Optional[bytes]:
args = shlex.split(self.cmd)
proc = await asyncio.create_subprocess_exec(
*args,
stdin=asyncio.subprocess.DEVNULL,
stdout=asyncio.subprocess.PIPE,
)
o = await proc.communicate()
return o[0]
class PathSecret(Secret):
def __init__(self, path: Path) -> None:
self.path = path
async def lookup(self) -> Optional[bytes]:
try:
f = self.path.expanduser().open('rb')
except OSError as e:
log.error('Failed to read secret from %s: %s', self.path, e)
return None
with f:
return await asyncio.to_thread(f.read)
class StringSecret(Secret):
def __init__(self, value: str) -> None:
self.value = value
async def lookup(self) -> Optional[bytes]:
return self.value.encode('utf-8')
class SecretServer:
def __init__(self, path: Optional[Path] = None) -> None:
self.path = path
async def handle_client(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
):
try:
sock = writer.get_extra_info('socket')
pid, uid, gid = get_socket_peercred(sock)
except Exception as e:
log.error('Failed to get peer credentials: %s', e)
pid, uid, gid = None, None, None
log.debug('Client connected (pid %d, uid %d, gid %d)', pid, uid, gid)
if uid is None:
if ALLOW_UNKNOWN_PEER:
log.warning('Handling connection from unknown peer')
else:
log.error('Refusing to handle connection from unknown peer')
writer.close()
return
else:
my_uid = os.getresuid()[1]
if uid != my_uid:
log.error(
'Refusing to handle connection from PID %d: '
'peer UID %d does not match %d',
pid,
uid,
my_uid,
)
writer.close()
return
while 1:
try:
key = (await reader.readuntil(b'\n')).rstrip(b'\n').decode()
except asyncio.IncompleteReadError:
break
else:
log.info('Client %d requested secret %s', pid, key)
try:
secret = await self.get_secret(key)
except Exception:
log.exception('Failed to get secret:')
writer.close()
return
else:
writer.write(secret + b'\n')
await writer.drain()
log.debug('Client disconnected')
writer.close()
await writer.wait_closed()
async def get_secret(self, key: str) -> bytes:
secrets = await load_secrets(self.path)
if secret := secrets.get(key):
if value := await secret.lookup():
return value
else:
log.warning('Lookup of secret %s failed', key)
else:
log.warning('Unknown secret: %s', key)
return b''
def get_socket_peercred(sock: socket.socket) -> tuple[int, int, int]:
struct_ucred = '=iii'
buflen = struct.calcsize(struct_ucred)
cred = sock.getsockopt(socket.SOL_SOCKET, socket.SO_PEERCRED, buflen)
return struct.unpack(struct_ucred, cred)
async def load_secrets(path: Optional[Path] = None) -> dict[str, Secret]:
if path is None:
path = Path('secrets.toml')
secrets = {}
try:
f = path.open('rb')
except OSError as e:
log.error('Failed to load secrets: %s', e)
return secrets
with f:
config = await asyncio.to_thread(tomllib.load, f)
for key, value in config.items():
if 'env' in value:
secrets[key] = EnvSecret(value['env'])
elif 'exec' in value:
secrets[key] = ExecSecret(value['exec'])
elif 'path' in value:
secrets[key] = PathSecret(Path(value['path']))
elif 'string' in value:
secrets[key] = StringSecret(value['string'])
else:
log.warning(
'Unsupported configuration for secret %s: %r', key, value
)
return secrets
def shutdown(signum, server):
log.info('Received signal %d, shutting down', signum)
server.close()
async def main():
logging.basicConfig(level=logging.DEBUG)
if SECRET_SOCKET_PATH:
sock_path = Path(SECRET_SOCKET_PATH)
elif XDG_RUNTIME_DIR:
sock_path = Path(XDG_RUNTIME_DIR) / 'secretsocket/.ss'
else:
sock_path = Path('/tmp/.secretsocket')
if not sock_path.parent.exists():
sock_path.parent.mkdir()
if sock_path.exists():
sock_path.unlink()
ss = SecretServer()
server = await asyncio.start_unix_server(ss.handle_client, path=sock_path)
async with server:
await server.start_serving()
loop = asyncio.get_running_loop()
for signum in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(signum, shutdown, signum, server)
await server.wait_closed()
if __name__ == '__main__':
asyncio.run(main())