Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop trio from test suite #2836

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test-suite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ jobs:
runs-on: "ubuntu-latest"

strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]

Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ types-contextvars==2.4.7.3
types-PyYAML==6.0.12.20240917
types-dataclasses==0.6.6
pytest==8.3.4
trio==0.27.0

# Documentation
black==24.10.0
Expand Down
19 changes: 3 additions & 16 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,10 @@
from __future__ import annotations

import functools
from typing import Any, Literal
from typing import Literal

import pytest

from starlette.testclient import TestClient
from tests.types import TestClientFactory


@pytest.fixture
def test_client_factory(
anyio_backend_name: Literal["asyncio", "trio"],
anyio_backend_options: dict[str, Any],
) -> TestClientFactory:
# anyio_backend_name defined by:
# https://anyio.readthedocs.io/en/stable/testing.html#specifying-the-backends-to-run-on
return functools.partial(
TestClient,
backend=anyio_backend_name,
backend_options=anyio_backend_options,
)
def anyio_backend() -> Literal["asyncio"]:
return "asyncio"
92 changes: 32 additions & 60 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from starlette.testclient import TestClient
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocket
from tests.types import TestClientFactory


class CustomMiddleware(BaseHTTPMiddleware):
Expand Down Expand Up @@ -84,8 +83,8 @@ async def websocket_endpoint(session: WebSocket) -> None:
)


def test_custom_middleware(test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
def test_custom_middleware() -> None:
client = TestClient(app)
response = client.get("/")
assert response.headers["Custom-Header"] == "Example"

Expand All @@ -105,9 +104,7 @@ def test_custom_middleware(test_client_factory: TestClientFactory) -> None:
assert text == "Hello, world!"


def test_state_data_across_multiple_middlewares(
test_client_factory: TestClientFactory,
) -> None:
def test_state_data_across_multiple_middlewares() -> None:
expected_value1 = "foo"
expected_value2 = "bar"

Expand Down Expand Up @@ -154,25 +151,25 @@ def homepage(request: Request) -> PlainTextResponse:
],
)

client = test_client_factory(app)
client = TestClient(app)
response = client.get("/")
assert response.text == "OK"
assert response.headers["X-State-Foo"] == expected_value1
assert response.headers["X-State-Bar"] == expected_value2


def test_app_middleware_argument(test_client_factory: TestClientFactory) -> None:
def test_app_middleware_argument() -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage")

app = Starlette(routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)])

client = test_client_factory(app)
client = TestClient(app)
response = client.get("/")
assert response.headers["Custom-Header"] == "Example"


def test_fully_evaluated_response(test_client_factory: TestClientFactory) -> None:
def test_fully_evaluated_response() -> None:
# Test for https://github.com/encode/starlette/issues/1022
class CustomMiddleware(BaseHTTPMiddleware):
async def dispatch(
Expand All @@ -185,7 +182,7 @@ async def dispatch(

app = Starlette(middleware=[Middleware(CustomMiddleware)])

client = test_client_factory(app)
client = TestClient(app)
response = client.get("/does_not_exist")
assert response.text == "Custom"

Expand Down Expand Up @@ -231,10 +228,7 @@ async def dispatch(
),
],
)
def test_contextvars(
test_client_factory: TestClientFactory,
middleware_cls: _MiddlewareFactory[Any],
) -> None:
def test_contextvars(middleware_cls: _MiddlewareFactory[Any]) -> None:
# this has to be an async endpoint because Starlette calls run_in_threadpool
# on sync endpoints which has it's own set of peculiarities w.r.t propagating
# contextvars (it propagates them forwards but not backwards)
Expand All @@ -245,7 +239,7 @@ async def homepage(request: Request) -> PlainTextResponse:

app = Starlette(middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)])

client = test_client_factory(app)
client = TestClient(app)
response = client.get("/")
assert response.status_code == 200, response.content

Expand Down Expand Up @@ -411,9 +405,7 @@ async def send(message: Message) -> None:
assert context_manager_exited.is_set()


def test_app_receives_http_disconnect_while_sending_if_discarded(
test_client_factory: TestClientFactory,
) -> None:
def test_app_receives_http_disconnect_while_sending_if_discarded() -> None:
class DiscardingMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
Expand Down Expand Up @@ -482,14 +474,12 @@ async def cancel_on_disconnect(

app = DiscardingMiddleware(downstream_app)

client = test_client_factory(app)
client = TestClient(app)
response = client.get("/does_not_exist")
assert response.text == "Custom"


def test_app_receives_http_disconnect_after_sending_if_discarded(
test_client_factory: TestClientFactory,
) -> None:
def test_app_receives_http_disconnect_after_sending_if_discarded() -> None:
class DiscardingMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
Expand Down Expand Up @@ -532,14 +522,12 @@ async def downstream_app(

app = DiscardingMiddleware(downstream_app)

client = test_client_factory(app)
client = TestClient(app)
response = client.get("/does_not_exist")
assert response.text == "Custom"


def test_read_request_stream_in_app_after_middleware_calls_stream(
test_client_factory: TestClientFactory,
) -> None:
def test_read_request_stream_in_app_after_middleware_calls_stream() -> None:
async def homepage(request: Request) -> PlainTextResponse:
expected = [b""]
async for chunk in request.stream():
Expand All @@ -564,14 +552,12 @@ async def dispatch(
middleware=[Middleware(ConsumingMiddleware)],
)

client: TestClient = test_client_factory(app)
client: TestClient = TestClient(app)
response = client.post("/", content=b"a")
assert response.status_code == 200


def test_read_request_stream_in_app_after_middleware_calls_body(
test_client_factory: TestClientFactory,
) -> None:
def test_read_request_stream_in_app_after_middleware_calls_body() -> None:
async def homepage(request: Request) -> PlainTextResponse:
expected = [b"a", b""]
async for chunk in request.stream():
Expand All @@ -593,14 +579,12 @@ async def dispatch(
middleware=[Middleware(ConsumingMiddleware)],
)

client: TestClient = test_client_factory(app)
client: TestClient = TestClient(app)
response = client.post("/", content=b"a")
assert response.status_code == 200


def test_read_request_body_in_app_after_middleware_calls_stream(
test_client_factory: TestClientFactory,
) -> None:
def test_read_request_body_in_app_after_middleware_calls_stream() -> None:
async def homepage(request: Request) -> PlainTextResponse:
assert await request.body() == b""
return PlainTextResponse("Homepage")
Expand All @@ -622,14 +606,12 @@ async def dispatch(
middleware=[Middleware(ConsumingMiddleware)],
)

client: TestClient = test_client_factory(app)
client: TestClient = TestClient(app)
response = client.post("/", content=b"a")
assert response.status_code == 200


def test_read_request_body_in_app_after_middleware_calls_body(
test_client_factory: TestClientFactory,
) -> None:
def test_read_request_body_in_app_after_middleware_calls_body() -> None:
async def homepage(request: Request) -> PlainTextResponse:
assert await request.body() == b"a"
return PlainTextResponse("Homepage")
Expand All @@ -648,14 +630,12 @@ async def dispatch(
middleware=[Middleware(ConsumingMiddleware)],
)

client: TestClient = test_client_factory(app)
client: TestClient = TestClient(app)
response = client.post("/", content=b"a")
assert response.status_code == 200


def test_read_request_stream_in_dispatch_after_app_calls_stream(
test_client_factory: TestClientFactory,
) -> None:
def test_read_request_stream_in_dispatch_after_app_calls_stream() -> None:
async def homepage(request: Request) -> PlainTextResponse:
expected = [b"a", b""]
async for chunk in request.stream():
Expand All @@ -680,14 +660,12 @@ async def dispatch(
middleware=[Middleware(ConsumingMiddleware)],
)

client: TestClient = test_client_factory(app)
client: TestClient = TestClient(app)
response = client.post("/", content=b"a")
assert response.status_code == 200


def test_read_request_stream_in_dispatch_after_app_calls_body(
test_client_factory: TestClientFactory,
) -> None:
def test_read_request_stream_in_dispatch_after_app_calls_body() -> None:
async def homepage(request: Request) -> PlainTextResponse:
assert await request.body() == b"a"
return PlainTextResponse("Homepage")
Expand All @@ -709,7 +687,7 @@ async def dispatch(
middleware=[Middleware(ConsumingMiddleware)],
)

client: TestClient = test_client_factory(app)
client: TestClient = TestClient(app)
response = client.post("/", content=b"a")
assert response.status_code == 200

Expand Down Expand Up @@ -773,9 +751,7 @@ async def send(msg: Message) -> None:
await rcv_stream.aclose()


def test_read_request_stream_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next(
test_client_factory: TestClientFactory,
) -> None:
def test_read_request_stream_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next() -> None:
async def homepage(request: Request) -> PlainTextResponse:
assert await request.body() == b"a"
return PlainTextResponse("Homepage")
Expand All @@ -798,14 +774,12 @@ async def dispatch(
middleware=[Middleware(ConsumingMiddleware)],
)

client: TestClient = test_client_factory(app)
client: TestClient = TestClient(app)
response = client.post("/", content=b"a")
assert response.status_code == 200


def test_read_request_body_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next(
test_client_factory: TestClientFactory,
) -> None:
def test_read_request_body_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next() -> None:
async def homepage(request: Request) -> PlainTextResponse:
assert await request.body() == b"a"
return PlainTextResponse("Homepage")
Expand All @@ -826,7 +800,7 @@ async def dispatch(
middleware=[Middleware(ConsumingMiddleware)],
)

client: TestClient = test_client_factory(app)
client: TestClient = TestClient(app)
response = client.post("/", content=b"a")
assert response.status_code == 200

Expand Down Expand Up @@ -917,9 +891,7 @@ async def send(msg: Message) -> None:
await rcv.aclose()


def test_downstream_middleware_modifies_receive(
test_client_factory: TestClientFactory,
) -> None:
def test_downstream_middleware_modifies_receive() -> None:
"""If a downstream middleware modifies receive() the final ASGI app
should see the modified version.
"""
Expand Down Expand Up @@ -952,7 +924,7 @@ async def wrapped_receive() -> Message:

return wrapped_app

client = test_client_factory(ConsumingMiddleware(modifying_middleware(endpoint)))
client = TestClient(ConsumingMiddleware(modifying_middleware(endpoint)))

resp = client.post("/", content=b"foo ")
assert resp.status_code == 200
Expand Down
Loading
Loading