Occasionally, ProtonVPN servers randomly reject the EAP authentication credentials. When this happens, the tunnel fails and is not restarted automatically by strongSwan. As such, the watchdog needs to react to this event as well.
343 lines
11 KiB
Python
343 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
import abc
|
|
import argparse
|
|
import asyncio.subprocess
|
|
import json
|
|
import logging.handlers
|
|
import os
|
|
import random
|
|
import shlex
|
|
import signal
|
|
import sys
|
|
import time
|
|
from http import HTTPStatus
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Iterable, Optional, Set, Tuple
|
|
|
|
import httpx
|
|
|
|
|
|
try:
|
|
import coloredlogs
|
|
except ImportError:
|
|
coloredlogs = None
|
|
|
|
|
|
log = logging.getLogger('protonvpn_watchdog')
|
|
|
|
|
|
Json = Dict[str, Any]
|
|
|
|
|
|
class BaseAsyncDaemon(abc.ABC):
|
|
def __init__(self) -> None:
|
|
self.loop: asyncio.AbstractEventLoop
|
|
self.stop_evt: asyncio.Event
|
|
self.killed = 0
|
|
|
|
def shutdown(self, signum: int) -> None:
|
|
if log.isEnabledFor(logging.INFO):
|
|
log.info(
|
|
'Received signal %d (%s), stopping',
|
|
signum,
|
|
signal.Signals(signum).name,
|
|
)
|
|
self.killed = signum
|
|
self.stop_evt.set()
|
|
|
|
async def main(self) -> None:
|
|
self.loop = asyncio.get_running_loop()
|
|
self.stop_evt = asyncio.Event()
|
|
self.loop.add_signal_handler(
|
|
signal.SIGINT, self.shutdown, signal.SIGINT
|
|
)
|
|
self.loop.add_signal_handler(
|
|
signal.SIGTERM, self.shutdown, signal.SIGTERM
|
|
)
|
|
t_run = asyncio.create_task(self.run())
|
|
t_stop = asyncio.create_task(self.stop_evt.wait())
|
|
done, pending = await asyncio.wait(
|
|
(t_run, t_stop), return_when=asyncio.FIRST_COMPLETED
|
|
)
|
|
rc = 0
|
|
for task in done:
|
|
try:
|
|
await task
|
|
except Exception:
|
|
log.exception('Unhandled error:')
|
|
rc = 1
|
|
for task in pending:
|
|
task.cancel()
|
|
try:
|
|
await self.stop()
|
|
except Exception: # pylint: disable=broad-except
|
|
log.exception('Error stopping daemon:')
|
|
rc = 1
|
|
if self.killed:
|
|
signal.signal(self.killed, signal.SIG_DFL)
|
|
os.kill(os.getpid(), self.killed)
|
|
else:
|
|
raise SystemExit(rc)
|
|
|
|
@abc.abstractmethod
|
|
async def run(self) -> None:
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
async def stop(self) -> None:
|
|
raise NotImplementedError
|
|
|
|
|
|
class AsyncDaemon(BaseAsyncDaemon):
|
|
STRONGSWAN_UNIT = os.environ.get(
|
|
'PROTONVPN_STRONGSWAN_UNIT',
|
|
'strongswan.service',
|
|
)
|
|
IKE_SA_NAME = os.environ.get('PROTONVPN_IKE_SA_NAME', 'protonvpn')
|
|
SERVER_LIST = os.environ.get(
|
|
'PROTONVPN_SERVER_LIST', '/var/cache/protonvpn/serverlist.json'
|
|
)
|
|
SERVER_LIST_URL = os.environ.get(
|
|
'PROTONVPN_SERVERLIST_URL', 'https://api.protonmail.ch/vpn/logicals'
|
|
)
|
|
CONFIG = os.environ.get(
|
|
'PROTONVPN_CONFIG',
|
|
'/etc/strongswan/swanctl/conf.d/protonvpn.remote_addrs',
|
|
)
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.jrnl_proc: asyncio.subprocess.Process
|
|
self.bad_servers: Set[str] = set()
|
|
|
|
async def get_serverlist(self) -> Optional[Json]:
|
|
try:
|
|
f = open(self.SERVER_LIST, encoding='utf-8')
|
|
except FileNotFoundError:
|
|
pass
|
|
else:
|
|
with f:
|
|
st = os.fstat(f.fileno())
|
|
now = time.time()
|
|
if st.st_mtime > now - 3600:
|
|
log.info('Using cached server list from %s', f.name)
|
|
try:
|
|
return json.load(f)
|
|
except ValueError as e:
|
|
log.warning('Failed to parse server list: %s', e)
|
|
async with httpx.AsyncClient() as client:
|
|
log.info('Fetching server list from %s', self.SERVER_LIST_URL)
|
|
r: httpx.Response = await client.get(self.SERVER_LIST_URL)
|
|
if r.status_code != HTTPStatus.OK:
|
|
log.error(
|
|
'Failed to fetch server list: HTTP status %s',
|
|
r.status_code,
|
|
)
|
|
return None
|
|
try:
|
|
data = r.json()
|
|
except ValueError as e:
|
|
log.error('Failed to parse server list: %s', e)
|
|
return None
|
|
path = Path(self.SERVER_LIST)
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
with path.open('w', encoding='utf-8') as f:
|
|
log.debug('Writing server list to %s', f.name)
|
|
json.dump(data, f)
|
|
return data
|
|
|
|
def iter_servers(self) -> Iterable[Tuple[int, str, str, str, str]]:
|
|
assert self.serverlist
|
|
for item in self.serverlist['LogicalServers']:
|
|
# Tier 0: Free
|
|
# Tier 1: Basic
|
|
# Tier 2: Plus
|
|
if item['Tier'] != 1:
|
|
continue
|
|
if item['EntryCountry'] != 'US':
|
|
continue
|
|
# Features 2: Tor
|
|
# Features 4: P2P
|
|
if item['Features']:
|
|
continue
|
|
for server in item['Servers']:
|
|
if server['Status'] != 1:
|
|
continue
|
|
yield (
|
|
item['Score'],
|
|
server['EntryIP'],
|
|
item['Name'],
|
|
item['EntryCountry'],
|
|
item['City'],
|
|
)
|
|
|
|
async def manage_serverlist(self) -> None:
|
|
while 1:
|
|
self.serverlist = await self.get_serverlist()
|
|
await asyncio.sleep(3600 + random.randint(0, 900))
|
|
|
|
def mark_bad(self, addr: str) -> None:
|
|
log.warning('Marking server %s as bad', addr)
|
|
self.bad_servers.add(addr)
|
|
|
|
async def on_message(self, data: Dict[str, str]) -> None:
|
|
log.debug('Got message: %s', data)
|
|
message = data.get('MESSAGE')
|
|
if not isinstance(message, str):
|
|
log.debug('Invalid message: %r', message)
|
|
return
|
|
if message.startswith('giving up') or 'EAP/FAIL' in message:
|
|
log.error('VPN tunnel failed!')
|
|
await self.reconfigure()
|
|
|
|
async def run(self) -> None:
|
|
log.info('Starting ProtonVPN Watchdog')
|
|
asyncio.ensure_future(self.manage_serverlist())
|
|
while 1:
|
|
try:
|
|
await self.watch_journal()
|
|
except Exception:
|
|
log.exception('Unhandled excption:')
|
|
|
|
async def reconfigure(self) -> None:
|
|
if not self.serverlist:
|
|
log.error('Cannot reconfigure: no known servers!')
|
|
return
|
|
fd = os.open(self.CONFIG, os.O_CREAT | os.O_RDWR, 0o644)
|
|
with open(fd, 'r+', encoding='utf-8') as f:
|
|
line = f.readline()
|
|
if line:
|
|
key, __, value = line.partition('=')
|
|
if not value or key.strip() != 'remote_addrs':
|
|
log.warning('Unexpected line in config: %s', line)
|
|
else:
|
|
self.mark_bad(value.strip())
|
|
f.seek(0)
|
|
f.truncate(0)
|
|
for _k, addr, name, country, city in sorted(self.iter_servers()):
|
|
if addr in self.bad_servers:
|
|
continue
|
|
log.info(
|
|
'Selected new server: %s (%s) %s/%s',
|
|
addr,
|
|
name,
|
|
country,
|
|
city,
|
|
)
|
|
f.write(f'remote_addrs = {addr}\n')
|
|
break
|
|
log.info('Reloading strongSwan connection configuration')
|
|
await self.swanctl('--load-conns')
|
|
|
|
async def stop(self) -> None:
|
|
if self.jrnl_proc.returncode is not None:
|
|
self.jrnl_proc.terminate()
|
|
|
|
async def swanctl(self, *args: str) -> None:
|
|
cmd = ['swanctl']
|
|
cmd += args
|
|
if log.isEnabledFor(logging.DEBUG):
|
|
log.debug('Running command: %s', list2cmdline(cmd))
|
|
p = await asyncio.create_subprocess_exec(
|
|
*cmd,
|
|
stdin=asyncio.subprocess.DEVNULL,
|
|
stdout=asyncio.subprocess.DEVNULL,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
)
|
|
assert p.stderr
|
|
async for line in p.stderr:
|
|
line = line.rstrip()
|
|
if line:
|
|
log.error(line.decode(errors='replace'))
|
|
await p.wait()
|
|
|
|
async def watch_journal(self) -> None:
|
|
cmd = [
|
|
'journalctl',
|
|
'--output=json',
|
|
'--follow',
|
|
'--lines=0',
|
|
f'--unit={self.STRONGSWAN_UNIT}',
|
|
f'IKE_SA_NAME={self.IKE_SA_NAME}',
|
|
]
|
|
if log.isEnabledFor(logging.DEBUG):
|
|
log.debug('Running command: %s', list2cmdline(cmd))
|
|
self.jrnl_proc = await asyncio.create_subprocess_exec(
|
|
*cmd,
|
|
stdin=asyncio.subprocess.DEVNULL,
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
)
|
|
assert self.jrnl_proc.stdout
|
|
assert self.jrnl_proc.stderr
|
|
|
|
async def handle_stdout():
|
|
async for line in self.jrnl_proc.stdout:
|
|
try:
|
|
data = json.loads(line)
|
|
except ValueError as e:
|
|
log.error('Failed to parse journal entry: %s', e)
|
|
continue
|
|
try:
|
|
await self.on_message(data)
|
|
except Exception:
|
|
log.exception('Error handling message:')
|
|
continue
|
|
|
|
async def handle_stderr():
|
|
async for line in self.jrnl_proc.stderr:
|
|
log.error(line.rstrip().decode(errors='replace'))
|
|
|
|
asyncio.ensure_future(handle_stdout())
|
|
asyncio.ensure_future(handle_stderr())
|
|
await self.jrnl_proc.wait()
|
|
self.stop_evt.set()
|
|
|
|
|
|
def list2cmdline(args):
|
|
return ' '.join(shlex.quote(a) for a in args)
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('-v', '--verbose', action='count', default=0)
|
|
return parser.parse_args()
|
|
|
|
|
|
def setup_logging(verbose: int):
|
|
if verbose < 1:
|
|
level = logging.WARNING
|
|
elif verbose < 2:
|
|
level = logging.INFO
|
|
else:
|
|
level = logging.DEBUG
|
|
handler: logging.Handler
|
|
if sys.stderr.isatty():
|
|
if coloredlogs is not None:
|
|
coloredlogs.install(level=level)
|
|
return
|
|
handler = logging.StreamHandler()
|
|
else:
|
|
handler = logging.handlers.SysLogHandler(address='/dev/log')
|
|
handler.setFormatter(
|
|
logging.Formatter(
|
|
'%(asctime)s %(name)s: %(levelname)s %(message)s',
|
|
'%b %e %H:%M:%S',
|
|
)
|
|
)
|
|
|
|
handler.setLevel(level)
|
|
logging.root.setLevel(level)
|
|
logging.root.addHandler(handler)
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
setup_logging(args.verbose)
|
|
daemon = AsyncDaemon()
|
|
asyncio.run(daemon.main())
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|