Skip to content

Commit

Permalink
[connection] move asyncio code out of QuicConnection (see #4)
Browse files Browse the repository at this point in the history
  • Loading branch information
jlaine committed Jun 17, 2019
1 parent e886aed commit 6b20cde
Show file tree
Hide file tree
Showing 17 changed files with 736 additions and 610 deletions.
2 changes: 0 additions & 2 deletions aioquic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
from .client import connect # noqa
from .connection import QuicConnection # noqa
from .server import serve # noqa
2 changes: 2 additions & 0 deletions aioquic/asyncio/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .client import connect # noqa
from .server import serve # noqa
24 changes: 13 additions & 11 deletions aioquic/client.py → aioquic/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import socket
from typing import AsyncGenerator, List, Optional, TextIO, cast

from ..configuration import QuicConfiguration
from ..connection import QuicConnection
from ..tls import SessionTicket, SessionTicketHandler
from .compat import asynccontextmanager
from .configuration import QuicConfiguration
from .connection import QuicConnection, QuicStreamHandler
from .tls import SessionTicket, SessionTicketHandler
from .protocol import QuicConnectionProtocol, QuicStreamHandler

__all__ = ["connect"]

Expand All @@ -23,7 +24,7 @@ async def connect(
session_ticket: Optional[SessionTicket] = None,
session_ticket_handler: Optional[SessionTicketHandler] = None,
stream_handler: Optional[QuicStreamHandler] = None,
) -> AsyncGenerator[QuicConnection, None]:
) -> AsyncGenerator[QuicConnectionProtocol, None]:
"""
Connect to a QUIC server at the given `host` and `port`.
Expand Down Expand Up @@ -69,20 +70,21 @@ async def connect(
if idle_timeout is not None:
configuration.idle_timeout = idle_timeout

connection = QuicConnection(
configuration=configuration, session_ticket_handler=session_ticket_handler
)

# connect
_, protocol = await loop.create_datagram_endpoint(
lambda: QuicConnection(
configuration=configuration,
session_ticket_handler=session_ticket_handler,
stream_handler=stream_handler,
),
lambda: QuicConnectionProtocol(connection, stream_handler=stream_handler),
local_addr=("::", 0),
)
protocol = cast(QuicConnection, protocol)
protocol.connect(addr, protocol_version=protocol_version)
protocol = cast(QuicConnectionProtocol, protocol)
protocol.connect(addr, protocol_version)
await protocol.wait_connected()
try:
yield protocol
finally:
protocol.close()
protocol._send_pending()
await protocol.wait_closed()
File renamed without changes.
189 changes: 189 additions & 0 deletions aioquic/asyncio/protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import asyncio
from typing import Any, Callable, Dict, Optional, Text, Tuple, Union, cast

from .. import events
from ..connection import NetworkAddress, QuicConnection

QuicConnectionIdHandler = Callable[[bytes], None]
QuicStreamHandler = Callable[[asyncio.StreamReader, asyncio.StreamWriter], None]


class QuicConnectionProtocol(asyncio.DatagramProtocol):
def __init__(
self,
connection: QuicConnection,
stream_handler: Optional[QuicStreamHandler] = None,
):
self._connection = connection
self._connection_id_issued_handler: QuicConnectionIdHandler = lambda c: None
self._connection_id_retired_handler: QuicConnectionIdHandler = lambda c: None

if stream_handler is not None:
self._stream_handler = stream_handler
else:
self._stream_handler = lambda r, w: None

def close(self) -> None:
self._connection.close()
self._send_pending()

def connect(self, addr: NetworkAddress, protocol_version: int) -> None:
self._connection.connect(
addr, now=self._loop.time(), protocol_version=protocol_version
)
self._send_pending()

def connection_made(self, transport: asyncio.BaseTransport) -> None:
loop = asyncio.get_event_loop()

self._closed = asyncio.Event()
self._connected_waiter = loop.create_future()
self._loop = loop
self._ping_waiter: Optional[asyncio.Future[None]] = None
self._send_task: Optional[asyncio.Handle] = None
self._stream_readers: Dict[int, asyncio.StreamReader] = {}
self._timer: Optional[asyncio.TimerHandle] = None
self._timer_at: Optional[float] = None
self._transport = cast(asyncio.DatagramTransport, transport)

def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> None:
self._connection.receive_datagram(
cast(bytes, data), addr, now=self._loop.time()
)
self._send_pending()

async def create_stream(
self, is_unidirectional: bool = False
) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
"""
Create a QUIC stream and return a pair of (reader, writer) objects.
The returned reader and writer objects are instances of :class:`asyncio.StreamReader`
and :class:`asyncio.StreamWriter` classes.
"""
stream = self._connection.create_stream(is_unidirectional=is_unidirectional)
return self._create_stream(stream.stream_id)

def request_key_update(self) -> None:
"""
Request an update of the encryption keys.
"""
self._connection.request_key_update()
self._send_pending()

async def ping(self) -> None:
"""
Pings the remote host and waits for the response.
"""
assert self._ping_waiter is None, "already await a ping"
self._ping_waiter = self._loop.create_future()
self._connection.send_ping(id(self._ping_waiter))
self._send_soon()
await asyncio.shield(self._ping_waiter)

async def wait_closed(self) -> None:
"""
Wait for the connection to be closed.
"""
await self._closed.wait()

async def wait_connected(self) -> None:
"""
Wait for the TLS handshake to complete.
"""
await asyncio.shield(self._connected_waiter)

def _create_stream(
self, stream_id: int
) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
adapter = QuicStreamAdapter(self, stream_id)
reader = asyncio.StreamReader()
writer = asyncio.StreamWriter(adapter, None, reader, None)
self._stream_readers[stream_id] = reader
return reader, writer

def _handle_timer(self) -> None:
now = max(self._timer_at, self._loop.time())

self._timer = None
self._timer_at = None

self._connection.handle_timer(now=now)

self._send_pending()

def _send_pending(self) -> None:
self._send_task = None

# process events
event = self._connection.next_event()
while event is not None:
if isinstance(event, events.ConnectionIdIssued):
self._connection_id_issued_handler(event.connection_id)
elif isinstance(event, events.ConnectionIdRetired):
self._connection_id_retired_handler(event.connection_id)
elif isinstance(event, events.ConnectionTerminated):
for reader in self._stream_readers.values():
reader.feed_eof()
if not self._connected_waiter.done():
self._connected_waiter.set_exception(ConnectionError)
self._closed.set()
elif isinstance(event, events.HandshakeCompleted):
self._connected_waiter.set_result(None)
elif isinstance(event, events.PongReceived):
waiter = self._ping_waiter
self._ping_waiter = None
waiter.set_result(None)
elif isinstance(event, events.StreamDataReceived):
reader = self._stream_readers.get(event.stream_id, None)
if reader is None:
reader, writer = self._create_stream(event.stream_id)
self._stream_handler(reader, writer)
reader.feed_data(event.data)
if event.end_stream:
reader.feed_eof()

event = self._connection.next_event()

# send datagrams
for data, addr in self._connection.datagrams_to_send(now=self._loop.time()):
self._transport.sendto(data, addr)

# re-arm timer
timer_at = self._connection.get_timer()
if self._timer is not None and self._timer_at != timer_at:
self._timer.cancel()
self._timer = None
if self._timer is None and timer_at is not None:
self._timer = self._loop.call_at(timer_at, self._handle_timer)
self._timer_at = timer_at

def _send_soon(self) -> None:
if self._send_task is None:
self._send_task = self._loop.call_soon(self._send_pending)


class QuicStreamAdapter(asyncio.Transport):
def __init__(self, protocol: QuicConnectionProtocol, stream_id: int):
self.protocol = protocol
self.stream_id = stream_id

def can_write_eof(self) -> bool:
return True

def get_extra_info(self, name: str, default: Any = None) -> Any:
"""
Get information about the underlying QUIC stream.
"""
if name == "connection":
return self.protocol._connection
elif name == "stream_id":
return self.stream_id

def write(self, data):
self.protocol._connection.send_stream_data(self.stream_id, data)
self.protocol._send_soon()

def write_eof(self):
self.protocol._connection.send_stream_data(self.stream_id, b"", end_stream=True)
self.protocol._send_soon()
58 changes: 35 additions & 23 deletions aioquic/server.py → aioquic/asyncio/server.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,28 @@
import asyncio
import ipaddress
import os
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Text, TextIO, Union, cast

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding, rsa

from .buffer import Buffer
from .configuration import QuicConfiguration
from .connection import NetworkAddress, QuicConnection, QuicStreamHandler
from .packet import (
from ..buffer import Buffer
from ..configuration import QuicConfiguration
from ..connection import NetworkAddress, QuicConnection
from ..packet import (
PACKET_TYPE_INITIAL,
encode_quic_retry,
encode_quic_version_negotiation,
pull_quic_header,
)
from .tls import SessionTicketFetcher, SessionTicketHandler
from ..tls import SessionTicketFetcher, SessionTicketHandler
from .protocol import QuicConnectionProtocol, QuicStreamHandler

__all__ = ["serve"]

QuicConnectionHandler = Callable[[QuicConnection], None]
QuicConnectionHandler = Callable[[QuicConnectionProtocol], None]


def encode_address(addr: NetworkAddress) -> bytes:
Expand All @@ -39,7 +41,8 @@ def __init__(
stream_handler: Optional[QuicStreamHandler] = None,
) -> None:
self._configuration = configuration
self._connections: Dict[bytes, QuicConnection] = {}
self._protocols: Dict[bytes, QuicConnectionProtocol] = {}
self._loop = asyncio.get_event_loop()
self._session_ticket_fetcher = session_ticket_fetcher
self._session_ticket_handler = session_ticket_handler
self._transport: Optional[asyncio.DatagramTransport] = None
Expand Down Expand Up @@ -81,9 +84,9 @@ def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> N
)
return

connection = self._connections.get(header.destination_cid, None)
protocol = self._protocols.get(header.destination_cid, None)
original_connection_id: Optional[bytes] = None
if connection is None and header.packet_type == PACKET_TYPE_INITIAL:
if protocol is None and header.packet_type == PACKET_TYPE_INITIAL:
# stateless retry
if self._retry_key is not None:
if not header.token:
Expand Down Expand Up @@ -131,25 +134,34 @@ def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> N
original_connection_id=original_connection_id,
session_ticket_fetcher=self._session_ticket_fetcher,
session_ticket_handler=self._session_ticket_handler,
stream_handler=self._stream_handler,
)
self._connections[header.destination_cid] = connection
protocol = QuicConnectionProtocol(
connection, stream_handler=self._stream_handler
)
protocol._connection_id_issued_handler = partial(
self._connection_id_issued, protocol=protocol
)
protocol._connection_id_retired_handler = partial(
self._connection_id_retired, protocol=protocol
)

def connection_id_issued(cid: bytes) -> None:
self._connections[cid] = connection
self._protocols[header.destination_cid] = protocol
protocol.connection_made(self._transport)

def connection_id_retired(cid: bytes) -> None:
del self._connections[cid]
self._protocols[connection.host_cid] = protocol
self._connection_handler(protocol)

connection._connection_id_issued_handler = connection_id_issued
connection._connection_id_retired_handler = connection_id_retired
connection.connection_made(self._transport)
if protocol is not None:
protocol.datagram_received(data, addr)

self._connections[connection.host_cid] = connection
self._connection_handler(connection)
def _connection_id_issued(self, cid: bytes, protocol: QuicConnectionProtocol):
self._protocols[cid] = protocol

if connection is not None:
connection.datagram_received(data, addr)
def _connection_id_retired(
self, cid: bytes, protocol: QuicConnectionProtocol
) -> None:
assert self._protocols[cid] == protocol
del self._protocols[cid]


async def serve(
Expand Down Expand Up @@ -181,7 +193,7 @@ async def serve(
* ``connection_handler`` is a callback which is invoked whenever a
connection is created. It must be a a function accepting a single
argument: a :class:`~aioquic.QuicConnection`.
argument: a :class:`~aioquic.asyncio.protocol.QuicConnectionProtocol`.
* ``secrets_log_file`` is a file-like object in which to log traffic
secrets. This is useful to analyze traffic captures with Wireshark.
* ``stateless_retry`` specifies whether a stateless retry should be
Expand Down
Loading

0 comments on commit 6b20cde

Please sign in to comment.