From 3a29535bb19ad564e306ed4617dbda3d1865c22a Mon Sep 17 00:00:00 2001 From: Julian Popescu Date: Wed, 18 Oct 2023 18:35:38 +0200 Subject: [PATCH 1/3] add cli options for ws handler --- chalice/cli/__init__.py | 27 ++++++++++------ chalice/cli/factory.py | 7 +++-- chalice/local.py | 68 +++++++++++++++++++++++++++++++--------- setup.py | 1 + tests/unit/test_local.py | 14 ++++----- 5 files changed, 84 insertions(+), 33 deletions(-) diff --git a/chalice/cli/__init__.py b/chalice/cli/__init__.py index 96ffa0b88..5d73ab1a4 100644 --- a/chalice/cli/__init__.py +++ b/chalice/cli/__init__.py @@ -102,6 +102,8 @@ def _configure_cli_env_vars(): @cli.command() @click.option('--host', default='127.0.0.1') @click.option('--port', default=8000, type=click.INT) +@click.option('--ws-host') +@click.option('--ws-port', type=click.INT) @click.option('--stage', default=DEFAULT_STAGE_NAME, help='Name of the Chalice stage for the local server to use.') @click.option('--autoreload/--no-autoreload', @@ -109,14 +111,18 @@ def _configure_cli_env_vars(): help='Automatically restart server when code changes.') @click.pass_context def local(ctx, host='127.0.0.1', port=8000, stage=DEFAULT_STAGE_NAME, - autoreload=True): - # type: (click.Context, str, int, str, bool) -> None + autoreload=True, ws_host=None, ws_port=None): + # type: (click.Context, str, int, str, bool, Optional[str], Optional[port]) -> None factory = ctx.obj['factory'] # type: CLIFactory from chalice.cli import reloader # We don't create the server here because that will bind the # socket and we only want to do this in the worker process. + if ws_host is None: + ws_host = host + if ws_port is None: + ws_port = port + 1 server_factory = functools.partial( - create_local_server, factory, host, port, stage) + create_local_server, factory, host, port, stage, ws_host, ws_port) # When running `chalice local`, a stdout logger is configured # so you'll see the same stdout logging as you would when # running in lambda. This is configuring the root logger. @@ -133,11 +139,11 @@ def local(ctx, host='127.0.0.1', port=8000, stage=DEFAULT_STAGE_NAME, # recommended way to do this is to use sys.exit() directly, # see: https://github.com/pallets/click/issues/747 sys.exit(rc) - run_local_server(factory, host, port, stage) + run_local_server(factory, host, port, stage, ws_host, ws_port) -def create_local_server(factory, host, port, stage): - # type: (CLIFactory, str, int, str) -> LocalDevServer +def create_local_server(factory, host, port, stage, ws_host, ws_port): + # type: (CLIFactory, str, int, str, str, int) -> LocalDevServer config = factory.create_config_obj( chalice_stage_name=stage ) @@ -146,13 +152,14 @@ def create_local_server(factory, host, port, stage): # there is no point in testing locally. routes = config.chalice_app.routes validate_routes(routes) - server = factory.create_local_server(app_obj, config, host, port) + server = factory.create_local_server( + app_obj, config, host, port, ws_host, ws_port) return server -def run_local_server(factory, host, port, stage): - # type: (CLIFactory, str, int, str) -> None - server = create_local_server(factory, host, port, stage) +def run_local_server(factory, host, port, stage, ws_host, ws_port): + # type: (CLIFactory, str, int, str, str, int) -> None + server = create_local_server(factory, host, port, stage, ws_host, ws_port) server.serve_forever() diff --git a/chalice/cli/factory.py b/chalice/cli/factory.py index 58fc9ae76..2f74aa9b9 100644 --- a/chalice/cli/factory.py +++ b/chalice/cli/factory.py @@ -336,9 +336,12 @@ def load_project_config(self) -> Dict[str, Any]: return json.loads(f.read()) def create_local_server( - self, app_obj: Chalice, config: Config, host: str, port: int + self, app_obj: Chalice, config: Config, + host: str, port: int, + ws_host: str, ws_port: int ) -> local.LocalDevServer: - return local.create_local_server(app_obj, config, host, port) + return local.create_local_server( + app_obj, config, host, port, ws_host, ws_port) def create_package_options(self) -> PackageOptions: """Create the package options that are required to target regions.""" diff --git a/chalice/local.py b/chalice/local.py index 0c13f31ac..602028eaa 100644 --- a/chalice/local.py +++ b/chalice/local.py @@ -6,6 +6,7 @@ from __future__ import print_function from __future__ import annotations import re +import socket import threading import time import uuid @@ -18,6 +19,8 @@ from six.moves.BaseHTTPServer import HTTPServer from six.moves.BaseHTTPServer import BaseHTTPRequestHandler from six.moves.socketserver import ThreadingMixIn +import websockets.sync.server +import websockets.exceptions from typing import ( List, Any, @@ -46,8 +49,10 @@ ContextType = Dict[str, Any] HeaderType = Dict[str, Any] ResponseType = Dict[str, Any] -HandlerCls = Callable[..., 'ChaliceRequestHandler'] -ServerCls = Callable[..., 'HTTPServer'] +HttpHandlerCls = Callable[..., 'ChaliceRequestHandler'] +HttpServerCls = Callable[..., 'HTTPServer'] +WsHandlerCls = Callable[..., 'ChaliceWsHandler'] +WsServerCls = Callable[..., 'websockets.sync.server.WebSocketServer'] class Clock(object): @@ -57,10 +62,11 @@ def time(self) -> float: def create_local_server(app_obj: Chalice, config: Config, - host: str, port: int) -> LocalDevServer: + host: str, port: int, + ws_host: str, ws_port: int) -> LocalDevServer: CustomLocalChalice.__bases__ = (LocalChalice, app_obj.__class__) app_obj.__class__ = CustomLocalChalice - return LocalDevServer(app_obj, config, host, port) + return LocalDevServer(app_obj, config, host, port, ws_host, ws_port) class LocalARNBuilder(object): @@ -673,6 +679,21 @@ def _send_headers(self, headers: HeaderType) -> None: self.end_headers() +class ChaliceWsHandler: + def __init__(self, app_object: Chalice, config: Config): + self.app_obj = app_object + self.config = config + + def __call__(self, + websocket: websockets.sync.server.ServerConnection): + print("open") + try: + for message in websocket: + print(message) + except websockets.exceptions.ConnectionClosed: + print("closed") + + class ThreadedHTTPServer(ThreadingMixIn, HTTPServer): """Threading mixin to better support browsers. @@ -689,27 +710,46 @@ class ThreadedHTTPServer(ThreadingMixIn, HTTPServer): class LocalDevServer(object): def __init__(self, app_object: Chalice, - config: Config, host: str, port: int, - handler_cls: HandlerCls = ChaliceRequestHandler, - server_cls: ServerCls = ThreadedHTTPServer) -> None: + config: Config, + host: str, port: int, + ws_host: str, ws_port: int, + http_handler_cls: HttpHandlerCls = ChaliceRequestHandler, + http_server_cls: HttpServerCls = ThreadedHTTPServer, + ws_handler_cls: WsHandlerCls = ChaliceWsHandler, + ws_server_cls: WsServerCls = + websockets.sync.server.serve,) -> None: self.app_object = app_object self.host = host self.port = port - self._wrapped_handler = functools.partial( - handler_cls, app_object=app_object, config=config) - self.server = server_cls((host, port), self._wrapped_handler) + self.ws_host = ws_host + self.ws_port = ws_port + self._wrapped_http_handler = functools.partial( + http_handler_cls, app_object=app_object, config=config) + self.http_server = http_server_cls( + (host, port), self._wrapped_http_handler) + self._ws_handler = ws_handler_cls(app_object, config) + self.ws_server = ws_server_cls(self._ws_handler, ws_host, ws_port) def handle_single_request(self) -> None: - self.server.handle_request() + self.http_server.handle_request() def serve_forever(self) -> None: - print("Serving on http://%s:%s" % (self.host, self.port)) - self.server.serve_forever() + print("Serving on http://%s:%s and ws://%s:%s" % + (self.host, self.port, self.ws_host, self.ws_port)) + threads = [ + threading.Thread(target=self.http_server.serve_forever), + threading.Thread(target=self.ws_server.serve_forever), + ] + for thread in threads: + thread.start() + for thread in threads: + thread.join() def shutdown(self) -> None: # This must be called from another thread of else it # will deadlock. - self.server.shutdown() + self.http_server.shutdown() + self.ws_server.shutdown() class HTTPServerThread(threading.Thread): diff --git a/setup.py b/setup.py index 5cada29ed..35735af22 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ def recursive_include(relative_dir): 'jmespath>=0.9.3,<2.0.0', 'pyyaml>=5.3.1,<7.0.0', 'inquirer>=2.7.0,<3.0.0', + 'websockets>=11.0.3,<12.0.0', 'wheel', 'setuptools' ] diff --git a/tests/unit/test_local.py b/tests/unit/test_local.py index b4135f650..41253d71f 100644 --- a/tests/unit/test_local.py +++ b/tests/unit/test_local.py @@ -727,7 +727,7 @@ def test_can_create_lambda_event_for_post_with_formencoded_body(): def test_can_provide_port_to_local_server(sample_app): dev_server = local.create_local_server(sample_app, None, '127.0.0.1', port=23456) - assert dev_server.server.server_port == 23456 + assert dev_server.http_server.server_port == 23456 def test_can_provide_host_to_local_server(sample_app): @@ -1191,8 +1191,8 @@ class TestLocalDevServer(object): def test_can_delegate_to_server(self, sample_app): http_server = mock.Mock(spec=HTTPServer) dev_server = LocalDevServer( - sample_app, Config(), '0.0.0.0', 8000, - server_cls=lambda *args: http_server, + sample_app, Config(), '0.0.0.0', 8000, '0.0.0.0', 8001, + http_server_cls=lambda *args: http_server, ) dev_server.handle_single_request() @@ -1208,15 +1208,15 @@ def args_recorder(*args): provided_args[:] = list(args) LocalDevServer( - sample_app, Config(), '0.0.0.0', 8000, - server_cls=args_recorder, + sample_app, Config(), '0.0.0.0', 8000, '0.0.0.0', 8001, + http_server_cls=args_recorder, ) assert provided_args[0] == ('0.0.0.0', 8000) def test_does_use_daemon_threads(self, sample_app): server = LocalDevServer( - sample_app, Config(), '0.0.0.0', 8000 + sample_app, Config(), '0.0.0.0', 8000, '0.0.0.0', 8001, ) - assert server.server.daemon_threads + assert server.http_server.daemon_threads From 960dfd9b04610b4adf07a57562934b70fe2dd9ec Mon Sep 17 00:00:00 2001 From: Julian Popescu Date: Fri, 20 Oct 2023 18:24:18 +0200 Subject: [PATCH 2/3] add websocket handler implementation --- chalice/app.py | 2 +- chalice/local.py | 262 ++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 247 insertions(+), 17 deletions(-) diff --git a/chalice/app.py b/chalice/app.py index 921534510..7bdf16b62 100644 --- a/chalice/app.py +++ b/chalice/app.py @@ -611,7 +611,7 @@ class WebsocketAPI(object): def __init__(self, env: Optional[MutableMapping] = None) -> None: self.session: Optional[Any] = None self._endpoint: Optional[str] = None - self._client = None + self._client: Optional[Any] = None if env is None: self._env: MutableMapping = os.environ else: diff --git a/chalice/local.py b/chalice/local.py index 602028eaa..6d407a03d 100644 --- a/chalice/local.py +++ b/chalice/local.py @@ -5,8 +5,9 @@ """ from __future__ import print_function from __future__ import annotations +import time +import datetime import re -import socket import threading import time import uuid @@ -19,8 +20,6 @@ from six.moves.BaseHTTPServer import HTTPServer from six.moves.BaseHTTPServer import BaseHTTPRequestHandler from six.moves.socketserver import ThreadingMixIn -import websockets.sync.server -import websockets.exceptions from typing import ( List, Any, @@ -31,12 +30,28 @@ Union, ) # noqa +from websockets.sync.server import ( + WebSocketServer, + serve as CreateWebsocketServer, + ServerConnection as WebsocketConnection +) +from websockets.http11 import ( + Request as WebsocketRequest, + Response as WebsocketResponse +) +from websockets.datastructures import Headers as WebsocketHeaders +from websockets.exceptions import ( + ConnectionClosed as WebsocketConnectionClosed +) + from chalice.app import Chalice # noqa from chalice.app import CORSConfig # noqa from chalice.app import ChaliceAuthorizer # noqa from chalice.app import CognitoUserPoolAuthorizer # noqa from chalice.app import RouteEntry # noqa from chalice.app import Request # noqa +from chalice.app import Response # noqa +from chalice.app import WebsocketEvent # noqa from chalice.app import AuthResponse # noqa from chalice.app import BuiltinAuthConfig # noqa from chalice.config import Config # noqa @@ -52,7 +67,7 @@ HttpHandlerCls = Callable[..., 'ChaliceRequestHandler'] HttpServerCls = Callable[..., 'HTTPServer'] WsHandlerCls = Callable[..., 'ChaliceWsHandler'] -WsServerCls = Callable[..., 'websockets.sync.server.WebSocketServer'] +WsServerCls = Callable[..., 'WebSocketServer'] class Clock(object): @@ -679,19 +694,233 @@ def _send_headers(self, headers: HeaderType) -> None: self.end_headers() +class WebsocketClientConnection: + def __init__(self, connection: WebsocketConnection) -> None: + self.connection = connection + self._connected_at = datetime.datetime.utcnow() + self._last_active_at = self._connected_at + + def _touch(self) -> None: + self._last_active_at = datetime.datetime.utcnow() + + def recv(self) -> str | bytes: + message = self.connection.recv() + self._touch() + return message + + def send(self, message: str | bytes) -> None: + self.connection.send(message) + self._touch() + + def close(self) -> None: + self.connection.close() + self._touch() + + def info(self) -> Dict[str, Any]: + try: + source_ip = self.connection.remote_address[0] + except Exception: + source_ip = '' + if self.connection.request is not None: + try: + user_agent = next(iter( + self.connection.request.headers.get_all('User-Agent'))) + except StopIteration: + user_agent = '' + else: + user_agent = '' + return { + 'ConnectedAt': self._connected_at, + 'Identity': { + 'SourceIp': source_ip, + 'UserAgent': user_agent, + }, + 'LastActiveAt': self._last_active_at, + } + + +class WebsocketClientExceptions: + GoneException = Exception + + +class WebsocketClient: + exceptions = WebsocketClientExceptions() + + def __init__(self) -> None: + self._connections: Dict[str, WebsocketClientConnection] = {} + + def _get(self, connection_id) -> WebsocketClientConnection: + try: + return self._connections[connection_id] + except KeyError: + raise self.exceptions.GoneException('Connection not found') + + def _del(self, connection_id: str): + try: + del self._connections[connection_id] + except KeyError: + raise self.exceptions.GoneException('Connection not found') + + def get_connection_id(self, connection: WebsocketConnection) -> str: + return base64.b64encode(connection.id.bytes).decode('ascii') + + def add_connection(self, connection: WebsocketConnection): + self._connections[self.get_connection_id(connection)] = ( + WebsocketClientConnection(connection) + ) + + def receive_message(self, ConnectionId: str) -> str | bytes: + return self._get(ConnectionId).recv() + + def post_to_connection(self, ConnectionId: str, Data: str) -> None: + try: + self._get(ConnectionId).send(Data) + except WebsocketConnectionClosed: + self._del(ConnectionId) + raise self.exceptions.GoneException('Connection closed') + + def delete_connection(self, ConnectionId: str) -> None: + self._get(ConnectionId).close() + self._del(ConnectionId) + + def get_connection(self, ConnectionId: str) -> Dict[str, Any]: + return self._get(ConnectionId).info() + + class ChaliceWsHandler: - def __init__(self, app_object: Chalice, config: Config): + MAX_LAMBDA_EXECUTION_TIME = 900 + + def __init__(self, app_object: Chalice, config: Config, domain_name: str): + app_object.websocket_api.configure(domain_name, config.chalice_stage) + app_object.websocket_api._client = WebsocketClient() self.app_obj = app_object self.config = config + self.domain_name = domain_name + + @property + def _websocket_client(self) -> WebsocketClient: + if isinstance(self.app_obj.websocket_api._client, WebsocketClient): + return self.app_obj.websocket_api._client + else: + raise TypeError("Websocket client is not a WebsocketClient") + + def _get_headers(self, websocket: WebsocketConnection) -> Dict[str, Any]: + if websocket.request is None: + return {"headers": {}, "multiValueHeaders": {}} + headers = dict(websocket.request.headers.raw_items()) + multi_value_headers = { + key: websocket.request.headers.get_all(key) + for key in headers.keys() + } + return {"headers": headers, "multiValueHeaders": multi_value_headers} + + def _get_request_context(self, + websocket: WebsocketConnection) -> Dict[str, Any]: + connection_id = self._websocket_client.get_connection_id(websocket) + connection_info = self._websocket_client.get_connection(connection_id) + return { + "connectionId": connection_id, + "domainName": self.domain_name, + "stage": self.config.chalice_stage, + "messageDirection": "IN", + "identity": { + "sourceIp": connection_info["Identity"]["SourceIp"], + }, + "extendedRequestId": connection_id, + "requestTime": connection_info["ConnectedAt"].isoformat(), + "connectedAt": + int(time.mktime(connection_info["ConnectedAt"].timetuple())), + "requestTimeEpoch": + int(time.mktime(connection_info["ConnectedAt"].timetuple())), + "apiId": "local", - def __call__(self, - websocket: websockets.sync.server.ServerConnection): - print("open") + } + + def _get_base_event(self, + websocket: WebsocketConnection, + include_headers: bool = False) -> Dict[str, Any]: + if include_headers: + event = self._get_headers(websocket) + else: + event = {} + event["requestContext"] = self._get_request_context(websocket) + event["isBase64Encoded"] = False + return event + + def _generate_lambda_context(self) -> LambdaContext: + if self.config.lambda_timeout is None: + timeout = self.MAX_LAMBDA_EXECUTION_TIME * 1000 + else: + timeout = self.config.lambda_timeout * 1000 + return LambdaContext( + function_name=self.config.function_name, + memory_size=self.config.lambda_memory_size, + max_runtime_ms=timeout + ) + + def _handle_event(self, event: Dict[str, Any]) -> Any: + handler = self.app_obj.websocket_handlers[ + event["requestContext"]["routeKey"] + ] + return handler.handler_function(WebsocketEvent( + event, self._generate_lambda_context())) + + def __call__(self, websocket: WebsocketConnection): + connection_id = self._websocket_client.get_connection_id(websocket) try: - for message in websocket: - print(message) - except websockets.exceptions.ConnectionClosed: - print("closed") + while True: + message = self._websocket_client.receive_message(connection_id) + event = self._get_base_event(websocket) + event["requestContext"]["routeKey"] = "$default" + event["requestContext"]["eventType"] = "MESSAGE" + event["requestContext"]["messageId"] = ( + base64.b64encode(uuid.uuid4().bytes).decode('ascii') + ) + if isinstance(message, str): + event["body"] = message + else: + event["body"] = base64.b64encode(message).decode('ascii') + event["isBase64Encoded"] = True + self._handle_event(event) + except WebsocketConnectionClosed as err: + event = self._get_base_event(websocket, True) + event["requestContext"]["routeKey"] = "$disconnect" + event["requestContext"]["eventType"] = "DISCONNECT" + close = err.rcvd or err.sent + if close: + event["requestContext"]["disconnectStatusCode"] = close.code + event["requestContext"]["disconnectReason"] = close.reason + self._handle_event(event) + + def process_request(self, + websocket: WebsocketConnection, + _: WebsocketRequest + ) -> Optional[WebsocketResponse]: + self._websocket_client.add_connection(websocket) + + event = self._get_base_event(websocket, True) + event["requestContext"]["routeKey"] = "$connect" + event["requestContext"]["eventType"] = "CONNECT" + + response = self._handle_event(event) + + if response is None: + return None + + if isinstance(response, Response): + response = response.to_dict() + + body = response.get("body", None) + if isinstance(body, str): + body = body.encode("utf-8") + + return WebsocketResponse( + status_code=response.get('statusCode', 101), + reason_phrase=response.get( + 'statusDescription', 'Switching Protocols'), + headers=WebsocketHeaders(**response.get('headers', {})), + body=body, + ) class ThreadedHTTPServer(ThreadingMixIn, HTTPServer): @@ -716,8 +945,7 @@ def __init__(self, http_handler_cls: HttpHandlerCls = ChaliceRequestHandler, http_server_cls: HttpServerCls = ThreadedHTTPServer, ws_handler_cls: WsHandlerCls = ChaliceWsHandler, - ws_server_cls: WsServerCls = - websockets.sync.server.serve,) -> None: + ws_server_cls: WsServerCls = CreateWebsocketServer,) -> None: self.app_object = app_object self.host = host self.port = port @@ -727,8 +955,10 @@ def __init__(self, http_handler_cls, app_object=app_object, config=config) self.http_server = http_server_cls( (host, port), self._wrapped_http_handler) - self._ws_handler = ws_handler_cls(app_object, config) - self.ws_server = ws_server_cls(self._ws_handler, ws_host, ws_port) + self._ws_handler = ws_handler_cls(app_object, config, ws_host) + self.ws_server = ws_server_cls( + self._ws_handler, ws_host, ws_port, + process_request=self._ws_handler.process_request) def handle_single_request(self) -> None: self.http_server.handle_request() From 119859a1974a5797493be26b87025dfa4967d6b3 Mon Sep 17 00:00:00 2001 From: Julian Popescu Date: Fri, 20 Oct 2023 18:32:49 +0200 Subject: [PATCH 3/3] check feature used flag to start WS server --- chalice/local.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/chalice/local.py b/chalice/local.py index 6d407a03d..49557d6b6 100644 --- a/chalice/local.py +++ b/chalice/local.py @@ -964,16 +964,20 @@ def handle_single_request(self) -> None: self.http_server.handle_request() def serve_forever(self) -> None: - print("Serving on http://%s:%s and ws://%s:%s" % - (self.host, self.port, self.ws_host, self.ws_port)) - threads = [ - threading.Thread(target=self.http_server.serve_forever), - threading.Thread(target=self.ws_server.serve_forever), - ] - for thread in threads: - thread.start() - for thread in threads: - thread.join() + if 'WEBSOCKETS' in self.app_object._features_used: + print("Serving on http://%s:%s and ws://%s:%s" % + (self.host, self.port, self.ws_host, self.ws_port)) + threads = [ + threading.Thread(target=self.http_server.serve_forever), + threading.Thread(target=self.ws_server.serve_forever), + ] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + else: + print("Serving on http://%s:%s" % (self.host, self.port)) + self.http_server.serve_forever() def shutdown(self) -> None: # This must be called from another thread of else it