Skip to content

Commit

Permalink
Add request lifecycles
Browse files Browse the repository at this point in the history
  • Loading branch information
tarsil committed Feb 12, 2025
1 parent 187ec95 commit 52ce462
Show file tree
Hide file tree
Showing 10 changed files with 268 additions and 4 deletions.
8 changes: 7 additions & 1 deletion docs/en/docs/release-notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@ hide:

# Release Notes

## 0.12.7
## 0.12.8

### Added

- `Lilya`, `Include`, `Host`, `Path` and `Router` now support `before_request` and `after_request`
life cycles. This can be particularly useful to those who want to perform actions before and after
a request is performed. E.g.: Telemetry.

### Added

Expand Down
2 changes: 1 addition & 1 deletion lilya/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.12.7"
__version__ = "0.12.8"
2 changes: 2 additions & 0 deletions lilya/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from lilya.conf.global_settings import Settings
from lilya.datastructures import State, URLPath
from lilya.middleware.app_settings import ApplicationSettingsMiddleware
from lilya.middleware.asyncexit import AsyncExitStackMiddleware
from lilya.middleware.base import DefineMiddleware
from lilya.middleware.exceptions import ExceptionMiddleware
from lilya.middleware.global_context import GlobalContextMiddleware
Expand Down Expand Up @@ -581,6 +582,7 @@ def build_middleware_stack(self) -> ASGIApp:
*self.custom_middleware,
DefineMiddleware(ApplicationSettingsMiddleware),
DefineMiddleware(ExceptionMiddleware, handlers=exception_handlers, debug=self.debug),
DefineMiddleware(AsyncExitStackMiddleware, debug=self.debug),
]

app = self.router
Expand Down
10 changes: 8 additions & 2 deletions lilya/middleware/asyncexit.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,37 @@
from __future__ import annotations

import traceback
from contextlib import AsyncExitStack

from lilya.protocols.middleware import MiddlewareProtocol
from lilya.types import ASGIApp, Receive, Scope, Send


class AsyncExitStackMiddleware(MiddlewareProtocol):
def __init__(self, app: ASGIApp):
def __init__(self, app: ASGIApp, debug: bool = False) -> None:
"""AsyncExitStack Middleware class.
Args:
app: The 'next' ASGI app to call.
"""
super().__init__(app)
self.app = app
self.debug = debug

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if not AsyncExitStack:
await self.app(scope, receive, send) # pragma: no cover

exception: Exception | None = None
async with AsyncExitStack() as stack:
scope["lilya_astack"] = stack
scope["lilya_asyncexitstack"] = stack
try:
await self.app(scope, receive, send)
except Exception as e:
exception = e

if exception and self.debug:
traceback.print_exception(exception, exception, exception.__traceback__) # type: ignore

if exception:
raise exception
62 changes: 62 additions & 0 deletions lilya/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from lilya._internal._responses import BaseHandler
from lilya._internal._urls import include
from lilya.compat import is_async_callable
from lilya.concurrency import run_in_threadpool
from lilya.conf import settings
from lilya.conf.global_settings import Settings
from lilya.datastructures import URL, Header, ScopeHandler, SendReceiveSniffer, URLPath
Expand Down Expand Up @@ -511,10 +512,23 @@ async def handle_dispatch(self, scope: Scope, receive: Receive, send: Send) -> N
await response(scope, receive, send)
else:
try:
for before_request in self.before_request:
if is_async_callable(before_request):
await before_request(scope, receive, send)
else:
await run_in_threadpool(before_request, scope, receive, send)

if not hasattr(self.app, "__is_controller__"):
await self.app(scope, receive, send)
else:
await self.handle_controller(scope, receive, send)

for after_request in self.after_request:
if is_async_callable(after_request):
await after_request(scope, receive, send)
else:
await run_in_threadpool(after_request, scope, receive, send)

except Exception as ex:
await self.handle_exception_handlers(scope, receive, send, ex)

Expand Down Expand Up @@ -692,7 +706,19 @@ async def handle_dispatch(self, scope: Scope, receive: Receive, send: Send) -> N
None
"""
try:
for before_request in self.before_request:
if is_async_callable(before_request):
await before_request(scope, receive, send)
else:
await run_in_threadpool(before_request, scope, receive, send)

await self.app(scope, receive, send)

for after_request in self.after_request:
if is_async_callable(after_request):
await after_request(scope, receive, send)
else:
await run_in_threadpool(after_request, scope, receive, send)
except Exception as ex:
await self.handle_exception_handlers(scope, receive, send, ex)

Expand Down Expand Up @@ -877,7 +903,19 @@ async def handle_dispatch(self, scope: Scope, receive: Receive, send: Send) -> N
None
"""
try:
for before_request in self.before_request:
if is_async_callable(before_request):
await before_request(scope, receive, send)
else:
await run_in_threadpool(before_request, scope, receive, send)

await self.app(scope, receive, send)

for after_request in self.after_request:
if is_async_callable(after_request):
await after_request(scope, receive, send)
else:
await run_in_threadpool(after_request, scope, receive, send)
except Exception as ex:
await self.handle_exception_handlers(scope, receive, send, ex)

Expand Down Expand Up @@ -1420,8 +1458,20 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == ScopeType.LIFESPAN:
await self.lifespan(scope, receive, send)
return

for before_request in self.before_request:
if is_async_callable(before_request):
await before_request(scope, receive, send)
else:
await run_in_threadpool(before_request, scope, receive, send)
await self.middleware_stack(scope, receive, send)

for after_request in self.after_request:
if is_async_callable(after_request):
await after_request(scope, receive, send)
else:
await run_in_threadpool(after_request, scope, receive, send)


class Router(BaseRouter):
"""
Expand Down Expand Up @@ -2044,7 +2094,19 @@ async def handle_dispatch(self, scope: Scope, receive: Receive, send: Send) -> N
None
"""
try:
for before_request in self.before_request:
if is_async_callable(before_request):
await before_request(scope, receive, send)
else:
await run_in_threadpool(before_request, scope, receive, send)

await self.app(scope, receive, send)

for after_request in self.after_request:
if is_async_callable(after_request):
await after_request(scope, receive, send)
else:
await run_in_threadpool(after_request, scope, receive, send)
except Exception as ex:
await self.handle_exception_handlers(scope, receive, send, ex)

Expand Down
4 changes: 4 additions & 0 deletions lilya/testclient/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def create_client(
debug: bool = False,
root_path: str = "",
cookies: httpx._types.CookieTypes | None = None,
before_request: Sequence[Callable[..., Any]] | None = None,
after_request: Sequence[Callable[..., Any]] | None = None,
**kwargs: Any,
) -> TestClient:
"""
Expand Down Expand Up @@ -60,6 +62,8 @@ def create_client(
lifespan=lifespan,
redirect_slashes=redirect_slashes,
include_in_schema=include_in_schema,
before_request=before_request,
after_request=after_request,
**kwargs,
),
base_url=base_url,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ testing = [
"ptpython",
"ipdb",
"pdbpp",
"structlog",
]

docs = [
Expand Down
Empty file.
76 changes: 76 additions & 0 deletions tests/request_lifecycles/test_all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from structlog import get_logger

from lilya.responses import PlainText
from lilya.routing import Include, Path
from lilya.testclient import create_client

logger = get_logger()


async def before_path_request(scope, receive, send):
app = scope["app"]
app.state.app_request += 1
logger.info(f"Before path request: {app.state.app_request}")


async def after_path_request(scope, receive, send):
app = scope["app"]
app.state.app_request += 1

logger.info(f"After path request: {app.state.app_request}")


async def before_include_request(scope, receive, send):
app = scope["app"]
app.state.app_request += 1
logger.info(f"Before include request: {app.state.app_request}")


async def after_include_request(scope, receive, send):
app = scope["app"]
app.state.app_request += 1

logger.info(f"After include request: {app.state.app_request}")


async def before_app_request(scope, receive, send):
app = scope["app"]
app.state.app_request = 1
logger.info(f"Before app request: {app.state.app_request}")


async def after_app_request(scope, receive, send):
app = scope["app"]
app.state.app_request += 1

logger.info(f"After app request: {app.state.app_request}")


def test_all_layers_request():
async def index(request):
state = request.app.state
return PlainText(f"State: {state.app_request}")

with create_client(
routes=[
Include(
"/",
routes=[
Path(
"/",
index,
before_request=[before_path_request],
after_request=[after_path_request],
)
],
before_request=[before_include_request],
after_request=[after_include_request],
),
],
before_request=[before_app_request],
after_request=[after_app_request],
) as client:
response = client.get("/")

assert response.status_code == 200
assert response.text == "State: 3"
107 changes: 107 additions & 0 deletions tests/request_lifecycles/test_before_after_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from structlog import get_logger

from lilya.responses import PlainText
from lilya.routing import Include, Path
from lilya.testclient import create_client

logger = get_logger()


async def before_path_request(scope, receive, send):
app = scope["app"]
app.state.app_request = 1
logger.info(f"Before path request: {app.state.app_request}")


async def after_path_request(scope, receive, send):
app = scope["app"]
app.state.app_request += 1

logger.info(f"After path request: {app.state.app_request}")


def test_path_before_request():
async def index(request):
state = request.app.state
return PlainText(f"State: {state.app_request}")

with create_client(
routes=[
Path(
"/",
index,
before_request=[before_path_request],
after_request=[after_path_request],
)
],
) as client:
response = client.get("/")

assert response.status_code == 200
assert response.text == "State: 1"


async def before_include_request(scope, receive, send):
app = scope["app"]
app.state.app_request = 1
logger.info(f"Before include request: {app.state.app_request}")


async def after_include_request(scope, receive, send):
app = scope["app"]
app.state.app_request += 1

logger.info(f"After include request: {app.state.app_request}")


def test_include_before_request():
async def index(request):
state = request.app.state
return PlainText(f"State: {state.app_request}")

with create_client(
routes=[
Include(
"/",
Path(
"/",
index,
),
before_request=[before_include_request],
after_request=[after_include_request],
)
],
) as client:
response = client.get("/")

assert response.status_code == 200
assert response.text == "State: 1"


async def before_app_request(scope, receive, send):
app = scope["app"]
app.state.app_request = 1
logger.info(f"Before app request: {app.state.app_request}")


async def after_app_request(scope, receive, send):
app = scope["app"]
app.state.app_request += 1

logger.info(f"After app request: {app.state.app_request}")


def test_app_before_request():
async def index(request):
state = request.app.state
return PlainText(f"State: {state.app_request}")

with create_client(
routes=[Path("/", index)],
before_request=[before_app_request],
after_request=[after_app_request],
) as client:
response = client.get("/")

assert response.status_code == 200
assert response.text == "State: 1"

0 comments on commit 52ce462

Please sign in to comment.