configpolicy/roles/protonvpn/files/protonvpn-watchdog.py

343 lines
11 KiB
Python

#!/usr/bin/env python3
import abc
import argparse
import asyncio.subprocess
import json
import logging.handlers
import os
import random
import shlex
import signal
import sys
import time
from http import HTTPStatus
from pathlib import Path
from typing import Any, Dict, Iterable, Optional, Set, Tuple
import httpx
try:
import coloredlogs
except ImportError:
coloredlogs = None
log = logging.getLogger('protonvpn_watchdog')
Json = Dict[str, Any]
class BaseAsyncDaemon(abc.ABC):
def __init__(self) -> None:
self.loop: asyncio.AbstractEventLoop
self.stop_evt: asyncio.Event
self.killed = 0
def shutdown(self, signum: int) -> None:
if log.isEnabledFor(logging.INFO):
log.info(
'Received signal %d (%s), stopping',
signum,
signal.Signals(signum).name,
)
self.killed = signum
self.stop_evt.set()
async def main(self) -> None:
self.loop = asyncio.get_running_loop()
self.stop_evt = asyncio.Event()
self.loop.add_signal_handler(
signal.SIGINT, self.shutdown, signal.SIGINT
)
self.loop.add_signal_handler(
signal.SIGTERM, self.shutdown, signal.SIGTERM
)
t_run = asyncio.create_task(self.run())
t_stop = asyncio.create_task(self.stop_evt.wait())
done, pending = await asyncio.wait(
(t_run, t_stop), return_when=asyncio.FIRST_COMPLETED
)
rc = 0
for task in done:
try:
await task
except Exception:
log.exception('Unhandled error:')
rc = 1
for task in pending:
task.cancel()
try:
await self.stop()
except Exception: # pylint: disable=broad-except
log.exception('Error stopping daemon:')
rc = 1
if self.killed:
signal.signal(self.killed, signal.SIG_DFL)
os.kill(os.getpid(), self.killed)
else:
raise SystemExit(rc)
@abc.abstractmethod
async def run(self) -> None:
raise NotImplementedError
@abc.abstractmethod
async def stop(self) -> None:
raise NotImplementedError
class AsyncDaemon(BaseAsyncDaemon):
STRONGSWAN_UNIT = os.environ.get(
'PROTONVPN_STRONGSWAN_UNIT',
'strongswan.service',
)
IKE_SA_NAME = os.environ.get('PROTONVPN_IKE_SA_NAME', 'protonvpn')
SERVER_LIST = os.environ.get(
'PROTONVPN_SERVER_LIST', '/var/cache/protonvpn/serverlist.json'
)
SERVER_LIST_URL = os.environ.get(
'PROTONVPN_SERVERLIST_URL', 'https://api.protonmail.ch/vpn/logicals'
)
CONFIG = os.environ.get(
'PROTONVPN_CONFIG',
'/etc/strongswan/swanctl/conf.d/protonvpn.remote_addrs',
)
def __init__(self) -> None:
super().__init__()
self.jrnl_proc: asyncio.subprocess.Process
self.bad_servers: Set[str] = set()
async def get_serverlist(self) -> Optional[Json]:
try:
f = open(self.SERVER_LIST, encoding='utf-8')
except FileNotFoundError:
pass
else:
with f:
st = os.fstat(f.fileno())
now = time.time()
if st.st_mtime > now - 3600:
log.info('Using cached server list from %s', f.name)
try:
return json.load(f)
except ValueError as e:
log.warning('Failed to parse server list: %s', e)
async with httpx.AsyncClient() as client:
log.info('Fetching server list from %s', self.SERVER_LIST_URL)
r: httpx.Response = await client.get(self.SERVER_LIST_URL)
if r.status_code != HTTPStatus.OK:
log.error(
'Failed to fetch server list: HTTP status %s',
r.status_code,
)
return None
try:
data = r.json()
except ValueError as e:
log.error('Failed to parse server list: %s', e)
return None
path = Path(self.SERVER_LIST)
path.parent.mkdir(parents=True, exist_ok=True)
with path.open('w', encoding='utf-8') as f:
log.debug('Writing server list to %s', f.name)
json.dump(data, f)
return data
def iter_servers(self) -> Iterable[Tuple[int, str, str, str, str]]:
assert self.serverlist
for item in self.serverlist['LogicalServers']:
# Tier 0: Free
# Tier 1: Basic
# Tier 2: Plus
if item['Tier'] != 1:
continue
if item['EntryCountry'] != 'US':
continue
# Features 2: Tor
# Features 4: P2P
if item['Features']:
continue
for server in item['Servers']:
if server['Status'] != 1:
continue
yield (
item['Score'],
server['EntryIP'],
item['Name'],
item['EntryCountry'],
item['City'],
)
async def manage_serverlist(self) -> None:
while 1:
self.serverlist = await self.get_serverlist()
await asyncio.sleep(3600 + random.randint(0, 900))
def mark_bad(self, addr: str) -> None:
log.warning('Marking server %s as bad', addr)
self.bad_servers.add(addr)
async def on_message(self, data: Dict[str, str]) -> None:
log.debug('Got message: %s', data)
message = data.get('MESSAGE')
if not isinstance(message, str):
log.debug('Invalid message: %r', message)
return
if message.startswith('giving up'):
log.error('VPN tunnel failed!')
await self.reconfigure()
async def run(self) -> None:
log.info('Starting ProtonVPN Watchdog')
asyncio.ensure_future(self.manage_serverlist())
while 1:
try:
await self.watch_journal()
except Exception:
log.exception('Unhandled excption:')
async def reconfigure(self) -> None:
if not self.serverlist:
log.error('Cannot reconfigure: no known servers!')
return
fd = os.open(self.CONFIG, os.O_CREAT | os.O_RDWR, 0o644)
with open(fd, 'r+', encoding='utf-8') as f:
line = f.readline()
if line:
key, __, value = line.partition('=')
if not value or key.strip() != 'remote_addrs':
log.warning('Unexpected line in config: %s', line)
else:
self.mark_bad(value.strip())
f.seek(0)
f.truncate(0)
for _k, addr, name, country, city in sorted(self.iter_servers()):
if addr in self.bad_servers:
continue
log.info(
'Selected new server: %s (%s) %s/%s',
addr,
name,
country,
city,
)
f.write(f'remote_addrs = {addr}\n')
break
log.info('Reloading strongSwan connection configuration')
await self.swanctl('--load-conns')
async def stop(self) -> None:
if self.jrnl_proc.returncode is not None:
self.jrnl_proc.terminate()
async def swanctl(self, *args: str) -> None:
cmd = ['swanctl']
cmd += args
if log.isEnabledFor(logging.DEBUG):
log.debug('Running command: %s', list2cmdline(cmd))
p = await asyncio.create_subprocess_exec(
*cmd,
stdin=asyncio.subprocess.DEVNULL,
stdout=asyncio.subprocess.DEVNULL,
stderr=asyncio.subprocess.PIPE,
)
assert p.stderr
async for line in p.stderr:
line = line.rstrip()
if line:
log.error(line.decode(errors='replace'))
await p.wait()
async def watch_journal(self) -> None:
cmd = [
'journalctl',
'--output=json',
'--follow',
'--lines=0',
f'--unit={self.STRONGSWAN_UNIT}',
f'IKE_SA_NAME={self.IKE_SA_NAME}',
]
if log.isEnabledFor(logging.DEBUG):
log.debug('Running command: %s', list2cmdline(cmd))
self.jrnl_proc = await asyncio.create_subprocess_exec(
*cmd,
stdin=asyncio.subprocess.DEVNULL,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
assert self.jrnl_proc.stdout
assert self.jrnl_proc.stderr
async def handle_stdout():
async for line in self.jrnl_proc.stdout:
try:
data = json.loads(line)
except ValueError as e:
log.error('Failed to parse journal entry: %s', e)
continue
try:
await self.on_message(data)
except Exception:
log.exception('Error handling message:')
continue
async def handle_stderr():
async for line in self.jrnl_proc.stderr:
log.error(line.rstrip().decode(errors='replace'))
asyncio.ensure_future(handle_stdout())
asyncio.ensure_future(handle_stderr())
await self.jrnl_proc.wait()
self.stop_evt.set()
def list2cmdline(args):
return ' '.join(shlex.quote(a) for a in args)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument('-v', '--verbose', action='count', default=0)
return parser.parse_args()
def setup_logging(verbose: int):
if verbose < 1:
level = logging.WARNING
elif verbose < 2:
level = logging.INFO
else:
level = logging.DEBUG
handler: logging.Handler
if sys.stderr.isatty():
if coloredlogs is not None:
coloredlogs.install(level=level)
return
handler = logging.StreamHandler()
else:
handler = logging.handlers.SysLogHandler(address='/dev/log')
handler.setFormatter(
logging.Formatter(
'%(asctime)s %(name)s: %(levelname)s %(message)s',
'%b %e %H:%M:%S',
)
)
handler.setLevel(level)
logging.root.setLevel(level)
logging.root.addHandler(handler)
def main() -> None:
args = parse_args()
setup_logging(args.verbose)
daemon = AsyncDaemon()
asyncio.run(daemon.main())
if __name__ == '__main__':
main()