Skip to content

Commit

Permalink
Add more testing for classes
Browse files Browse the repository at this point in the history
  • Loading branch information
tarsil committed Feb 12, 2025
1 parent 52ce462 commit 09b8dfa
Show file tree
Hide file tree
Showing 4 changed files with 270 additions and 0 deletions.
30 changes: 30 additions & 0 deletions lilya/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,9 @@ async def handle_dispatch(self, scope: Scope, receive: Receive, send: Send) -> N
else:
try:
for before_request in self.before_request:
if inspect.isclass(before_request):
before_request = before_request()

if is_async_callable(before_request):
await before_request(scope, receive, send)
else:
Expand All @@ -524,6 +527,9 @@ async def handle_dispatch(self, scope: Scope, receive: Receive, send: Send) -> N
await self.handle_controller(scope, receive, send)

for after_request in self.after_request:
if inspect.isclass(after_request):
after_request = after_request()

if is_async_callable(after_request):
await after_request(scope, receive, send)
else:
Expand Down Expand Up @@ -707,6 +713,9 @@ async def handle_dispatch(self, scope: Scope, receive: Receive, send: Send) -> N
"""
try:
for before_request in self.before_request:
if inspect.isclass(before_request):
before_request = before_request()

if is_async_callable(before_request):
await before_request(scope, receive, send)
else:
Expand All @@ -715,6 +724,9 @@ async def handle_dispatch(self, scope: Scope, receive: Receive, send: Send) -> N
await self.app(scope, receive, send)

for after_request in self.after_request:
if inspect.isclass(after_request):
after_request = after_request()

if is_async_callable(after_request):
await after_request(scope, receive, send)
else:
Expand Down Expand Up @@ -904,6 +916,9 @@ async def handle_dispatch(self, scope: Scope, receive: Receive, send: Send) -> N
"""
try:
for before_request in self.before_request:
if inspect.isclass(before_request):
before_request = before_request()

if is_async_callable(before_request):
await before_request(scope, receive, send)
else:
Expand All @@ -912,6 +927,9 @@ async def handle_dispatch(self, scope: Scope, receive: Receive, send: Send) -> N
await self.app(scope, receive, send)

for after_request in self.after_request:
if inspect.isclass(after_request):
after_request = after_request()

if is_async_callable(after_request):
await after_request(scope, receive, send)
else:
Expand Down Expand Up @@ -1460,13 +1478,19 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
return

for before_request in self.before_request:
if inspect.isclass(before_request):
before_request = before_request()

if is_async_callable(before_request):
await before_request(scope, receive, send)
else:
await run_in_threadpool(before_request, scope, receive, send)
await self.middleware_stack(scope, receive, send)

for after_request in self.after_request:
if inspect.isclass(after_request):
after_request = after_request()

if is_async_callable(after_request):
await after_request(scope, receive, send)
else:
Expand Down Expand Up @@ -2095,6 +2119,9 @@ async def handle_dispatch(self, scope: Scope, receive: Receive, send: Send) -> N
"""
try:
for before_request in self.before_request:
if inspect.isclass(before_request):
before_request = before_request()

if is_async_callable(before_request):
await before_request(scope, receive, send)
else:
Expand All @@ -2103,6 +2130,9 @@ async def handle_dispatch(self, scope: Scope, receive: Receive, send: Send) -> N
await self.app(scope, receive, send)

for after_request in self.after_request:
if inspect.isclass(after_request):
after_request = after_request()

if is_async_callable(after_request):
await after_request(scope, receive, send)
else:
Expand Down
82 changes: 82 additions & 0 deletions tests/request_lifecycles/test_all_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from structlog import get_logger

from lilya.responses import PlainText
from lilya.routing import Include, Path
from lilya.testclient import create_client

logger = get_logger()


class BeforePathRequest:
async def __call__(self, scope, receive, send):
app = scope["app"]
app.state.app_request += 1
logger.info(f"Before path request: {app.state.app_request}")


class AfterPathRequest:
async def __call__(self, scope, receive, send):
app = scope["app"]
app.state.app_request += 1

logger.info(f"After path request: {app.state.app_request}")


class BeforeIncludeRequest:
async def __call__(self, scope, receive, send):
app = scope["app"]
app.state.app_request += 1
logger.info(f"Before include request: {app.state.app_request}")


class AfterIncludeRequest:
async def __call__(self, scope, receive, send):
app = scope["app"]
app.state.app_request += 1

logger.info(f"After include request: {app.state.app_request}")


class BeforeAppRequest:
async def __call__(self, scope, receive, send):
app = scope["app"]
app.state.app_request = 1
logger.info(f"Before app request: {app.state.app_request}")


class AfterAppRequest:
async def __call__(self, scope, receive, send):
app = scope["app"]
app.state.app_request += 1

logger.info(f"After app request: {app.state.app_request}")


def test_all_layers_request():
async def index(request):
state = request.app.state
return PlainText(f"State: {state.app_request}")

with create_client(
routes=[
Include(
"/",
routes=[
Path(
"/",
index,
before_request=[BeforePathRequest],
after_request=[AfterPathRequest],
)
],
before_request=[BeforeIncludeRequest],
after_request=[AfterIncludeRequest],
),
],
before_request=[BeforeAppRequest],
after_request=[AfterAppRequest],
) as client:
response = client.get("/")

assert response.status_code == 200
assert response.text == "State: 3"
82 changes: 82 additions & 0 deletions tests/request_lifecycles/test_all_class_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from structlog import get_logger

from lilya.responses import PlainText
from lilya.routing import Include, Path
from lilya.testclient import create_client

logger = get_logger()


class BeforePathRequest:
def __call__(self, scope, receive, send):
app = scope["app"]
app.state.app_request += 1
logger.info(f"Before path request: {app.state.app_request}")


class AfterPathRequest:
def __call__(self, scope, receive, send):
app = scope["app"]
app.state.app_request += 1

logger.info(f"After path request: {app.state.app_request}")


class BeforeIncludeRequest:
def __call__(self, scope, receive, send):
app = scope["app"]
app.state.app_request += 1
logger.info(f"Before include request: {app.state.app_request}")


class AfterIncludeRequest:
def __call__(self, scope, receive, send):
app = scope["app"]
app.state.app_request += 1

logger.info(f"After include request: {app.state.app_request}")


class BeforeAppRequest:
def __call__(self, scope, receive, send):
app = scope["app"]
app.state.app_request = 1
logger.info(f"Before app request: {app.state.app_request}")


class AfterAppRequest:
def __call__(self, scope, receive, send):
app = scope["app"]
app.state.app_request += 1

logger.info(f"After app request: {app.state.app_request}")


def test_all_layers_request():
async def index(request):
state = request.app.state
return PlainText(f"State: {state.app_request}")

with create_client(
routes=[
Include(
"/",
routes=[
Path(
"/",
index,
before_request=[BeforePathRequest],
after_request=[AfterPathRequest],
)
],
before_request=[BeforeIncludeRequest],
after_request=[AfterIncludeRequest],
),
],
before_request=[BeforeAppRequest],
after_request=[AfterAppRequest],
) as client:
response = client.get("/")

assert response.status_code == 200
assert response.text == "State: 3"
76 changes: 76 additions & 0 deletions tests/request_lifecycles/test_all_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from structlog import get_logger

from lilya.responses import PlainText
from lilya.routing import Include, Path
from lilya.testclient import create_client

logger = get_logger()


def before_path_request(scope, receive, send):
app = scope["app"]
app.state.app_request += 1
logger.info(f"Before path request: {app.state.app_request}")


def after_path_request(scope, receive, send):
app = scope["app"]
app.state.app_request += 1

logger.info(f"After path request: {app.state.app_request}")


def before_include_request(scope, receive, send):
app = scope["app"]
app.state.app_request += 1
logger.info(f"Before include request: {app.state.app_request}")


def after_include_request(scope, receive, send):
app = scope["app"]
app.state.app_request += 1

logger.info(f"After include request: {app.state.app_request}")


def before_app_request(scope, receive, send):
app = scope["app"]
app.state.app_request = 1
logger.info(f"Before app request: {app.state.app_request}")


def after_app_request(scope, receive, send):
app = scope["app"]
app.state.app_request += 1

logger.info(f"After app request: {app.state.app_request}")


def test_all_layers_request():
async def index(request):
state = request.app.state
return PlainText(f"State: {state.app_request}")

with create_client(
routes=[
Include(
"/",
routes=[
Path(
"/",
index,
before_request=[before_path_request],
after_request=[after_path_request],
)
],
before_request=[before_include_request],
after_request=[after_include_request],
),
],
before_request=[before_app_request],
after_request=[after_app_request],
) as client:
response = client.get("/")

assert response.status_code == 200
assert response.text == "State: 3"

0 comments on commit 09b8dfa

Please sign in to comment.