aiomarionette/aiomarionette.py

326 lines
9.7 KiB
Python

'''Firefox Marionette protocol for asyncio
This module provides an asynchronous implementation of the Firefox
Marionette protocol over TCP sockets.
>>> async with Marionette() as mn:
... await mn.connect()
... await mn.navigate('https://getfirefox.com/')
'''
import dataclasses
import abc
import asyncio
import json
import logging
import random
import socket
from types import TracebackType
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast
log = logging.getLogger(__name__)
WindowType = Union[Literal['window'], Literal['tab']]
@dataclasses.dataclass
class WindowRect:
'''Window size and location'''
x: int
'''Position x coordinate in pixels'''
y: int
'''Position y coordinate in pixels'''
height: int
'''Height in pixels'''
width: int
'''Width in pixels'''
class MarionetteException(Exception):
'''Base class for Marionette errors'''
class _BaseRPC(metaclass=abc.ABCMeta):
# pylint: disable=too-few-public-methods
@abc.abstractmethod
async def _send_message(self, command: str, **kwargs: Any) -> Any:
...
class WebDriverBase(_BaseRPC, metaclass=abc.ABCMeta):
'''WebDriver protocol implementation'''
async def close_window(self) -> List[str]:
'''Close the current window
If the last window is closed, the session will be ended.
:returns: List of handles of remaining windows
'''
windows: List[str] = await self._send_message('WebDriver:CloseWindow')
return windows
async def fullscreen(self) -> WindowRect:
'''Enter or exit fullscreen for the current window
:returns: New :py:class:`WindowRect` with current size/location
'''
res: WindowRect = await self._send_message(
'WebDriver:FullscreenWindow'
)
return res
async def get_title(self) -> str:
'''Get the current window title'''
res = await self._send_message('WebDriver:GetTitle')
title: str = res['value']
return title
async def get_url(self) -> str:
'''Get the URL of the current window'''
res = await self._send_message('WebDriver:GetCurrentURL')
url: str = res['value']
return url
async def get_window_rect(self) -> WindowRect:
'''Get the current window position and dimensions'''
res: Dict[str, int] = await self._send_message(
'WebDriver:GetWindowRect'
)
return WindowRect(**res)
async def get_window_handles(self) -> List[str]:
'''Get a list of handles for all open windows'''
handles: List[str]
handles = await self._send_message('WebDriver:GetWindowHandles')
return handles
async def navigate(self, url: str) -> None:
'''Navigate to the specified location'''
await self._send_message('WebDriver:Navigate', url=url)
async def new_window(
self,
type: Optional[WindowType] = None, # pylint: disable=redefined-builtin
focus: bool = False,
private: bool = False,
) -> str:
'''Open a new window or tab
:param type: Either ``'window'`` or ``'tab'``
:param focus: Give the new window focus
:param private: Open a new private window
'''
res = await self._send_message(
'WebDriver:NewWindow',
type=type,
focus=focus,
private=private,
)
handle: str = res['handle']
return handle
async def set_window_rect(
self,
x: Optional[int] = None,
y: Optional[int] = None,
height: Optional[int] = None,
width: Optional[int] = None,
) -> WindowRect:
'''Resize the current window
:param x: Position x coordinate in pixels
:param y: Position y coordinate in pixels
:param height: Height in pixels
:param width: Width in pixels
'''
if (x is None and y is None) and (height is None and width is None):
raise ValueError('x and y OR height and width need values')
res: Dict[str, int] = await self._send_message(
'WebDriver:SetWindowRect',
x=x,
y=y,
height=height,
width=width,
)
return WindowRect(**res)
async def refresh(self) -> None:
'''Refresh the current window content'''
await self._send_message('WebDriver:Refresh')
async def switch_to_window(self, handle: str, focus: bool = True) -> None:
'''Switch to the specified window
:param handle: Window handle
:param focus: Give the selected window focus
'''
await self._send_message(
'WebDriver:SwitchToWindow', handle=handle, focus=focus
)
class Marionette(WebDriverBase):
'''Firefox Marionette session
This class implements the WebDriver protocol; see
:py:class:`WebDriverBase` for available methods.
'''
def __init__(self, host: str = 'localhost', port: int = 2828) -> None:
self.host = host
'''Socket host'''
self.port = port
'''Socket port'''
self._transport: Optional[asyncio.Transport] = None
self._waiting: Dict[int, asyncio.Future[Any]] = {}
self.session: Optional[Dict[str, Any]] = None
'''Marionette session information'''
async def __aenter__(self) -> 'Marionette':
return self
async def __aexit__(
self,
exc_type: Optional[Type[Exception]],
exc_value: Optional[Exception],
tb: Optional[TracebackType],
) -> None:
await self.close()
async def close(self) -> None:
'''Close the connection'''
self._waiting.clear()
if self._transport is not None:
self._transport.close()
self._transport = None
self.session = None
async def connect(self) -> None:
'''Connect to the Marionette socket and begin a session'''
hello = await self._connect()
log.info(
'Connected to Marionette server '
'(protocol: %s, application type: %s)',
hello.get('marionetteProtocol'),
hello.get('applicationType'),
)
self.session = await self._send_message(
'WebDriver:NewSession', strictFileInteractibility=True
)
async def _connect(self) -> Dict[str, Any]:
loop = asyncio.get_running_loop()
while 1:
log.info(
'Connecting to Marionette server %s on port %d',
self.host,
self.port,
)
try:
transport, _protocol = await loop.create_connection(
lambda: _MarionetteProtocol(self),
host=self.host,
port=self.port,
family=socket.AF_INET6,
)
fut = self._waiting[-1] = loop.create_future()
hello: Dict[str, Any] = await fut
except (OSError, EOFError) as e:
log.error('Failed to connect to Marionette server: %s', e)
await asyncio.sleep(1)
continue
else:
self._transport = cast(asyncio.Transport, transport)
break
# pyright: reportUnboundVariable=false
return hello
async def _send_message(self, command: str, **kwargs: Any) -> Any:
assert self._transport
loop = asyncio.get_running_loop()
fut = loop.create_future()
msgid = random.randint(0, 65535)
msg = json.dumps([0, msgid, command, kwargs or None])
log.debug('Sending message: %r', msg)
self._transport.write(f'{len(msg)}:{msg}'.encode('utf-8'))
self._waiting[msgid] = fut
return await fut
def _message_received(self, data: bytes) -> None:
try:
message = json.loads(data)
except ValueError as e:
log.error('Got invalid message from Marionette server: %s', e)
return
if -1 in self._waiting:
self._waiting.pop(-1).set_result(message)
if isinstance(message, list):
if message[0] != 1:
log.error('Unsupported message type: %r', message[0])
return
try:
fut = self._waiting.pop(message[1])
except KeyError:
log.warning(
'Nothing waiting for response to message ID %s', message[1]
)
return
if message[2] is not None:
fut.set_exception(MarionetteException(message[2]))
else:
fut.set_result(message[3])
def _connection_error(self, exc: Exception) -> None:
while self._waiting:
_msgid, fut = self._waiting.popitem()
fut.set_exception(exc)
class _MarionetteProtocol(asyncio.Protocol):
def __init__(self, marionette: Marionette) -> None:
self.marionette = marionette
self.buf = bytearray()
def data_received(self, data: bytes) -> None:
self.buf += data
while self.buf:
try:
idx = self.buf.index(b':')
except ValueError:
return
length = int(self.buf[:idx])
end = idx + 1 + length
if len(self.buf) < end:
return
end = idx + 1 + length
data = self.buf[idx + 1 : end]
self.buf = self.buf[end:]
log.debug('Received message: %s', data)
# pylint: disable=protected-access
self.marionette._message_received(data)
def connection_lost(self, exc: Optional[Exception]) -> None:
log.error('Connection lost, cancelling outstanding requests')
if exc is None:
exc = EOFError()
# pylint: disable=protected-access
self.marionette._connection_error(exc)