357 lines
11 KiB
Python
357 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
import abc
|
|
import argparse
|
|
import asyncio.subprocess
|
|
import json
|
|
import logging.handlers
|
|
import os
|
|
import random
|
|
import shlex
|
|
import signal
|
|
import sys
|
|
import time
|
|
from http import HTTPStatus
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Iterable, Optional, Set, Tuple
|
|
|
|
import httpx
|
|
|
|
|
|
try:
|
|
import coloredlogs
|
|
except ImportError:
|
|
coloredlogs = None
|
|
|
|
|
|
log = logging.getLogger('protonvpn_watchdog')
|
|
|
|
|
|
Json = Dict[str, Any]
|
|
|
|
|
|
class BaseAsyncDaemon(abc.ABC):
|
|
def __init__(self) -> None:
|
|
self.loop: asyncio.AbstractEventLoop
|
|
self.stop_evt: asyncio.Event
|
|
self.killed = 0
|
|
|
|
def shutdown(self, signum: int) -> None:
|
|
if log.isEnabledFor(logging.INFO):
|
|
log.info(
|
|
'Received signal %d (%s), stopping',
|
|
signum,
|
|
signal.Signals(signum).name,
|
|
)
|
|
self.killed = signum
|
|
self.stop_evt.set()
|
|
|
|
async def main(self) -> None:
|
|
self.loop = asyncio.get_running_loop()
|
|
self.stop_evt = asyncio.Event()
|
|
self.loop.add_signal_handler(
|
|
signal.SIGINT, self.shutdown, signal.SIGINT
|
|
)
|
|
self.loop.add_signal_handler(
|
|
signal.SIGTERM, self.shutdown, signal.SIGTERM
|
|
)
|
|
t_run = asyncio.create_task(self.run())
|
|
t_stop = asyncio.create_task(self.stop_evt.wait())
|
|
done, pending = await asyncio.wait(
|
|
(t_run, t_stop), return_when=asyncio.FIRST_COMPLETED
|
|
)
|
|
rc = 0
|
|
for task in done:
|
|
try:
|
|
await task
|
|
except Exception:
|
|
log.exception('Unhandled error:')
|
|
rc = 1
|
|
for task in pending:
|
|
task.cancel()
|
|
try:
|
|
await self.stop()
|
|
except Exception: # pylint: disable=broad-except
|
|
log.exception('Error stopping daemon:')
|
|
rc = 1
|
|
if self.killed:
|
|
signal.signal(self.killed, signal.SIG_DFL)
|
|
os.kill(os.getpid(), self.killed)
|
|
else:
|
|
raise SystemExit(rc)
|
|
|
|
@abc.abstractmethod
|
|
async def run(self) -> None:
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
async def stop(self) -> None:
|
|
raise NotImplementedError
|
|
|
|
|
|
class AsyncDaemon(BaseAsyncDaemon):
|
|
STRONGSWAN_UNIT = os.environ.get(
|
|
'PROTONVPN_STRONGSWAN_UNIT',
|
|
'strongswan.service',
|
|
)
|
|
IKE_SA_NAME = os.environ.get('PROTONVPN_IKE_SA_NAME', 'protonvpn')
|
|
SERVER_LIST = os.environ.get(
|
|
'PROTONVPN_SERVER_LIST', '/var/cache/protonvpn/serverlist.json'
|
|
)
|
|
SERVER_LIST_URL = os.environ.get(
|
|
'PROTONVPN_SERVERLIST_URL', 'https://api.protonmail.ch/vpn/logicals'
|
|
)
|
|
CONFIG = os.environ.get(
|
|
'PROTONVPN_CONFIG',
|
|
'/var/lib/protonvpn/remote_addrs',
|
|
)
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.jrnl_proc: asyncio.subprocess.Process
|
|
self.bad_servers: Set[str] = set()
|
|
|
|
async def get_serverlist(self) -> Optional[Json]:
|
|
data: Optional[Json] = None
|
|
try:
|
|
f = open(self.SERVER_LIST, encoding='utf-8')
|
|
except FileNotFoundError:
|
|
pass
|
|
else:
|
|
with f:
|
|
try:
|
|
data = json.load(f)
|
|
except ValueError as e:
|
|
log.warning('Failed to parse server list: %s', e)
|
|
st = os.fstat(f.fileno())
|
|
now = time.time()
|
|
if st.st_mtime > now - 3600:
|
|
log.info('Using cached server list from %s', f.name)
|
|
return data
|
|
async with httpx.AsyncClient() as client:
|
|
log.info('Fetching server list from %s', self.SERVER_LIST_URL)
|
|
r: httpx.Response
|
|
try:
|
|
r = await client.get(self.SERVER_LIST_URL)
|
|
r.raise_for_status()
|
|
except httpx.HTTPError as e:
|
|
log.error('Failed to fetch server list: %s', e)
|
|
return data
|
|
try:
|
|
data = r.json()
|
|
except ValueError as e:
|
|
log.error('Failed to parse server list: %s', e)
|
|
return data
|
|
path = Path(self.SERVER_LIST)
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
with path.open('w', encoding='utf-8') as f:
|
|
log.debug('Writing server list to %s', f.name)
|
|
json.dump(data, f)
|
|
return data
|
|
|
|
def iter_servers(self) -> Iterable[Tuple[int, str, str, str, str]]:
|
|
assert self.serverlist
|
|
for item in self.serverlist['LogicalServers']:
|
|
# Tier 0: Free
|
|
# Tier 1: Basic
|
|
# Tier 2: Plus
|
|
if item['Tier'] != 1:
|
|
continue
|
|
if item['EntryCountry'] != 'US':
|
|
continue
|
|
# Features 2: Tor
|
|
# Features 4: P2P
|
|
if item['Features']:
|
|
continue
|
|
for server in item['Servers']:
|
|
if server['Status'] != 1:
|
|
continue
|
|
yield (
|
|
item['Score'],
|
|
server['EntryIP'],
|
|
item['Name'],
|
|
item['EntryCountry'],
|
|
item['City'],
|
|
)
|
|
|
|
async def manage_serverlist(self) -> None:
|
|
while 1:
|
|
self.serverlist = await self.get_serverlist()
|
|
try:
|
|
st = os.stat(self.CONFIG)
|
|
except FileNotFoundError:
|
|
await self.reconfigure()
|
|
except OSError as e:
|
|
log.error('Error checking config file attributes: %s', e)
|
|
await self.reconfigure()
|
|
else:
|
|
if st.st_size < 10:
|
|
await self.reconfigure()
|
|
await asyncio.sleep(3600 + random.randint(0, 900))
|
|
|
|
def mark_bad(self, addr: str) -> None:
|
|
log.warning('Marking server %s as bad', addr)
|
|
self.bad_servers.add(addr)
|
|
|
|
async def on_message(self, data: Dict[str, str]) -> None:
|
|
log.debug('Got message: %s', data)
|
|
message = data.get('MESSAGE')
|
|
if not isinstance(message, str):
|
|
log.debug('Invalid message: %r', message)
|
|
return
|
|
if message.startswith('giving up') or 'EAP/FAIL' in message:
|
|
log.error('VPN tunnel failed!')
|
|
await self.reconfigure()
|
|
|
|
async def run(self) -> None:
|
|
log.info('Starting ProtonVPN Watchdog')
|
|
asyncio.ensure_future(self.manage_serverlist())
|
|
while 1:
|
|
try:
|
|
await self.watch_journal()
|
|
except Exception:
|
|
log.exception('Unhandled excption:')
|
|
|
|
async def reconfigure(self) -> None:
|
|
if not self.serverlist:
|
|
log.error('Cannot reconfigure: no known servers!')
|
|
return
|
|
if not os.path.isdir(os.path.dirname(self.CONFIG)):
|
|
os.makedirs(os.path.dirname(self.CONFIG))
|
|
fd = os.open(self.CONFIG, os.O_CREAT | os.O_RDWR, 0o644)
|
|
with open(fd, 'r+', encoding='utf-8') as f:
|
|
line = f.readline()
|
|
if line:
|
|
key, __, value = line.partition('=')
|
|
if not value or key.strip() != 'remote_addrs':
|
|
log.warning('Unexpected line in config: %s', line)
|
|
else:
|
|
self.mark_bad(value.strip())
|
|
f.seek(0)
|
|
f.truncate(0)
|
|
for _k, addr, name, country, city in sorted(self.iter_servers()):
|
|
if addr in self.bad_servers:
|
|
continue
|
|
log.info(
|
|
'Selected new server: %s (%s) %s/%s',
|
|
addr,
|
|
name,
|
|
country,
|
|
city,
|
|
)
|
|
f.write(f'remote_addrs = {addr}\n')
|
|
break
|
|
log.info('Reloading strongSwan connection configuration')
|
|
await self.swanctl('--load-conns')
|
|
|
|
async def stop(self) -> None:
|
|
if self.jrnl_proc.returncode is not None:
|
|
self.jrnl_proc.terminate()
|
|
|
|
async def swanctl(self, *args: str) -> None:
|
|
cmd = ['swanctl']
|
|
cmd += args
|
|
if log.isEnabledFor(logging.DEBUG):
|
|
log.debug('Running command: %s', list2cmdline(cmd))
|
|
p = await asyncio.create_subprocess_exec(
|
|
*cmd,
|
|
stdin=asyncio.subprocess.DEVNULL,
|
|
stdout=asyncio.subprocess.DEVNULL,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
)
|
|
assert p.stderr
|
|
async for line in p.stderr:
|
|
line = line.rstrip()
|
|
if line:
|
|
log.error(line.decode(errors='replace'))
|
|
await p.wait()
|
|
|
|
async def watch_journal(self) -> None:
|
|
cmd = [
|
|
'journalctl',
|
|
'--output=json',
|
|
'--follow',
|
|
'--lines=0',
|
|
f'--unit={self.STRONGSWAN_UNIT}',
|
|
f'IKE_SA_NAME={self.IKE_SA_NAME}',
|
|
]
|
|
if log.isEnabledFor(logging.DEBUG):
|
|
log.debug('Running command: %s', list2cmdline(cmd))
|
|
self.jrnl_proc = await asyncio.create_subprocess_exec(
|
|
*cmd,
|
|
stdin=asyncio.subprocess.DEVNULL,
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
)
|
|
assert self.jrnl_proc.stdout
|
|
assert self.jrnl_proc.stderr
|
|
|
|
async def handle_stdout():
|
|
async for line in self.jrnl_proc.stdout:
|
|
try:
|
|
data = json.loads(line)
|
|
except ValueError as e:
|
|
log.error('Failed to parse journal entry: %s', e)
|
|
continue
|
|
try:
|
|
await self.on_message(data)
|
|
except Exception:
|
|
log.exception('Error handling message:')
|
|
continue
|
|
|
|
async def handle_stderr():
|
|
async for line in self.jrnl_proc.stderr:
|
|
log.error(line.rstrip().decode(errors='replace'))
|
|
|
|
asyncio.ensure_future(handle_stdout())
|
|
asyncio.ensure_future(handle_stderr())
|
|
await self.jrnl_proc.wait()
|
|
self.stop_evt.set()
|
|
|
|
|
|
def list2cmdline(args):
|
|
return ' '.join(shlex.quote(a) for a in args)
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('-v', '--verbose', action='count', default=0)
|
|
return parser.parse_args()
|
|
|
|
|
|
def setup_logging(verbose: int):
|
|
if verbose < 1:
|
|
level = logging.WARNING
|
|
elif verbose < 2:
|
|
level = logging.INFO
|
|
else:
|
|
level = logging.DEBUG
|
|
handler: logging.Handler
|
|
if sys.stderr.isatty():
|
|
if coloredlogs is not None:
|
|
coloredlogs.install(level=level)
|
|
return
|
|
handler = logging.StreamHandler()
|
|
else:
|
|
handler = logging.handlers.SysLogHandler(address='/dev/log')
|
|
handler.setFormatter(
|
|
logging.Formatter(
|
|
'%(asctime)s %(name)s: %(levelname)s %(message)s',
|
|
'%b %e %H:%M:%S',
|
|
)
|
|
)
|
|
|
|
handler.setLevel(level)
|
|
logging.root.setLevel(level)
|
|
logging.root.addHandler(handler)
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
setup_logging(args.verbose)
|
|
daemon = AsyncDaemon()
|
|
asyncio.run(daemon.main())
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|