Skip to content

Commit

Permalink
Add wrap middleware and wrap permission (#141)
Browse files Browse the repository at this point in the history
  • Loading branch information
tarsil authored Feb 9, 2025
1 parent db3c00f commit 0e16c78
Show file tree
Hide file tree
Showing 7 changed files with 410 additions and 13 deletions.
10 changes: 10 additions & 0 deletions docs/en/docs/release-notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@ hide:

# Release Notes

## 0.12.7

### Added

- Added `version` to Lilya client.

### Changed

- Declaring `DefinePermission` became optional as Lilya automatically wraps if not provided.
- Declaring `DefineMiddleware` became optional as Lilya automatically wraps if not provided.

## 0.12.6

Expand Down
22 changes: 22 additions & 0 deletions lilya/_internal/_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from __future__ import annotations

from typing import Any, cast

from lilya.middleware.base import DefineMiddleware


def wrap_middleware(
middleware: DefineMiddleware | Any,
) -> DefineMiddleware:
"""
Wraps the given middleware into a DefineMiddleware instance if it is not already one.
Or else it will assume its a Lilya permission and wraps it.
Args:
permission (Union["BasePermission", Any]): The permission to be wrapped.
Returns:
BasePermission: The wrapped permission instance.
"""
if isinstance(middleware, DefineMiddleware):
return middleware
return DefineMiddleware(cast(Any, middleware))
22 changes: 22 additions & 0 deletions lilya/_internal/_permissions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from __future__ import annotations

from typing import Any, cast

from lilya.permissions.base import DefinePermission


def wrap_permission(
permission: DefinePermission | Any,
) -> DefinePermission:
"""
Wraps the given permission into a BasePermission instance if it is not already one.
Or else it will assume its a Lilya permission and wraps it.
Args:
permission (Union["BasePermission", Any]): The permission to be wrapped.
Returns:
BasePermission: The wrapped permission instance.
"""
if isinstance(permission, DefinePermission):
return permission
return DefinePermission(cast(Any, permission))
13 changes: 11 additions & 2 deletions lilya/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from functools import cached_property
from typing import Annotated, Any, ClassVar, cast

from lilya._internal._middleware import wrap_middleware
from lilya._internal._module_loading import import_string
from lilya._internal._permissions import wrap_permission
from lilya._utils import is_class_and_subclass
from lilya.conf import __lazy_settings__, settings as lilya_settings
from lilya.conf.exceptions import FieldException
Expand Down Expand Up @@ -389,8 +391,15 @@ async def create_user(request: Request):
self.debug = self.__load_settings_value("debug", debug, is_boolean=True)

self.exception_handlers = {} if exception_handlers is None else dict(exception_handlers)
self.custom_middleware = self.__load_settings_value("middleware", middleware)
self.custom_permissions = self.__load_settings_value("permissions", permissions)
self.custom_middleware = [
wrap_middleware(middleware)
for middleware in self.__load_settings_value("middleware", middleware) or []
]

self.custom_permissions = [
wrap_permission(permission)
for permission in self.__load_settings_value("permissions", permissions) or []
]

self.state = State()
self.middleware_stack: ASGIApp | None = None
Expand Down
52 changes: 41 additions & 11 deletions lilya/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from lilya import status
from lilya._internal._events import AsyncLifespan, handle_lifespan_events
from lilya._internal._middleware import wrap_middleware
from lilya._internal._module_loading import import_string
from lilya._internal._path import (
clean_path,
Expand All @@ -17,6 +18,7 @@
parse_path,
replace_params,
)
from lilya._internal._permissions import wrap_permission
from lilya._internal._responses import BaseHandler
from lilya._internal._urls import include
from lilya.compat import is_async_callable
Expand Down Expand Up @@ -318,12 +320,20 @@ def __init__(
else:
self.app = handler

self.middleware = middleware
self.permissions = permissions
if middleware is not None:
self.middleware = [wrap_middleware(mid) for mid in middleware]
else:
self.middleware = middleware

self.permissions = permissions if permissions is not None else []
self.exception_handlers = {} if exception_handlers is None else dict(exception_handlers)

self.wrapped_permissions = [
wrap_permission(permission) for permission in permissions or []
]

self._apply_middleware(self.middleware)
self._apply_permissions(self.permissions)
self._apply_permissions(self.wrapped_permissions)

if self.methods is not None:
self.methods = [method.upper() for method in self.methods]
Expand Down Expand Up @@ -547,12 +557,20 @@ def __init__(
else:
self.app = handler

self.middleware = middleware
self.permissions = permissions
if middleware is not None:
self.middleware = [wrap_middleware(mid) for mid in middleware]
else:
self.middleware = middleware

self.permissions = permissions if permissions is not None else []
self.exception_handlers = {} if exception_handlers is None else dict(exception_handlers)

self.wrapped_permissions = [
wrap_permission(permission) for permission in permissions or []
]

self._apply_middleware(self.middleware)
self._apply_permissions(self.permissions)
self._apply_permissions(self.wrapped_permissions)

self.path_regex, self.path_format, self.param_convertors, self.path_start = compile_path(
self.path
Expand Down Expand Up @@ -740,8 +758,12 @@ def __init__(
self.permissions = permissions if permissions is not None else []
self.exception_handlers = {} if exception_handlers is None else dict(exception_handlers)

self._apply_middleware(middleware)
self._apply_permissions(permissions)
self.wrapped_permissions = [
wrap_permission(permission) for permission in permissions or []
]

self._apply_middleware(self.middleware)
self._apply_permissions(self.wrapped_permissions)

def _apply_middleware(self, middleware: Sequence[DefineMiddleware] | None) -> None:
"""
Expand Down Expand Up @@ -1000,8 +1022,12 @@ def __init__(
self.permission_started = False
self.is_sub_router = is_sub_router

self.wrapped_permissions = [
wrap_permission(permission) for permission in permissions or []
]

self._apply_middleware(self.middleware)
self._apply_permissions(self.permissions)
self._apply_permissions(self.wrapped_permissions)
self._set_settings_app(self.settings_module, self)

def _apply_middleware(self, middleware: Sequence[DefineMiddleware] | None) -> None:
Expand Down Expand Up @@ -1869,8 +1895,12 @@ def __init__(
self.permissions = permissions if permissions is not None else []
self.exception_handlers = {} if exception_handlers is None else dict(exception_handlers)

self._apply_middleware(middleware)
self._apply_permissions(permissions)
self.wrapped_permissions = [
wrap_permission(permission) for permission in permissions or []
]

self._apply_middleware(self.middleware)
self._apply_permissions(self.wrapped_permissions)

self.name = name
self.include_in_schema = include_in_schema
Expand Down
99 changes: 99 additions & 0 deletions tests/middleware/test_compression_wrap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from lilya.apps import Lilya
from lilya.middleware.compression import GZipMiddleware
from lilya.responses import PlainText, StreamingResponse
from lilya.routing import Path


def test_gzip_responses(test_client_factory):
def homepage():
return PlainText("x" * 4000, status_code=200)

app = Lilya(
routes=[Path("/", handler=homepage)],
middleware=[GZipMiddleware],
)

client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "gzip"})
assert response.status_code == 200
assert response.text == "x" * 4000
assert response.headers["Content-Encoding"] == "gzip"
assert int(response.headers["Content-Length"]) < 4000


def test_gzip_not_in_accept_encoding(test_client_factory):
def homepage(request):
return PlainText("x" * 4000, status_code=200)

app = Lilya(
routes=[Path("/", handler=homepage)],
middleware=[GZipMiddleware],
)

client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "identity"})
assert response.status_code == 200
assert response.text == "x" * 4000
assert "Content-Encoding" not in response.headers
assert int(response.headers["Content-Length"]) == 4000


def test_gzip_ignored_for_small_responses(test_client_factory):
def homepage(request):
return PlainText("OK", status_code=200)

app = Lilya(
routes=[Path("/", handler=homepage)],
middleware=[GZipMiddleware],
)

client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "gzip"})
assert response.status_code == 200
assert response.text == "OK"
assert "Content-Encoding" not in response.headers
assert int(response.headers["Content-Length"]) == 2


def test_gzip_streaming_response(test_client_factory):
def homepage():
async def generator(bytes, count):
for _ in range(count):
yield bytes

streaming = generator(bytes=b"x" * 400, count=10)
return StreamingResponse(streaming, status_code=200)

app = Lilya(
routes=[Path("/", handler=homepage)],
middleware=[GZipMiddleware],
)

client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "gzip"})
assert response.status_code == 200
assert response.text == "x" * 4000
assert response.headers["Content-Encoding"] == "gzip"
assert "Content-Length" not in response.headers


def test_gzip_ignored_for_responses_with_encoding_set(test_client_factory):
def homepage():
async def generator(bytes, count):
for _ in range(count):
yield bytes

streaming = generator(bytes=b"x" * 400, count=10)
return StreamingResponse(streaming, status_code=200, headers={"Content-Encoding": "text"})

app = Lilya(
routes=[Path("/", handler=homepage)],
middleware=[GZipMiddleware],
)

client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "gzip, text"})
assert response.status_code == 200
assert response.text == "x" * 4000
assert response.headers["Content-Encoding"] == "text"
assert "Content-Length" not in response.headers
Loading

0 comments on commit 0e16c78

Please sign in to comment.