diff --git a/roles/protonvpn/files/protonvpn-watchdog.py b/roles/protonvpn/files/protonvpn-watchdog.py new file mode 100644 index 0000000..fec97ad --- /dev/null +++ b/roles/protonvpn/files/protonvpn-watchdog.py @@ -0,0 +1,342 @@ +#!/usr/bin/env python3 +import abc +import argparse +import asyncio.subprocess +import json +import logging.handlers +import os +import random +import shlex +import signal +import sys +import time +from http import HTTPStatus +from pathlib import Path +from typing import Any, Dict, Iterable, Optional, Set, Tuple + +import httpx + + +try: + import coloredlogs +except ImportError: + coloredlogs = None + + +log = logging.getLogger('protonvpn_watchdog') + + +Json = Dict[str, Any] + + +class BaseAsyncDaemon(abc.ABC): + def __init__(self) -> None: + self.loop: asyncio.AbstractEventLoop + self.stop_evt: asyncio.Event + self.killed = 0 + + def shutdown(self, signum: int) -> None: + if log.isEnabledFor(logging.INFO): + log.info( + 'Received signal %d (%s), stopping', + signum, + signal.Signals(signum).name, + ) + self.killed = signum + self.stop_evt.set() + + async def main(self) -> None: + self.loop = asyncio.get_running_loop() + self.stop_evt = asyncio.Event() + self.loop.add_signal_handler( + signal.SIGINT, self.shutdown, signal.SIGINT + ) + self.loop.add_signal_handler( + signal.SIGTERM, self.shutdown, signal.SIGTERM + ) + t_run = asyncio.create_task(self.run()) + t_stop = asyncio.create_task(self.stop_evt.wait()) + done, pending = await asyncio.wait( + (t_run, t_stop), return_when=asyncio.FIRST_COMPLETED + ) + rc = 0 + for task in done: + try: + await task + except Exception: + log.exception('Unhandled error:') + rc = 1 + for task in pending: + task.cancel() + try: + await self.stop() + except Exception: # pylint: disable=broad-except + log.exception('Error stopping daemon:') + rc = 1 + if self.killed: + signal.signal(self.killed, signal.SIG_DFL) + os.kill(os.getpid(), self.killed) + else: + raise SystemExit(rc) + + @abc.abstractmethod + async def run(self) -> None: + raise NotImplementedError + + @abc.abstractmethod + async def stop(self) -> None: + raise NotImplementedError + + +class AsyncDaemon(BaseAsyncDaemon): + STRONGSWAN_UNIT = os.environ.get( + 'PROTONVPN_STRONGSWAN_UNIT', + 'strongswan.service', + ) + IKE_SA_NAME = os.environ.get('PROTONVPN_IKE_SA_NAME', 'protonvpn') + SERVER_LIST = os.environ.get( + 'PROTONVPN_SERVER_LIST', '/var/cache/protonvpn/serverlist.json' + ) + SERVER_LIST_URL = os.environ.get( + 'PROTONVPN_SERVERLIST_URL', 'https://api.protonmail.ch/vpn/logicals' + ) + CONFIG = os.environ.get( + 'PROTONVPN_CONFIG', + '/etc/strongswan/swanctl/conf.d/protonvpn.remote_addrs', + ) + + def __init__(self) -> None: + super().__init__() + self.jrnl_proc: asyncio.subprocess.Process + self.bad_servers: Set[str] = set() + + async def get_serverlist(self) -> Optional[Json]: + try: + f = open(self.SERVER_LIST, encoding='utf-8') + except FileNotFoundError: + pass + else: + with f: + st = os.fstat(f.fileno()) + now = time.time() + if st.st_mtime > now - 3600: + log.info('Using cached server list from %s', f.name) + try: + return json.load(f) + except ValueError as e: + log.warning('Failed to parse server list: %s', e) + async with httpx.AsyncClient() as client: + log.info('Fetching server list from %s', self.SERVER_LIST_URL) + r: httpx.Response = await client.get(self.SERVER_LIST_URL) + if r.status_code != HTTPStatus.OK: + log.error( + 'Failed to fetch server list: HTTP status %s', + r.status_code, + ) + return None + try: + data = r.json() + except ValueError as e: + log.error('Failed to parse server list: %s', e) + return None + path = Path(self.SERVER_LIST) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open('w', encoding='utf-8') as f: + log.debug('Writing server list to %s', f.name) + json.dump(data, f) + return data + + def iter_servers(self) -> Iterable[Tuple[int, str, str, str, str]]: + assert self.serverlist + for item in self.serverlist['LogicalServers']: + # Tier 0: Free + # Tier 1: Basic + # Tier 2: Plus + if item['Tier'] != 1: + continue + if item['EntryCountry'] != 'US': + continue + # Features 2: Tor + # Features 4: P2P + if item['Features']: + continue + for server in item['Servers']: + if server['Status'] != 1: + continue + yield ( + item['Score'], + server['EntryIP'], + item['Name'], + item['EntryCountry'], + item['City'], + ) + + async def manage_serverlist(self) -> None: + while 1: + self.serverlist = await self.get_serverlist() + await asyncio.sleep(3600 + random.randint(0, 900)) + + def mark_bad(self, addr: str) -> None: + log.warning('Marking server %s as bad', addr) + self.bad_servers.add(addr) + + async def on_message(self, data: Dict[str, str]) -> None: + log.debug('Got message: %s', data) + message = data.get('MESSAGE') + if not isinstance(message, str): + log.debug('Invalid message: %r', message) + return + if message.startswith('giving up'): + log.error('VPN tunnel failed!') + await self.reconfigure() + + async def run(self) -> None: + log.info('Starting ProtonVPN Watchdog') + asyncio.ensure_future(self.manage_serverlist()) + while 1: + try: + await self.watch_journal() + except Exception: + log.exception('Unhandled excption:') + + async def reconfigure(self) -> None: + if not self.serverlist: + log.error('Cannot reconfigure: no known servers!') + return + fd = os.open(self.CONFIG, os.O_CREAT | os.O_RDWR, 0o644) + with open(fd, 'r+', encoding='utf-8') as f: + line = f.readline() + if line: + key, __, value = line.partition('=') + if not value or key.strip() != 'remote_addrs': + log.warning('Unexpected line in config: %s', line) + else: + self.mark_bad(value.strip()) + f.seek(0) + f.truncate(0) + for _k, addr, name, country, city in sorted(self.iter_servers()): + if addr in self.bad_servers: + continue + log.info( + 'Selected new server: %s (%s) %s/%s', + addr, + name, + country, + city, + ) + f.write(f'remote_addrs = {addr}\n') + break + log.info('Reloading strongSwan connection configuration') + await self.swanctl('--load-conns') + + async def stop(self) -> None: + if self.jrnl_proc.returncode is not None: + self.jrnl_proc.terminate() + + async def swanctl(self, *args: str) -> None: + cmd = ['swanctl'] + cmd += args + if log.isEnabledFor(logging.DEBUG): + log.debug('Running command: %s', list2cmdline(cmd)) + p = await asyncio.create_subprocess_exec( + *cmd, + stdin=asyncio.subprocess.DEVNULL, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.PIPE, + ) + assert p.stderr + async for line in p.stderr: + line = line.rstrip() + if line: + log.error(line.decode(errors='replace')) + await p.wait() + + async def watch_journal(self) -> None: + cmd = [ + 'journalctl', + '--output=json', + '--follow', + '--lines=0', + f'--unit={self.STRONGSWAN_UNIT}', + f'IKE_SA_NAME={self.IKE_SA_NAME}', + ] + if log.isEnabledFor(logging.DEBUG): + log.debug('Running command: %s', list2cmdline(cmd)) + self.jrnl_proc = await asyncio.create_subprocess_exec( + *cmd, + stdin=asyncio.subprocess.DEVNULL, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + assert self.jrnl_proc.stdout + assert self.jrnl_proc.stderr + + async def handle_stdout(): + async for line in self.jrnl_proc.stdout: + try: + data = json.loads(line) + except ValueError as e: + log.error('Failed to parse journal entry: %s', e) + continue + try: + await self.on_message(data) + except Exception: + log.exception('Error handling message:') + continue + + async def handle_stderr(): + async for line in self.jrnl_proc.stderr: + log.error(line.rstrip().decode(errors='replace')) + + asyncio.ensure_future(handle_stdout()) + asyncio.ensure_future(handle_stderr()) + await self.jrnl_proc.wait() + self.stop_evt.set() + + +def list2cmdline(args): + return ' '.join(shlex.quote(a) for a in args) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument('-v', '--verbose', action='count', default=0) + return parser.parse_args() + + +def setup_logging(verbose: int): + if verbose < 1: + level = logging.WARNING + elif verbose < 2: + level = logging.INFO + else: + level = logging.DEBUG + handler: logging.Handler + if sys.stderr.isatty(): + if coloredlogs is not None: + coloredlogs.install(level=level) + return + handler = logging.StreamHandler() + else: + handler = logging.handlers.SysLogHandler(address='/dev/log') + handler.setFormatter( + logging.Formatter( + '%(asctime)s %(name)s: %(levelname)s %(message)s', + '%b %e %H:%M:%S', + ) + ) + + handler.setLevel(level) + logging.root.setLevel(level) + logging.root.addHandler(handler) + + +def main() -> None: + args = parse_args() + setup_logging(args.verbose) + daemon = AsyncDaemon() + asyncio.run(daemon.main()) + + +if __name__ == '__main__': + main() diff --git a/roles/protonvpn/files/protonvpn-watchdog.service b/roles/protonvpn/files/protonvpn-watchdog.service new file mode 100644 index 0000000..299662e --- /dev/null +++ b/roles/protonvpn/files/protonvpn-watchdog.service @@ -0,0 +1,8 @@ +[Unit] +Description=ProtonVPN Watchdog + +[Service] +ExecStart=/usr/local/bin/protonvpn-watchdog -v + +[Install] +WantedBy=multi-user.target diff --git a/roles/protonvpn/handlers/main.yml b/roles/protonvpn/handlers/main.yml index 3e98587..2bae44c 100644 --- a/roles/protonvpn/handlers/main.yml +++ b/roles/protonvpn/handlers/main.yml @@ -1,3 +1,10 @@ +- name: reload systemd + command: + systemctl daemon-reload - name: reload strongswan config command: swanctl --load-all +- name: restart protonvpn-watchdog + service: + name: protonvpn-watchdog + state: restarted diff --git a/roles/protonvpn/tasks/main.yml b/roles/protonvpn/tasks/main.yml index dd9ee60..d663f36 100644 --- a/roles/protonvpn/tasks/main.yml +++ b/roles/protonvpn/tasks/main.yml @@ -16,3 +16,46 @@ tags: - strongswan-config - protonvpn-config +- name: ensure protonvpn remote address is configured + copy: + dest: /etc/strongswan/swanctl/conf.d/protonvpn.remote_addrs + mode: '0640' + content: > + remote_addrs = {{ protonvpn_server }} + force: false + notify: reload strongswan config + tags: + - strongswan-config + - protonvpn-config + +- name: ensure protonvpn-watchdog script is installed + copy: + src: protonvpn-watchdog.py + dest: /usr/local/bin/protonvpn-watchdog + mode: '0755' + notify: restart protonvpn-watchdog + tags: + - protonvpn-watchdog +- name: ensure protonvpn-watchdog systemd unit is installed + copy: + src: protonvpn-watchdog.service + dest: /etc/systemd/system/protonvpn-watchdog.service + mode: '0644' + notify: + - reload systemd + - restart protonvpn-watchdog + tags: + - protonvpn-watchdog + - systemd +- name: ensure protonvpn-watchdog service is enabled + service: + name: protonvpn-watchdog + enabled: true + tags: + - service +- name: ensure protonvpn-watchdog service is running + service: + name: protonvpn-watchdog + state: started + tags: + - service diff --git a/roles/protonvpn/templates/protonvpn.conf.j2 b/roles/protonvpn/templates/protonvpn.conf.j2 index 8704770..880cced 100644 --- a/roles/protonvpn/templates/protonvpn.conf.j2 +++ b/roles/protonvpn/templates/protonvpn.conf.j2 @@ -1,10 +1,10 @@ connections { protonvpn { - local_addrs = %defaultroute - remote_addrs = {{ protonvpn_server }} + local_addrs = %any + include protonvpn.remote_addrs vips = 0.0.0.0,:: keyingtries = 0 - dpd_delay = 30s + dpd_delay = 10s local { auth = eap-mschapv2 eap_id = {{ protonvpn_username }}