#!/usr/bin/env python3 import abc import argparse import asyncio.subprocess import json import logging.handlers import os import random import shlex import signal import sys import time from http import HTTPStatus from pathlib import Path from typing import Any, Dict, Iterable, Optional, Set, Tuple import httpx try: import coloredlogs except ImportError: coloredlogs = None log = logging.getLogger('protonvpn_watchdog') Json = Dict[str, Any] class BaseAsyncDaemon(abc.ABC): def __init__(self) -> None: self.loop: asyncio.AbstractEventLoop self.stop_evt: asyncio.Event self.killed = 0 def shutdown(self, signum: int) -> None: if log.isEnabledFor(logging.INFO): log.info( 'Received signal %d (%s), stopping', signum, signal.Signals(signum).name, ) self.killed = signum self.stop_evt.set() async def main(self) -> None: self.loop = asyncio.get_running_loop() self.stop_evt = asyncio.Event() self.loop.add_signal_handler( signal.SIGINT, self.shutdown, signal.SIGINT ) self.loop.add_signal_handler( signal.SIGTERM, self.shutdown, signal.SIGTERM ) t_run = asyncio.create_task(self.run()) t_stop = asyncio.create_task(self.stop_evt.wait()) done, pending = await asyncio.wait( (t_run, t_stop), return_when=asyncio.FIRST_COMPLETED ) rc = 0 for task in done: try: await task except Exception: log.exception('Unhandled error:') rc = 1 for task in pending: task.cancel() try: await self.stop() except Exception: # pylint: disable=broad-except log.exception('Error stopping daemon:') rc = 1 if self.killed: signal.signal(self.killed, signal.SIG_DFL) os.kill(os.getpid(), self.killed) else: raise SystemExit(rc) @abc.abstractmethod async def run(self) -> None: raise NotImplementedError @abc.abstractmethod async def stop(self) -> None: raise NotImplementedError class AsyncDaemon(BaseAsyncDaemon): STRONGSWAN_UNIT = os.environ.get( 'PROTONVPN_STRONGSWAN_UNIT', 'strongswan.service', ) IKE_SA_NAME = os.environ.get('PROTONVPN_IKE_SA_NAME', 'protonvpn') SERVER_LIST = os.environ.get( 'PROTONVPN_SERVER_LIST', '/var/cache/protonvpn/serverlist.json' ) SERVER_LIST_URL = os.environ.get( 'PROTONVPN_SERVERLIST_URL', 'https://api.protonmail.ch/vpn/logicals' ) CONFIG = os.environ.get( 'PROTONVPN_CONFIG', '/etc/strongswan/swanctl/conf.d/protonvpn.remote_addrs', ) def __init__(self) -> None: super().__init__() self.jrnl_proc: asyncio.subprocess.Process self.bad_servers: Set[str] = set() async def get_serverlist(self) -> Optional[Json]: try: f = open(self.SERVER_LIST, encoding='utf-8') except FileNotFoundError: pass else: with f: st = os.fstat(f.fileno()) now = time.time() if st.st_mtime > now - 3600: log.info('Using cached server list from %s', f.name) try: return json.load(f) except ValueError as e: log.warning('Failed to parse server list: %s', e) async with httpx.AsyncClient() as client: log.info('Fetching server list from %s', self.SERVER_LIST_URL) r: httpx.Response = await client.get(self.SERVER_LIST_URL) if r.status_code != HTTPStatus.OK: log.error( 'Failed to fetch server list: HTTP status %s', r.status_code, ) return None try: data = r.json() except ValueError as e: log.error('Failed to parse server list: %s', e) return None path = Path(self.SERVER_LIST) path.parent.mkdir(parents=True, exist_ok=True) with path.open('w', encoding='utf-8') as f: log.debug('Writing server list to %s', f.name) json.dump(data, f) return data def iter_servers(self) -> Iterable[Tuple[int, str, str, str, str]]: assert self.serverlist for item in self.serverlist['LogicalServers']: # Tier 0: Free # Tier 1: Basic # Tier 2: Plus if item['Tier'] != 1: continue if item['EntryCountry'] != 'US': continue # Features 2: Tor # Features 4: P2P if item['Features']: continue for server in item['Servers']: if server['Status'] != 1: continue yield ( item['Score'], server['EntryIP'], item['Name'], item['EntryCountry'], item['City'], ) async def manage_serverlist(self) -> None: while 1: self.serverlist = await self.get_serverlist() await asyncio.sleep(3600 + random.randint(0, 900)) def mark_bad(self, addr: str) -> None: log.warning('Marking server %s as bad', addr) self.bad_servers.add(addr) async def on_message(self, data: Dict[str, str]) -> None: log.debug('Got message: %s', data) message = data.get('MESSAGE') if not isinstance(message, str): log.debug('Invalid message: %r', message) return if message.startswith('giving up'): log.error('VPN tunnel failed!') await self.reconfigure() async def run(self) -> None: log.info('Starting ProtonVPN Watchdog') asyncio.ensure_future(self.manage_serverlist()) while 1: try: await self.watch_journal() except Exception: log.exception('Unhandled excption:') async def reconfigure(self) -> None: if not self.serverlist: log.error('Cannot reconfigure: no known servers!') return fd = os.open(self.CONFIG, os.O_CREAT | os.O_RDWR, 0o644) with open(fd, 'r+', encoding='utf-8') as f: line = f.readline() if line: key, __, value = line.partition('=') if not value or key.strip() != 'remote_addrs': log.warning('Unexpected line in config: %s', line) else: self.mark_bad(value.strip()) f.seek(0) f.truncate(0) for _k, addr, name, country, city in sorted(self.iter_servers()): if addr in self.bad_servers: continue log.info( 'Selected new server: %s (%s) %s/%s', addr, name, country, city, ) f.write(f'remote_addrs = {addr}\n') break log.info('Reloading strongSwan connection configuration') await self.swanctl('--load-conns') async def stop(self) -> None: if self.jrnl_proc.returncode is not None: self.jrnl_proc.terminate() async def swanctl(self, *args: str) -> None: cmd = ['swanctl'] cmd += args if log.isEnabledFor(logging.DEBUG): log.debug('Running command: %s', list2cmdline(cmd)) p = await asyncio.create_subprocess_exec( *cmd, stdin=asyncio.subprocess.DEVNULL, stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.PIPE, ) assert p.stderr async for line in p.stderr: line = line.rstrip() if line: log.error(line.decode(errors='replace')) await p.wait() async def watch_journal(self) -> None: cmd = [ 'journalctl', '--output=json', '--follow', '--lines=0', f'--unit={self.STRONGSWAN_UNIT}', f'IKE_SA_NAME={self.IKE_SA_NAME}', ] if log.isEnabledFor(logging.DEBUG): log.debug('Running command: %s', list2cmdline(cmd)) self.jrnl_proc = await asyncio.create_subprocess_exec( *cmd, stdin=asyncio.subprocess.DEVNULL, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) assert self.jrnl_proc.stdout assert self.jrnl_proc.stderr async def handle_stdout(): async for line in self.jrnl_proc.stdout: try: data = json.loads(line) except ValueError as e: log.error('Failed to parse journal entry: %s', e) continue try: await self.on_message(data) except Exception: log.exception('Error handling message:') continue async def handle_stderr(): async for line in self.jrnl_proc.stderr: log.error(line.rstrip().decode(errors='replace')) asyncio.ensure_future(handle_stdout()) asyncio.ensure_future(handle_stderr()) await self.jrnl_proc.wait() self.stop_evt.set() def list2cmdline(args): return ' '.join(shlex.quote(a) for a in args) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument('-v', '--verbose', action='count', default=0) return parser.parse_args() def setup_logging(verbose: int): if verbose < 1: level = logging.WARNING elif verbose < 2: level = logging.INFO else: level = logging.DEBUG handler: logging.Handler if sys.stderr.isatty(): if coloredlogs is not None: coloredlogs.install(level=level) return handler = logging.StreamHandler() else: handler = logging.handlers.SysLogHandler(address='/dev/log') handler.setFormatter( logging.Formatter( '%(asctime)s %(name)s: %(levelname)s %(message)s', '%b %e %H:%M:%S', ) ) handler.setLevel(level) logging.root.setLevel(level) logging.root.addHandler(handler) def main() -> None: args = parse_args() setup_logging(args.verbose) daemon = AsyncDaemon() asyncio.run(daemon.main()) if __name__ == '__main__': main()