diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index 14567a26e..dd00da433 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -13,6 +13,7 @@ jobs: runs-on: "ubuntu-latest" strategy: + fail-fast: false matrix: python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] diff --git a/requirements.txt b/requirements.txt index 83c34fc9c..6f65fd509 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,6 @@ types-contextvars==2.4.7.3 types-PyYAML==6.0.12.20240917 types-dataclasses==0.6.6 pytest==8.3.4 -trio==0.27.0 # Documentation black==24.10.0 diff --git a/tests/conftest.py b/tests/conftest.py index 4db3ae018..dc1046201 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,23 +1,10 @@ from __future__ import annotations -import functools -from typing import Any, Literal +from typing import Literal import pytest -from starlette.testclient import TestClient -from tests.types import TestClientFactory - @pytest.fixture -def test_client_factory( - anyio_backend_name: Literal["asyncio", "trio"], - anyio_backend_options: dict[str, Any], -) -> TestClientFactory: - # anyio_backend_name defined by: - # https://anyio.readthedocs.io/en/stable/testing.html#specifying-the-backends-to-run-on - return functools.partial( - TestClient, - backend=anyio_backend_name, - backend_options=anyio_backend_options, - ) +def anyio_backend() -> Literal["asyncio"]: + return "asyncio" diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 7232cfd18..60d3a0cfe 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -19,7 +19,6 @@ from starlette.testclient import TestClient from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocket -from tests.types import TestClientFactory class CustomMiddleware(BaseHTTPMiddleware): @@ -84,8 +83,8 @@ async def websocket_endpoint(session: WebSocket) -> None: ) -def test_custom_middleware(test_client_factory: TestClientFactory) -> None: - client = test_client_factory(app) +def test_custom_middleware() -> None: + client = TestClient(app) response = client.get("/") assert response.headers["Custom-Header"] == "Example" @@ -105,9 +104,7 @@ def test_custom_middleware(test_client_factory: TestClientFactory) -> None: assert text == "Hello, world!" -def test_state_data_across_multiple_middlewares( - test_client_factory: TestClientFactory, -) -> None: +def test_state_data_across_multiple_middlewares() -> None: expected_value1 = "foo" expected_value2 = "bar" @@ -154,25 +151,25 @@ def homepage(request: Request) -> PlainTextResponse: ], ) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.text == "OK" assert response.headers["X-State-Foo"] == expected_value1 assert response.headers["X-State-Bar"] == expected_value2 -def test_app_middleware_argument(test_client_factory: TestClientFactory) -> None: +def test_app_middleware_argument() -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage") app = Starlette(routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)]) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.headers["Custom-Header"] == "Example" -def test_fully_evaluated_response(test_client_factory: TestClientFactory) -> None: +def test_fully_evaluated_response() -> None: # Test for https://github.com/encode/starlette/issues/1022 class CustomMiddleware(BaseHTTPMiddleware): async def dispatch( @@ -185,7 +182,7 @@ async def dispatch( app = Starlette(middleware=[Middleware(CustomMiddleware)]) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/does_not_exist") assert response.text == "Custom" @@ -231,10 +228,7 @@ async def dispatch( ), ], ) -def test_contextvars( - test_client_factory: TestClientFactory, - middleware_cls: _MiddlewareFactory[Any], -) -> None: +def test_contextvars(middleware_cls: _MiddlewareFactory[Any]) -> None: # this has to be an async endpoint because Starlette calls run_in_threadpool # on sync endpoints which has it's own set of peculiarities w.r.t propagating # contextvars (it propagates them forwards but not backwards) @@ -245,7 +239,7 @@ async def homepage(request: Request) -> PlainTextResponse: app = Starlette(middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)]) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.status_code == 200, response.content @@ -411,9 +405,7 @@ async def send(message: Message) -> None: assert context_manager_exited.is_set() -def test_app_receives_http_disconnect_while_sending_if_discarded( - test_client_factory: TestClientFactory, -) -> None: +def test_app_receives_http_disconnect_while_sending_if_discarded() -> None: class DiscardingMiddleware(BaseHTTPMiddleware): async def dispatch( self, @@ -482,14 +474,12 @@ async def cancel_on_disconnect( app = DiscardingMiddleware(downstream_app) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/does_not_exist") assert response.text == "Custom" -def test_app_receives_http_disconnect_after_sending_if_discarded( - test_client_factory: TestClientFactory, -) -> None: +def test_app_receives_http_disconnect_after_sending_if_discarded() -> None: class DiscardingMiddleware(BaseHTTPMiddleware): async def dispatch( self, @@ -532,14 +522,12 @@ async def downstream_app( app = DiscardingMiddleware(downstream_app) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/does_not_exist") assert response.text == "Custom" -def test_read_request_stream_in_app_after_middleware_calls_stream( - test_client_factory: TestClientFactory, -) -> None: +def test_read_request_stream_in_app_after_middleware_calls_stream() -> None: async def homepage(request: Request) -> PlainTextResponse: expected = [b""] async for chunk in request.stream(): @@ -564,14 +552,12 @@ async def dispatch( middleware=[Middleware(ConsumingMiddleware)], ) - client: TestClient = test_client_factory(app) + client: TestClient = TestClient(app) response = client.post("/", content=b"a") assert response.status_code == 200 -def test_read_request_stream_in_app_after_middleware_calls_body( - test_client_factory: TestClientFactory, -) -> None: +def test_read_request_stream_in_app_after_middleware_calls_body() -> None: async def homepage(request: Request) -> PlainTextResponse: expected = [b"a", b""] async for chunk in request.stream(): @@ -593,14 +579,12 @@ async def dispatch( middleware=[Middleware(ConsumingMiddleware)], ) - client: TestClient = test_client_factory(app) + client: TestClient = TestClient(app) response = client.post("/", content=b"a") assert response.status_code == 200 -def test_read_request_body_in_app_after_middleware_calls_stream( - test_client_factory: TestClientFactory, -) -> None: +def test_read_request_body_in_app_after_middleware_calls_stream() -> None: async def homepage(request: Request) -> PlainTextResponse: assert await request.body() == b"" return PlainTextResponse("Homepage") @@ -622,14 +606,12 @@ async def dispatch( middleware=[Middleware(ConsumingMiddleware)], ) - client: TestClient = test_client_factory(app) + client: TestClient = TestClient(app) response = client.post("/", content=b"a") assert response.status_code == 200 -def test_read_request_body_in_app_after_middleware_calls_body( - test_client_factory: TestClientFactory, -) -> None: +def test_read_request_body_in_app_after_middleware_calls_body() -> None: async def homepage(request: Request) -> PlainTextResponse: assert await request.body() == b"a" return PlainTextResponse("Homepage") @@ -648,14 +630,12 @@ async def dispatch( middleware=[Middleware(ConsumingMiddleware)], ) - client: TestClient = test_client_factory(app) + client: TestClient = TestClient(app) response = client.post("/", content=b"a") assert response.status_code == 200 -def test_read_request_stream_in_dispatch_after_app_calls_stream( - test_client_factory: TestClientFactory, -) -> None: +def test_read_request_stream_in_dispatch_after_app_calls_stream() -> None: async def homepage(request: Request) -> PlainTextResponse: expected = [b"a", b""] async for chunk in request.stream(): @@ -680,14 +660,12 @@ async def dispatch( middleware=[Middleware(ConsumingMiddleware)], ) - client: TestClient = test_client_factory(app) + client: TestClient = TestClient(app) response = client.post("/", content=b"a") assert response.status_code == 200 -def test_read_request_stream_in_dispatch_after_app_calls_body( - test_client_factory: TestClientFactory, -) -> None: +def test_read_request_stream_in_dispatch_after_app_calls_body() -> None: async def homepage(request: Request) -> PlainTextResponse: assert await request.body() == b"a" return PlainTextResponse("Homepage") @@ -709,7 +687,7 @@ async def dispatch( middleware=[Middleware(ConsumingMiddleware)], ) - client: TestClient = test_client_factory(app) + client: TestClient = TestClient(app) response = client.post("/", content=b"a") assert response.status_code == 200 @@ -773,9 +751,7 @@ async def send(msg: Message) -> None: await rcv_stream.aclose() -def test_read_request_stream_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next( - test_client_factory: TestClientFactory, -) -> None: +def test_read_request_stream_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next() -> None: async def homepage(request: Request) -> PlainTextResponse: assert await request.body() == b"a" return PlainTextResponse("Homepage") @@ -798,14 +774,12 @@ async def dispatch( middleware=[Middleware(ConsumingMiddleware)], ) - client: TestClient = test_client_factory(app) + client: TestClient = TestClient(app) response = client.post("/", content=b"a") assert response.status_code == 200 -def test_read_request_body_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next( - test_client_factory: TestClientFactory, -) -> None: +def test_read_request_body_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next() -> None: async def homepage(request: Request) -> PlainTextResponse: assert await request.body() == b"a" return PlainTextResponse("Homepage") @@ -826,7 +800,7 @@ async def dispatch( middleware=[Middleware(ConsumingMiddleware)], ) - client: TestClient = test_client_factory(app) + client: TestClient = TestClient(app) response = client.post("/", content=b"a") assert response.status_code == 200 @@ -917,9 +891,7 @@ async def send(msg: Message) -> None: await rcv.aclose() -def test_downstream_middleware_modifies_receive( - test_client_factory: TestClientFactory, -) -> None: +def test_downstream_middleware_modifies_receive() -> None: """If a downstream middleware modifies receive() the final ASGI app should see the modified version. """ @@ -952,7 +924,7 @@ async def wrapped_receive() -> Message: return wrapped_app - client = test_client_factory(ConsumingMiddleware(modifying_middleware(endpoint))) + client = TestClient(ConsumingMiddleware(modifying_middleware(endpoint))) resp = client.post("/", content=b"foo ") assert resp.status_code == 200 diff --git a/tests/middleware/test_cors.py b/tests/middleware/test_cors.py index 0d987263e..9e4016bbf 100644 --- a/tests/middleware/test_cors.py +++ b/tests/middleware/test_cors.py @@ -4,12 +4,10 @@ from starlette.requests import Request from starlette.responses import PlainTextResponse from starlette.routing import Route -from tests.types import TestClientFactory +from starlette.testclient import TestClient -def test_cors_allow_all( - test_client_factory: TestClientFactory, -) -> None: +def test_cors_allow_all() -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) @@ -27,7 +25,7 @@ def homepage(request: Request) -> PlainTextResponse: ], ) - client = test_client_factory(app) + client = TestClient(app) # Test pre-flight response headers = { @@ -68,9 +66,7 @@ def homepage(request: Request) -> PlainTextResponse: assert "access-control-allow-origin" not in response.headers -def test_cors_allow_all_except_credentials( - test_client_factory: TestClientFactory, -) -> None: +def test_cors_allow_all_except_credentials() -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) @@ -87,7 +83,7 @@ def homepage(request: Request) -> PlainTextResponse: ], ) - client = test_client_factory(app) + client = TestClient(app) # Test pre-flight response headers = { @@ -119,9 +115,7 @@ def homepage(request: Request) -> PlainTextResponse: assert "access-control-allow-origin" not in response.headers -def test_cors_allow_specific_origin( - test_client_factory: TestClientFactory, -) -> None: +def test_cors_allow_specific_origin() -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) @@ -136,7 +130,7 @@ def homepage(request: Request) -> PlainTextResponse: ], ) - client = test_client_factory(app) + client = TestClient(app) # Test pre-flight response headers = { @@ -168,9 +162,7 @@ def homepage(request: Request) -> PlainTextResponse: assert "access-control-allow-origin" not in response.headers -def test_cors_disallowed_preflight( - test_client_factory: TestClientFactory, -) -> None: +def test_cors_disallowed_preflight() -> None: def homepage(request: Request) -> None: pass # pragma: no cover @@ -185,7 +177,7 @@ def homepage(request: Request) -> None: ], ) - client = test_client_factory(app) + client = TestClient(app) # Test pre-flight response headers = { @@ -209,9 +201,7 @@ def homepage(request: Request) -> None: assert response.text == "Disallowed CORS headers" -def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_allowed( - test_client_factory: TestClientFactory, -) -> None: +def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_allowed() -> None: def homepage(request: Request) -> None: return # pragma: no cover @@ -227,7 +217,7 @@ def homepage(request: Request) -> None: ], ) - client = test_client_factory(app) + client = TestClient(app) # Test pre-flight response headers = { @@ -244,9 +234,7 @@ def homepage(request: Request) -> None: assert response.headers["vary"] == "Origin" -def test_cors_preflight_allow_all_methods( - test_client_factory: TestClientFactory, -) -> None: +def test_cors_preflight_allow_all_methods() -> None: def homepage(request: Request) -> None: pass # pragma: no cover @@ -255,7 +243,7 @@ def homepage(request: Request) -> None: middleware=[Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])], ) - client = test_client_factory(app) + client = TestClient(app) headers = { "Origin": "https://example.org", @@ -268,9 +256,7 @@ def homepage(request: Request) -> None: assert method in response.headers["access-control-allow-methods"] -def test_cors_allow_all_methods( - test_client_factory: TestClientFactory, -) -> None: +def test_cors_allow_all_methods() -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) @@ -285,7 +271,7 @@ def homepage(request: Request) -> PlainTextResponse: middleware=[Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])], ) - client = test_client_factory(app) + client = TestClient(app) headers = {"Origin": "https://example.org"} @@ -297,9 +283,7 @@ def homepage(request: Request) -> PlainTextResponse: assert response.status_code == 200 -def test_cors_allow_origin_regex( - test_client_factory: TestClientFactory, -) -> None: +def test_cors_allow_origin_regex() -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) @@ -315,7 +299,7 @@ def homepage(request: Request) -> PlainTextResponse: ], ) - client = test_client_factory(app) + client = TestClient(app) # Test standard response headers = {"Origin": "https://example.org"} @@ -369,9 +353,7 @@ def homepage(request: Request) -> PlainTextResponse: assert "access-control-allow-origin" not in response.headers -def test_cors_allow_origin_regex_fullmatch( - test_client_factory: TestClientFactory, -) -> None: +def test_cors_allow_origin_regex_fullmatch() -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) @@ -386,7 +368,7 @@ def homepage(request: Request) -> PlainTextResponse: ], ) - client = test_client_factory(app) + client = TestClient(app) # Test standard response headers = {"Origin": "https://subdomain.example.org"} @@ -404,9 +386,7 @@ def homepage(request: Request) -> PlainTextResponse: assert "access-control-allow-origin" not in response.headers -def test_cors_credentialed_requests_return_specific_origin( - test_client_factory: TestClientFactory, -) -> None: +def test_cors_credentialed_requests_return_specific_origin() -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) @@ -414,7 +394,7 @@ def homepage(request: Request) -> PlainTextResponse: routes=[Route("/", endpoint=homepage)], middleware=[Middleware(CORSMiddleware, allow_origins=["*"])], ) - client = test_client_factory(app) + client = TestClient(app) # Test credentialed request headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"} @@ -425,9 +405,7 @@ def homepage(request: Request) -> PlainTextResponse: assert "access-control-allow-credentials" not in response.headers -def test_cors_vary_header_defaults_to_origin( - test_client_factory: TestClientFactory, -) -> None: +def test_cors_vary_header_defaults_to_origin() -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) @@ -438,16 +416,14 @@ def homepage(request: Request) -> PlainTextResponse: headers = {"Origin": "https://example.org"} - client = test_client_factory(app) + client = TestClient(app) response = client.get("/", headers=headers) assert response.status_code == 200 assert response.headers["vary"] == "Origin" -def test_cors_vary_header_is_not_set_for_non_credentialed_request( - test_client_factory: TestClientFactory, -) -> None: +def test_cors_vary_header_is_not_set_for_non_credentialed_request() -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}) @@ -455,16 +431,14 @@ def homepage(request: Request) -> PlainTextResponse: routes=[Route("/", endpoint=homepage)], middleware=[Middleware(CORSMiddleware, allow_origins=["*"])], ) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/", headers={"Origin": "https://someplace.org"}) assert response.status_code == 200 assert response.headers["vary"] == "Accept-Encoding" -def test_cors_vary_header_is_properly_set_for_credentialed_request( - test_client_factory: TestClientFactory, -) -> None: +def test_cors_vary_header_is_properly_set_for_credentialed_request() -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}) @@ -472,16 +446,14 @@ def homepage(request: Request) -> PlainTextResponse: routes=[Route("/", endpoint=homepage)], middleware=[Middleware(CORSMiddleware, allow_origins=["*"])], ) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"}) assert response.status_code == 200 assert response.headers["vary"] == "Accept-Encoding, Origin" -def test_cors_vary_header_is_properly_set_when_allow_origins_is_not_wildcard( - test_client_factory: TestClientFactory, -) -> None: +def test_cors_vary_header_is_properly_set_when_allow_origins_is_not_wildcard() -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}) @@ -491,16 +463,14 @@ def homepage(request: Request) -> PlainTextResponse: ], middleware=[Middleware(CORSMiddleware, allow_origins=["https://example.org"])], ) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/", headers={"Origin": "https://example.org"}) assert response.status_code == 200 assert response.headers["vary"] == "Accept-Encoding, Origin" -def test_cors_allowed_origin_does_not_leak_between_credentialed_requests( - test_client_factory: TestClientFactory, -) -> None: +def test_cors_allowed_origin_does_not_leak_between_credentialed_requests() -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) @@ -518,7 +488,7 @@ def homepage(request: Request) -> PlainTextResponse: ], ) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/", headers={"Origin": "https://someplace.org"}) assert response.headers["access-control-allow-origin"] == "*" assert "access-control-allow-credentials" not in response.headers diff --git a/tests/middleware/test_errors.py b/tests/middleware/test_errors.py index 0b0f7e51d..ea9d8b772 100644 --- a/tests/middleware/test_errors.py +++ b/tests/middleware/test_errors.py @@ -8,13 +8,11 @@ from starlette.requests import Request from starlette.responses import JSONResponse, Response from starlette.routing import Route +from starlette.testclient import TestClient from starlette.types import Receive, Scope, Send -from tests.types import TestClientFactory -def test_handler( - test_client_factory: TestClientFactory, -) -> None: +def test_handler() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: raise RuntimeError("Something went wrong") @@ -22,49 +20,49 @@ def error_500(request: Request, exc: Exception) -> JSONResponse: return JSONResponse({"detail": "Server Error"}, status_code=500) app = ServerErrorMiddleware(app, handler=error_500) - client = test_client_factory(app, raise_server_exceptions=False) + client = TestClient(app, raise_server_exceptions=False) response = client.get("/") assert response.status_code == 500 assert response.json() == {"detail": "Server Error"} -def test_debug_text(test_client_factory: TestClientFactory) -> None: +def test_debug_text() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: raise RuntimeError("Something went wrong") app = ServerErrorMiddleware(app, debug=True) - client = test_client_factory(app, raise_server_exceptions=False) + client = TestClient(app, raise_server_exceptions=False) response = client.get("/") assert response.status_code == 500 assert response.headers["content-type"].startswith("text/plain") assert "RuntimeError: Something went wrong" in response.text -def test_debug_html(test_client_factory: TestClientFactory) -> None: +def test_debug_html() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: raise RuntimeError("Something went wrong") app = ServerErrorMiddleware(app, debug=True) - client = test_client_factory(app, raise_server_exceptions=False) + client = TestClient(app, raise_server_exceptions=False) response = client.get("/", headers={"Accept": "text/html, */*"}) assert response.status_code == 500 assert response.headers["content-type"].startswith("text/html") assert "RuntimeError" in response.text -def test_debug_after_response_sent(test_client_factory: TestClientFactory) -> None: +def test_debug_after_response_sent() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response(b"", status_code=204) await response(scope, receive, send) raise RuntimeError("Something went wrong") app = ServerErrorMiddleware(app, debug=True) - client = test_client_factory(app) + client = TestClient(app) with pytest.raises(RuntimeError): client.get("/") -def test_debug_not_http(test_client_factory: TestClientFactory) -> None: +def test_debug_not_http() -> None: """ DebugMiddleware should just pass through any non-http messages as-is. """ @@ -75,12 +73,12 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: app = ServerErrorMiddleware(app) with pytest.raises(RuntimeError): - client = test_client_factory(app) + client = TestClient(app) with client.websocket_connect("/"): pass # pragma: no cover -def test_background_task(test_client_factory: TestClientFactory) -> None: +def test_background_task() -> None: accessed_error_handler = False def error_handler(request: Request, exc: Exception) -> Any: @@ -99,7 +97,7 @@ async def endpoint(request: Request) -> Response: exception_handlers={Exception: error_handler}, ) - client = test_client_factory(app, raise_server_exceptions=False) + client = TestClient(app, raise_server_exceptions=False) response = client.get("/") assert response.status_code == 204 assert accessed_error_handler diff --git a/tests/middleware/test_gzip.py b/tests/middleware/test_gzip.py index b20a7cb84..773ac4b54 100644 --- a/tests/middleware/test_gzip.py +++ b/tests/middleware/test_gzip.py @@ -4,10 +4,10 @@ from starlette.requests import Request from starlette.responses import ContentStream, PlainTextResponse, StreamingResponse from starlette.routing import Route -from tests.types import TestClientFactory +from starlette.testclient import TestClient -def test_gzip_responses(test_client_factory: TestClientFactory) -> None: +def test_gzip_responses() -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("x" * 4000, status_code=200) @@ -16,7 +16,7 @@ def homepage(request: Request) -> PlainTextResponse: middleware=[Middleware(GZipMiddleware)], ) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/", headers={"accept-encoding": "gzip"}) assert response.status_code == 200 assert response.text == "x" * 4000 @@ -24,7 +24,7 @@ def homepage(request: Request) -> PlainTextResponse: assert int(response.headers["Content-Length"]) < 4000 -def test_gzip_not_in_accept_encoding(test_client_factory: TestClientFactory) -> None: +def test_gzip_not_in_accept_encoding() -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("x" * 4000, status_code=200) @@ -33,7 +33,7 @@ def homepage(request: Request) -> PlainTextResponse: middleware=[Middleware(GZipMiddleware)], ) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/", headers={"accept-encoding": "identity"}) assert response.status_code == 200 assert response.text == "x" * 4000 @@ -41,9 +41,7 @@ def homepage(request: Request) -> PlainTextResponse: assert int(response.headers["Content-Length"]) == 4000 -def test_gzip_ignored_for_small_responses( - test_client_factory: TestClientFactory, -) -> None: +def test_gzip_ignored_for_small_responses() -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("OK", status_code=200) @@ -52,7 +50,7 @@ def homepage(request: Request) -> PlainTextResponse: middleware=[Middleware(GZipMiddleware)], ) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/", headers={"accept-encoding": "gzip"}) assert response.status_code == 200 assert response.text == "OK" @@ -60,7 +58,7 @@ def homepage(request: Request) -> PlainTextResponse: assert int(response.headers["Content-Length"]) == 2 -def test_gzip_streaming_response(test_client_factory: TestClientFactory) -> None: +def test_gzip_streaming_response() -> None: def homepage(request: Request) -> StreamingResponse: async def generator(bytes: bytes, count: int) -> ContentStream: for index in range(count): @@ -74,7 +72,7 @@ async def generator(bytes: bytes, count: int) -> ContentStream: middleware=[Middleware(GZipMiddleware)], ) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/", headers={"accept-encoding": "gzip"}) assert response.status_code == 200 assert response.text == "x" * 4000 @@ -82,9 +80,7 @@ async def generator(bytes: bytes, count: int) -> ContentStream: assert "Content-Length" not in response.headers -def test_gzip_ignored_for_responses_with_encoding_set( - test_client_factory: TestClientFactory, -) -> None: +def test_gzip_ignored_for_responses_with_encoding_set() -> None: def homepage(request: Request) -> StreamingResponse: async def generator(bytes: bytes, count: int) -> ContentStream: for index in range(count): @@ -98,7 +94,7 @@ async def generator(bytes: bytes, count: int) -> ContentStream: middleware=[Middleware(GZipMiddleware)], ) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/", headers={"accept-encoding": "gzip, text"}) assert response.status_code == 200 assert response.text == "x" * 4000 diff --git a/tests/middleware/test_https_redirect.py b/tests/middleware/test_https_redirect.py index 66014e7e5..e6bee078e 100644 --- a/tests/middleware/test_https_redirect.py +++ b/tests/middleware/test_https_redirect.py @@ -4,10 +4,10 @@ from starlette.requests import Request from starlette.responses import PlainTextResponse from starlette.routing import Route -from tests.types import TestClientFactory +from starlette.testclient import TestClient -def test_https_redirect_middleware(test_client_factory: TestClientFactory) -> None: +def test_https_redirect_middleware() -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("OK", status_code=200) @@ -16,26 +16,26 @@ def homepage(request: Request) -> PlainTextResponse: middleware=[Middleware(HTTPSRedirectMiddleware)], ) - client = test_client_factory(app, base_url="https://testserver") + client = TestClient(app, base_url="https://testserver") response = client.get("/") assert response.status_code == 200 - client = test_client_factory(app) + client = TestClient(app) response = client.get("/", follow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver/" - client = test_client_factory(app, base_url="http://testserver:80") + client = TestClient(app, base_url="http://testserver:80") response = client.get("/", follow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver/" - client = test_client_factory(app, base_url="http://testserver:443") + client = TestClient(app, base_url="http://testserver:443") response = client.get("/", follow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver/" - client = test_client_factory(app, base_url="http://testserver:123") + client = TestClient(app, base_url="http://testserver:123") response = client.get("/", follow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver:123/" diff --git a/tests/middleware/test_session.py b/tests/middleware/test_session.py index b4f3c64fa..e596f6243 100644 --- a/tests/middleware/test_session.py +++ b/tests/middleware/test_session.py @@ -7,7 +7,6 @@ from starlette.responses import JSONResponse from starlette.routing import Mount, Route from starlette.testclient import TestClient -from tests.types import TestClientFactory def view_session(request: Request) -> JSONResponse: @@ -25,7 +24,7 @@ async def clear_session(request: Request) -> JSONResponse: return JSONResponse({"session": request.session}) -def test_session(test_client_factory: TestClientFactory) -> None: +def test_session() -> None: app = Starlette( routes=[ Route("/view_session", endpoint=view_session), @@ -34,7 +33,7 @@ def test_session(test_client_factory: TestClientFactory) -> None: ], middleware=[Middleware(SessionMiddleware, secret_key="example")], ) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/view_session") assert response.json() == {"session": {}} @@ -58,7 +57,7 @@ def test_session(test_client_factory: TestClientFactory) -> None: assert response.json() == {"session": {}} -def test_session_expires(test_client_factory: TestClientFactory) -> None: +def test_session_expires() -> None: app = Starlette( routes=[ Route("/view_session", endpoint=view_session), @@ -66,7 +65,7 @@ def test_session_expires(test_client_factory: TestClientFactory) -> None: ], middleware=[Middleware(SessionMiddleware, secret_key="example", max_age=-1)], ) - client = test_client_factory(app) + client = TestClient(app) response = client.post("/update_session", json={"some": "data"}) assert response.json() == {"session": {"some": "data"}} @@ -77,12 +76,12 @@ def test_session_expires(test_client_factory: TestClientFactory) -> None: expired_session_match = re.search(r"session=([^;]*);", expired_cookie_header) assert expired_session_match is not None expired_session_value = expired_session_match[1] - client = test_client_factory(app, cookies={"session": expired_session_value}) + client = TestClient(app, cookies={"session": expired_session_value}) response = client.get("/view_session") assert response.json() == {"session": {}} -def test_secure_session(test_client_factory: TestClientFactory) -> None: +def test_secure_session() -> None: app = Starlette( routes=[ Route("/view_session", endpoint=view_session), @@ -91,8 +90,8 @@ def test_secure_session(test_client_factory: TestClientFactory) -> None: ], middleware=[Middleware(SessionMiddleware, secret_key="example", https_only=True)], ) - secure_client = test_client_factory(app, base_url="https://testserver") - unsecure_client = test_client_factory(app, base_url="http://testserver") + secure_client = TestClient(app, base_url="https://testserver") + unsecure_client = TestClient(app, base_url="http://testserver") response = unsecure_client.get("/view_session") assert response.json() == {"session": {}} @@ -119,7 +118,7 @@ def test_secure_session(test_client_factory: TestClientFactory) -> None: assert response.json() == {"session": {}} -def test_session_cookie_subpath(test_client_factory: TestClientFactory) -> None: +def test_session_cookie_subpath() -> None: second_app = Starlette( routes=[ Route("/update_session", endpoint=update_session, methods=["POST"]), @@ -127,7 +126,7 @@ def test_session_cookie_subpath(test_client_factory: TestClientFactory) -> None: middleware=[Middleware(SessionMiddleware, secret_key="example", path="/second_app")], ) app = Starlette(routes=[Mount("/second_app", app=second_app)]) - client = test_client_factory(app, base_url="http://testserver/second_app") + client = TestClient(app, base_url="http://testserver/second_app") response = client.post("/update_session", json={"some": "data"}) assert response.status_code == 200 cookie = response.headers["set-cookie"] @@ -137,7 +136,7 @@ def test_session_cookie_subpath(test_client_factory: TestClientFactory) -> None: assert cookie_path == "/second_app" -def test_invalid_session_cookie(test_client_factory: TestClientFactory) -> None: +def test_invalid_session_cookie() -> None: app = Starlette( routes=[ Route("/view_session", endpoint=view_session), @@ -145,18 +144,18 @@ def test_invalid_session_cookie(test_client_factory: TestClientFactory) -> None: ], middleware=[Middleware(SessionMiddleware, secret_key="example")], ) - client = test_client_factory(app) + client = TestClient(app) response = client.post("/update_session", json={"some": "data"}) assert response.json() == {"session": {"some": "data"}} # we expect it to not raise an exception if we provide a bogus session cookie - client = test_client_factory(app, cookies={"session": "invalid"}) + client = TestClient(app, cookies={"session": "invalid"}) response = client.get("/view_session") assert response.json() == {"session": {}} -def test_session_cookie(test_client_factory: TestClientFactory) -> None: +def test_session_cookie() -> None: app = Starlette( routes=[ Route("/view_session", endpoint=view_session), @@ -164,7 +163,7 @@ def test_session_cookie(test_client_factory: TestClientFactory) -> None: ], middleware=[Middleware(SessionMiddleware, secret_key="example", max_age=None)], ) - client: TestClient = test_client_factory(app) + client: TestClient = TestClient(app) response = client.post("/update_session", json={"some": "data"}) assert response.json() == {"session": {"some": "data"}} @@ -178,7 +177,7 @@ def test_session_cookie(test_client_factory: TestClientFactory) -> None: assert response.json() == {"session": {}} -def test_domain_cookie(test_client_factory: TestClientFactory) -> None: +def test_domain_cookie() -> None: app = Starlette( routes=[ Route("/view_session", endpoint=view_session), @@ -186,7 +185,7 @@ def test_domain_cookie(test_client_factory: TestClientFactory) -> None: ], middleware=[Middleware(SessionMiddleware, secret_key="example", domain=".example.com")], ) - client: TestClient = test_client_factory(app) + client: TestClient = TestClient(app) response = client.post("/update_session", json={"some": "data"}) assert response.json() == {"session": {"some": "data"}} diff --git a/tests/middleware/test_trusted_host.py b/tests/middleware/test_trusted_host.py index 5b8b217c3..18530f62a 100644 --- a/tests/middleware/test_trusted_host.py +++ b/tests/middleware/test_trusted_host.py @@ -4,10 +4,10 @@ from starlette.requests import Request from starlette.responses import PlainTextResponse from starlette.routing import Route -from tests.types import TestClientFactory +from starlette.testclient import TestClient -def test_trusted_host_middleware(test_client_factory: TestClientFactory) -> None: +def test_trusted_host_middleware() -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("OK", status_code=200) @@ -16,15 +16,15 @@ def homepage(request: Request) -> PlainTextResponse: middleware=[Middleware(TrustedHostMiddleware, allowed_hosts=["testserver", "*.testserver"])], ) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.status_code == 200 - client = test_client_factory(app, base_url="http://subdomain.testserver") + client = TestClient(app, base_url="http://subdomain.testserver") response = client.get("/") assert response.status_code == 200 - client = test_client_factory(app, base_url="http://invalidhost") + client = TestClient(app, base_url="http://invalidhost") response = client.get("/") assert response.status_code == 400 @@ -35,7 +35,7 @@ def test_default_allowed_hosts() -> None: assert middleware.allowed_hosts == ["*"] -def test_www_redirect(test_client_factory: TestClientFactory) -> None: +def test_www_redirect() -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("OK", status_code=200) @@ -44,7 +44,7 @@ def homepage(request: Request) -> PlainTextResponse: middleware=[Middleware(TrustedHostMiddleware, allowed_hosts=["www.example.com"])], ) - client = test_client_factory(app, base_url="https://example.com") + client = TestClient(app, base_url="https://example.com") response = client.get("/") assert response.status_code == 200 assert response.url == "https://www.example.com/" diff --git a/tests/middleware/test_wsgi.py b/tests/middleware/test_wsgi.py index 3511c89c9..b0c40eb9c 100644 --- a/tests/middleware/test_wsgi.py +++ b/tests/middleware/test_wsgi.py @@ -6,7 +6,7 @@ from starlette._utils import collapse_excgroups from starlette.middleware.wsgi import WSGIMiddleware, build_environ -from tests.types import TestClientFactory +from starlette.testclient import TestClient WSGIResponse = Iterable[bytes] StartResponse = Callable[..., Any] @@ -65,41 +65,41 @@ def return_exc_info( return [output] -def test_wsgi_get(test_client_factory: TestClientFactory) -> None: +def test_wsgi_get() -> None: app = WSGIMiddleware(hello_world) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.status_code == 200 assert response.text == "Hello World!\n" -def test_wsgi_post(test_client_factory: TestClientFactory) -> None: +def test_wsgi_post() -> None: app = WSGIMiddleware(echo_body) - client = test_client_factory(app) + client = TestClient(app) response = client.post("/", json={"example": 123}) assert response.status_code == 200 assert response.text == '{"example":123}' -def test_wsgi_exception(test_client_factory: TestClientFactory) -> None: +def test_wsgi_exception() -> None: # Note that we're testing the WSGI app directly here. # The HTTP protocol implementations would catch this error and return 500. app = WSGIMiddleware(raise_exception) - client = test_client_factory(app) + client = TestClient(app) with pytest.raises(RuntimeError), collapse_excgroups(): client.get("/") -def test_wsgi_exc_info(test_client_factory: TestClientFactory) -> None: +def test_wsgi_exc_info() -> None: # Note that we're testing the WSGI app directly here. # The HTTP protocol implementations would catch this error and return 500. app = WSGIMiddleware(return_exc_info) - client = test_client_factory(app) + client = TestClient(app) with pytest.raises(RuntimeError): response = client.get("/") app = WSGIMiddleware(return_exc_info) - client = test_client_factory(app, raise_server_exceptions=False) + client = TestClient(app, raise_server_exceptions=False) response = client.get("/") assert response.status_code == 500 assert response.text == "Internal Server Error" diff --git a/tests/test_applications.py b/tests/test_applications.py index 310eef6b4..eaabc710b 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -23,7 +23,6 @@ from starlette.testclient import TestClient, WebSocketDenialResponse from starlette.types import ASGIApp, Receive, Scope, Send from starlette.websockets import WebSocket -from tests.types import TestClientFactory async def error_500(request: Request, exc: HTTPException) -> JSONResponse: @@ -137,8 +136,8 @@ def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException) -> @pytest.fixture -def client(test_client_factory: TestClientFactory) -> Generator[TestClient, None, None]: - with test_client_factory(app) as client: +def client() -> Generator[TestClient, None, None]: + with TestClient(app) as client: yield client @@ -180,8 +179,8 @@ def test_mounted_route_path_params(client: TestClient) -> None: assert response.text == "Hello, tomchristie!" -def test_subdomain_route(test_client_factory: TestClientFactory) -> None: - client = test_client_factory(app, base_url="https://foo.example.org/") +def test_subdomain_route() -> None: + client = TestClient(app, base_url="https://foo.example.org/") response = client.get("/") assert response.status_code == 200 @@ -210,8 +209,8 @@ def test_405(client: TestClient) -> None: assert response.json() == {"detail": "Custom message"} -def test_500(test_client_factory: TestClientFactory) -> None: - client = test_client_factory(app, raise_server_exceptions=False) +def test_500() -> None: + client = TestClient(app, raise_server_exceptions=False) response = client.get("/500") assert response.status_code == 500 assert response.json() == {"detail": "Server Error"} @@ -245,8 +244,8 @@ def test_websocket_raise_custom_exception(client: TestClient) -> None: } -def test_middleware(test_client_factory: TestClientFactory) -> None: - client = test_client_factory(app, base_url="http://incorrecthost") +def test_middleware() -> None: + client = TestClient(app, base_url="http://incorrecthost") response = client.get("/func") assert response.status_code == 400 assert response.text == "Invalid host header" @@ -278,7 +277,7 @@ def test_routes() -> None: ] -def test_app_mount(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_app_mount(tmpdir: Path) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") @@ -289,7 +288,7 @@ def test_app_mount(tmpdir: Path, test_client_factory: TestClientFactory) -> None ] ) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/static/example.txt") assert response.status_code == 200 @@ -300,7 +299,7 @@ def test_app_mount(tmpdir: Path, test_client_factory: TestClientFactory) -> None assert response.text == "Method Not Allowed" -def test_app_debug(test_client_factory: TestClientFactory) -> None: +def test_app_debug() -> None: async def homepage(request: Request) -> None: raise RuntimeError() @@ -311,14 +310,14 @@ async def homepage(request: Request) -> None: ) app.debug = True - client = test_client_factory(app, raise_server_exceptions=False) + client = TestClient(app, raise_server_exceptions=False) response = client.get("/") assert response.status_code == 500 assert "RuntimeError" in response.text assert app.debug -def test_app_add_route(test_client_factory: TestClientFactory) -> None: +def test_app_add_route() -> None: async def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Hello, World!") @@ -328,13 +327,13 @@ async def homepage(request: Request) -> PlainTextResponse: ] ) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.status_code == 200 assert response.text == "Hello, World!" -def test_app_add_websocket_route(test_client_factory: TestClientFactory) -> None: +def test_app_add_websocket_route() -> None: async def websocket_endpoint(session: WebSocket) -> None: await session.accept() await session.send_text("Hello, world!") @@ -345,14 +344,14 @@ async def websocket_endpoint(session: WebSocket) -> None: WebSocketRoute("/ws", endpoint=websocket_endpoint), ] ) - client = test_client_factory(app) + client = TestClient(app) with client.websocket_connect("/ws") as session: text = session.receive_text() assert text == "Hello, world!" -def test_app_add_event_handler(test_client_factory: TestClientFactory) -> None: +def test_app_add_event_handler() -> None: startup_complete = False cleanup_complete = False @@ -372,14 +371,14 @@ def run_cleanup() -> None: assert not startup_complete assert not cleanup_complete - with test_client_factory(app): + with TestClient(app): assert startup_complete assert not cleanup_complete assert startup_complete assert cleanup_complete -def test_app_async_cm_lifespan(test_client_factory: TestClientFactory) -> None: +def test_app_async_cm_lifespan() -> None: startup_complete = False cleanup_complete = False @@ -394,7 +393,7 @@ async def lifespan(app: ASGIApp) -> AsyncGenerator[None, None]: assert not startup_complete assert not cleanup_complete - with test_client_factory(app): + with TestClient(app): assert startup_complete assert not cleanup_complete assert startup_complete @@ -411,7 +410,7 @@ async def lifespan(app: ASGIApp) -> AsyncGenerator[None, None]: @deprecated_lifespan -def test_app_async_gen_lifespan(test_client_factory: TestClientFactory) -> None: +def test_app_async_gen_lifespan() -> None: startup_complete = False cleanup_complete = False @@ -425,7 +424,7 @@ async def lifespan(app: ASGIApp) -> AsyncGenerator[None, None]: assert not startup_complete assert not cleanup_complete - with test_client_factory(app): + with TestClient(app): assert startup_complete assert not cleanup_complete assert startup_complete @@ -433,7 +432,7 @@ async def lifespan(app: ASGIApp) -> AsyncGenerator[None, None]: @deprecated_lifespan -def test_app_sync_gen_lifespan(test_client_factory: TestClientFactory) -> None: +def test_app_sync_gen_lifespan() -> None: startup_complete = False cleanup_complete = False @@ -447,7 +446,7 @@ def lifespan(app: ASGIApp) -> Generator[None, None, None]: assert not startup_complete assert not cleanup_complete - with test_client_factory(app): + with TestClient(app): assert startup_complete assert not cleanup_complete assert startup_complete @@ -494,7 +493,7 @@ async def startup() -> None: ... # pragma: no cover assert len(record) == 1 -def test_middleware_stack_init(test_client_factory: TestClientFactory) -> None: +def test_middleware_stack_init() -> None: class NoOpMiddleware: def __init__(self, app: ASGIApp): self.app = app @@ -520,23 +519,23 @@ def get_app() -> ASGIApp: app = get_app() - with test_client_factory(app): + with TestClient(app): pass assert SimpleInitializableMiddleware.counter == 1 - test_client_factory(app).get("/foo") + TestClient(app).get("/foo") assert SimpleInitializableMiddleware.counter == 1 app = get_app() - test_client_factory(app).get("/foo") + TestClient(app).get("/foo") assert SimpleInitializableMiddleware.counter == 2 -def test_middleware_args(test_client_factory: TestClientFactory) -> None: +def test_middleware_args() -> None: calls: list[str] = [] class MiddlewareWithArgs: @@ -552,13 +551,13 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: app.add_middleware(MiddlewareWithArgs, "foo") app.add_middleware(MiddlewareWithArgs, "bar") - with test_client_factory(app): + with TestClient(app): pass assert calls == ["bar", "foo"] -def test_middleware_factory(test_client_factory: TestClientFactory) -> None: +def test_middleware_factory() -> None: calls: list[str] = [] def _middleware_factory(app: ASGIApp, arg: str) -> ASGIApp: @@ -575,7 +574,7 @@ def get_middleware_factory() -> Callable[[ASGIApp, str], ASGIApp]: app.add_middleware(_middleware_factory, arg="foo") app.add_middleware(get_middleware_factory(), "bar") - with test_client_factory(app): + with TestClient(app): pass assert calls == ["bar", "foo"] diff --git a/tests/test_authentication.py b/tests/test_authentication.py index ddd2ad805..2d1ac37f7 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -16,8 +16,8 @@ from starlette.requests import HTTPConnection, Request from starlette.responses import JSONResponse, Response from starlette.routing import Route, WebSocketRoute +from starlette.testclient import TestClient from starlette.websockets import WebSocket, WebSocketDisconnect -from tests.types import TestClientFactory AsyncEndpoint = Callable[..., Awaitable[Response]] SyncEndpoint = Callable[..., Response] @@ -209,8 +209,8 @@ def foo() -> None: # pragma: no cover pass -def test_user_interface(test_client_factory: TestClientFactory) -> None: - with test_client_factory(app) as client: +def test_user_interface() -> None: + with TestClient(app) as client: response = client.get("/") assert response.status_code == 200 assert response.json() == {"authenticated": False, "user": ""} @@ -220,8 +220,8 @@ def test_user_interface(test_client_factory: TestClientFactory) -> None: assert response.json() == {"authenticated": True, "user": "tomchristie"} -def test_authentication_required(test_client_factory: TestClientFactory) -> None: - with test_client_factory(app) as client: +def test_authentication_required() -> None: + with TestClient(app) as client: response = client.get("/dashboard") assert response.status_code == 403 @@ -270,10 +270,8 @@ def test_authentication_required(test_client_factory: TestClientFactory) -> None assert response.text == "Invalid basic auth credentials" -def test_websocket_authentication_required( - test_client_factory: TestClientFactory, -) -> None: - with test_client_factory(app) as client: +def test_websocket_authentication_required() -> None: + with TestClient(app) as client: with pytest.raises(WebSocketDisconnect): with client.websocket_connect("/ws"): pass # pragma: no cover @@ -303,8 +301,8 @@ def test_websocket_authentication_required( } -def test_authentication_redirect(test_client_factory: TestClientFactory) -> None: - with test_client_factory(app) as client: +def test_authentication_redirect() -> None: + with TestClient(app) as client: response = client.get("/admin") assert response.status_code == 200 url = "{}?{}".format("http://testserver/", urlencode({"next": "http://testserver/admin"})) @@ -344,8 +342,8 @@ def control_panel(request: Request) -> JSONResponse: ) -def test_custom_on_error(test_client_factory: TestClientFactory) -> None: - with test_client_factory(other_app) as client: +def test_custom_on_error() -> None: + with TestClient(other_app) as client: response = client.get("/control-panel", auth=("tomchristie", "example")) assert response.status_code == 200 assert response.json() == {"authenticated": True, "user": "tomchristie"} diff --git a/tests/test_background.py b/tests/test_background.py index 48f348769..961810e16 100644 --- a/tests/test_background.py +++ b/tests/test_background.py @@ -2,11 +2,11 @@ from starlette.background import BackgroundTask, BackgroundTasks from starlette.responses import Response +from starlette.testclient import TestClient from starlette.types import Receive, Scope, Send -from tests.types import TestClientFactory -def test_async_task(test_client_factory: TestClientFactory) -> None: +def test_async_task() -> None: TASK_COMPLETE = False async def async_task() -> None: @@ -19,13 +19,13 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response("task initiated", media_type="text/plain", background=task) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.text == "task initiated" assert TASK_COMPLETE -def test_sync_task(test_client_factory: TestClientFactory) -> None: +def test_sync_task() -> None: TASK_COMPLETE = False def sync_task() -> None: @@ -38,13 +38,13 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response("task initiated", media_type="text/plain", background=task) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.text == "task initiated" assert TASK_COMPLETE -def test_multiple_tasks(test_client_factory: TestClientFactory) -> None: +def test_multiple_tasks() -> None: TASK_COUNTER = 0 def increment(amount: int) -> None: @@ -59,15 +59,13 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response("tasks initiated", media_type="text/plain", background=tasks) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.text == "tasks initiated" assert TASK_COUNTER == 1 + 2 + 3 -def test_multi_tasks_failure_avoids_next_execution( - test_client_factory: TestClientFactory, -) -> None: +def test_multi_tasks_failure_avoids_next_execution() -> None: TASK_COUNTER = 0 def increment() -> None: @@ -83,7 +81,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response("tasks initiated", media_type="text/plain", background=tasks) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) with pytest.raises(Exception): client.get("/") assert TASK_COUNTER == 1 diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index d620984c1..20b5a0e7b 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -9,7 +9,7 @@ from starlette.requests import Request from starlette.responses import Response from starlette.routing import Route -from tests.types import TestClientFactory +from starlette.testclient import TestClient @pytest.mark.anyio @@ -30,9 +30,7 @@ async def task2() -> None: assert not task2_finished.is_set() -def test_accessing_context_from_threaded_sync_endpoint( - test_client_factory: TestClientFactory, -) -> None: +def test_accessing_context_from_threaded_sync_endpoint() -> None: ctxvar: ContextVar[bytes] = ContextVar("ctxvar") ctxvar.set(b"data") @@ -40,7 +38,7 @@ def endpoint(request: Request) -> Response: return Response(ctxvar.get()) app = Starlette(routes=[Route("/", endpoint)]) - client = test_client_factory(app) + client = TestClient(app) resp = client.get("/") assert resp.content == b"data" diff --git a/tests/test_convertors.py b/tests/test_convertors.py index d8430c672..dd8981e24 100644 --- a/tests/test_convertors.py +++ b/tests/test_convertors.py @@ -9,7 +9,7 @@ from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Route, Router -from tests.types import TestClientFactory +from starlette.testclient import TestClient @pytest.fixture(scope="module", autouse=True) @@ -49,8 +49,8 @@ def datetime_convertor(request: Request) -> JSONResponse: ) -def test_datetime_convertor(test_client_factory: TestClientFactory, app: Router) -> None: - client = test_client_factory(app) +def test_datetime_convertor(app: Router) -> None: + client = TestClient(app) response = client.get("/datetime/2020-01-01T00:00:00") assert response.json() == {"datetime": "2020-01-01T00:00:00"} @@ -60,7 +60,7 @@ def test_datetime_convertor(test_client_factory: TestClientFactory, app: Router) @pytest.mark.parametrize("param, status_code", [("1.0", 200), ("1-0", 404)]) -def test_default_float_convertor(test_client_factory: TestClientFactory, param: str, status_code: int) -> None: +def test_default_float_convertor(param: str, status_code: int) -> None: def float_convertor(request: Request) -> JSONResponse: param = request.path_params["param"] assert isinstance(param, float) @@ -68,7 +68,7 @@ def float_convertor(request: Request) -> JSONResponse: app = Router(routes=[Route("/{param:float}", endpoint=float_convertor)]) - client = test_client_factory(app) + client = TestClient(app) response = client.get(f"/{param}") assert response.status_code == status_code @@ -83,7 +83,7 @@ def float_convertor(request: Request) -> JSONResponse: ("not-a-uuid", 404), ], ) -def test_default_uuid_convertor(test_client_factory: TestClientFactory, param: str, status_code: int) -> None: +def test_default_uuid_convertor(param: str, status_code: int) -> None: def uuid_convertor(request: Request) -> JSONResponse: param = request.path_params["param"] assert isinstance(param, UUID) @@ -91,6 +91,6 @@ def uuid_convertor(request: Request) -> JSONResponse: app = Router(routes=[Route("/{param:uuid}", endpoint=uuid_convertor)]) - client = test_client_factory(app) + client = TestClient(app) response = client.get(f"/{param}") assert response.status_code == status_code diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 76163873c..3458abe0d 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -8,7 +8,6 @@ from starlette.routing import Route, Router from starlette.testclient import TestClient from starlette.websockets import WebSocket -from tests.types import TestClientFactory class Homepage(HTTPEndpoint): @@ -23,8 +22,8 @@ async def get(self, request: Request) -> PlainTextResponse: @pytest.fixture -def client(test_client_factory: TestClientFactory) -> Iterator[TestClient]: - with test_client_factory(app) as client: +def client() -> Iterator[TestClient]: + with TestClient(app) as client: yield client @@ -47,27 +46,25 @@ def test_http_endpoint_route_method(client: TestClient) -> None: assert response.headers["allow"] == "GET" -def test_websocket_endpoint_on_connect(test_client_factory: TestClientFactory) -> None: +def test_websocket_endpoint_on_connect() -> None: class WebSocketApp(WebSocketEndpoint): async def on_connect(self, websocket: WebSocket) -> None: assert websocket["subprotocols"] == ["soap", "wamp"] await websocket.accept(subprotocol="wamp") - client = test_client_factory(WebSocketApp) + client = TestClient(WebSocketApp) with client.websocket_connect("/ws", subprotocols=["soap", "wamp"]) as websocket: assert websocket.accepted_subprotocol == "wamp" -def test_websocket_endpoint_on_receive_bytes( - test_client_factory: TestClientFactory, -) -> None: +def test_websocket_endpoint_on_receive_bytes() -> None: class WebSocketApp(WebSocketEndpoint): encoding = "bytes" async def on_receive(self, websocket: WebSocket, data: bytes) -> None: await websocket.send_bytes(b"Message bytes was: " + data) - client = test_client_factory(WebSocketApp) + client = TestClient(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.send_bytes(b"Hello, world!") _bytes = websocket.receive_bytes() @@ -78,16 +75,14 @@ async def on_receive(self, websocket: WebSocket, data: bytes) -> None: websocket.send_text("Hello world") -def test_websocket_endpoint_on_receive_json( - test_client_factory: TestClientFactory, -) -> None: +def test_websocket_endpoint_on_receive_json() -> None: class WebSocketApp(WebSocketEndpoint): encoding = "json" async def on_receive(self, websocket: WebSocket, data: str) -> None: await websocket.send_json({"message": data}) - client = test_client_factory(WebSocketApp) + client = TestClient(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.send_json({"hello": "world"}) data = websocket.receive_json() @@ -98,32 +93,28 @@ async def on_receive(self, websocket: WebSocket, data: str) -> None: websocket.send_text("Hello world") -def test_websocket_endpoint_on_receive_json_binary( - test_client_factory: TestClientFactory, -) -> None: +def test_websocket_endpoint_on_receive_json_binary() -> None: class WebSocketApp(WebSocketEndpoint): encoding = "json" async def on_receive(self, websocket: WebSocket, data: str) -> None: await websocket.send_json({"message": data}, mode="binary") - client = test_client_factory(WebSocketApp) + client = TestClient(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.send_json({"hello": "world"}, mode="binary") data = websocket.receive_json(mode="binary") assert data == {"message": {"hello": "world"}} -def test_websocket_endpoint_on_receive_text( - test_client_factory: TestClientFactory, -) -> None: +def test_websocket_endpoint_on_receive_text() -> None: class WebSocketApp(WebSocketEndpoint): encoding = "text" async def on_receive(self, websocket: WebSocket, data: str) -> None: await websocket.send_text(f"Message text was: {data}") - client = test_client_factory(WebSocketApp) + client = TestClient(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.send_text("Hello, world!") _text = websocket.receive_text() @@ -134,28 +125,26 @@ async def on_receive(self, websocket: WebSocket, data: str) -> None: websocket.send_bytes(b"Hello world") -def test_websocket_endpoint_on_default(test_client_factory: TestClientFactory) -> None: +def test_websocket_endpoint_on_default() -> None: class WebSocketApp(WebSocketEndpoint): encoding = None async def on_receive(self, websocket: WebSocket, data: str) -> None: await websocket.send_text(f"Message text was: {data}") - client = test_client_factory(WebSocketApp) + client = TestClient(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.send_text("Hello, world!") _text = websocket.receive_text() assert _text == "Message text was: Hello, world!" -def test_websocket_endpoint_on_disconnect( - test_client_factory: TestClientFactory, -) -> None: +def test_websocket_endpoint_on_disconnect() -> None: class WebSocketApp(WebSocketEndpoint): async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None: assert close_code == 1001 await websocket.close(code=close_code) - client = test_client_factory(WebSocketApp) + client = TestClient(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.close(code=1001) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index fe5da0ba6..d6b0cb54d 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -9,7 +9,6 @@ from starlette.routing import Route, Router, WebSocketRoute from starlette.testclient import TestClient from starlette.types import Receive, Scope, Send -from tests.types import TestClientFactory def raise_runtime_error(request: Request) -> None: @@ -74,8 +73,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: @pytest.fixture -def client(test_client_factory: TestClientFactory) -> Generator[TestClient, None, None]: - with test_client_factory(app) as client: +def client() -> Generator[TestClient, None, None]: + with TestClient(app) as client: yield client @@ -109,7 +108,7 @@ def test_websockets_should_raise(client: TestClient) -> None: pass # pragma: no cover -def test_handled_exc_after_response(test_client_factory: TestClientFactory, client: TestClient) -> None: +def test_handled_exc_after_response(client: TestClient) -> None: # A 406 HttpException is raised *after* the response has already been sent. # The exception middleware should raise a RuntimeError. with pytest.raises(RuntimeError, match="Caught handled exception, but response already started."): @@ -117,13 +116,13 @@ def test_handled_exc_after_response(test_client_factory: TestClientFactory, clie # If `raise_server_exceptions=False` then the test client will still allow # us to see the response as it will have been seen by the client. - allow_200_client = test_client_factory(app, raise_server_exceptions=False) + allow_200_client = TestClient(app, raise_server_exceptions=False) response = allow_200_client.get("/handled_exc_after_response") assert response.status_code == 200 assert response.text == "OK" -def test_force_500_response(test_client_factory: TestClientFactory) -> None: +def test_force_500_response() -> None: # use a sentinel variable to make sure we actually # make it into the endpoint and don't get a 500 # from an incorrect ASGI app signature or something @@ -134,7 +133,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: called = True raise RuntimeError() - force_500_client = test_client_factory(app, raise_server_exceptions=False) + force_500_client = TestClient(app, raise_server_exceptions=False) response = force_500_client.get("/") assert called assert response.status_code == 500 diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index 70beebc3f..ecf60bae0 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -13,8 +13,8 @@ from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Mount +from starlette.testclient import TestClient from starlette.types import ASGIApp, Receive, Scope, Send -from tests.types import TestClientFactory class ForceMultipartDict(dict[typing.Any, typing.Any]): @@ -127,18 +127,22 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: return app -def test_multipart_request_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None: - client = test_client_factory(app) +def test_multipart_request_data( + tmpdir: Path, +) -> None: + client = TestClient(app) response = client.post("/", data={"some": "data"}, files=FORCE_MULTIPART) assert response.json() == {"some": "data"} -def test_multipart_request_files(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_multipart_request_files( + tmpdir: Path, +) -> None: path = os.path.join(tmpdir, "test.txt") with open(path, "wb") as file: file.write(b"") - client = test_client_factory(app) + client = TestClient(app) with open(path, "rb") as f: response = client.post("/", files={"test": f}) assert response.json() == { @@ -151,12 +155,14 @@ def test_multipart_request_files(tmpdir: Path, test_client_factory: TestClientFa } -def test_multipart_request_files_with_content_type(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_multipart_request_files_with_content_type( + tmpdir: Path, +) -> None: path = os.path.join(tmpdir, "test.txt") with open(path, "wb") as file: file.write(b"") - client = test_client_factory(app) + client = TestClient(app) with open(path, "rb") as f: response = client.post("/", files={"test": ("test.txt", f, "text/plain")}) assert response.json() == { @@ -169,7 +175,9 @@ def test_multipart_request_files_with_content_type(tmpdir: Path, test_client_fac } -def test_multipart_request_multiple_files(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_multipart_request_multiple_files( + tmpdir: Path, +) -> None: path1 = os.path.join(tmpdir, "test1.txt") with open(path1, "wb") as file: file.write(b"") @@ -178,7 +186,7 @@ def test_multipart_request_multiple_files(tmpdir: Path, test_client_factory: Tes with open(path2, "wb") as file: file.write(b"") - client = test_client_factory(app) + client = TestClient(app) with open(path1, "rb") as f1, open(path2, "rb") as f2: response = client.post("/", files={"test1": f1, "test2": ("test2.txt", f2, "text/plain")}) assert response.json() == { @@ -197,7 +205,9 @@ def test_multipart_request_multiple_files(tmpdir: Path, test_client_factory: Tes } -def test_multipart_request_multiple_files_with_headers(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_multipart_request_multiple_files_with_headers( + tmpdir: Path, +) -> None: path1 = os.path.join(tmpdir, "test1.txt") with open(path1, "wb") as file: file.write(b"") @@ -206,7 +216,7 @@ def test_multipart_request_multiple_files_with_headers(tmpdir: Path, test_client with open(path2, "wb") as file: file.write(b"") - client = test_client_factory(app_with_headers) + client = TestClient(app_with_headers) with open(path1, "rb") as f1, open(path2, "rb") as f2: response = client.post( "/", @@ -234,7 +244,9 @@ def test_multipart_request_multiple_files_with_headers(tmpdir: Path, test_client } -def test_multi_items(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_multi_items( + tmpdir: Path, +) -> None: path1 = os.path.join(tmpdir, "test1.txt") with open(path1, "wb") as file: file.write(b"") @@ -243,7 +255,7 @@ def test_multi_items(tmpdir: Path, test_client_factory: TestClientFactory) -> No with open(path2, "wb") as file: file.write(b"") - client = test_client_factory(multi_items_app) + client = TestClient(multi_items_app) with open(path1, "rb") as f1, open(path2, "rb") as f2: response = client.post( "/", @@ -269,8 +281,10 @@ def test_multi_items(tmpdir: Path, test_client_factory: TestClientFactory) -> No } -def test_multipart_request_mixed_files_and_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None: - client = test_client_factory(app) +def test_multipart_request_mixed_files_and_data( + tmpdir: Path, +) -> None: + client = TestClient(app) response = client.post( "/", data=( @@ -303,8 +317,10 @@ def test_multipart_request_mixed_files_and_data(tmpdir: Path, test_client_factor } -def test_multipart_request_with_charset_for_filename(tmpdir: Path, test_client_factory: TestClientFactory) -> None: - client = test_client_factory(app) +def test_multipart_request_with_charset_for_filename( + tmpdir: Path, +) -> None: + client = TestClient(app) response = client.post( "/", data=( @@ -327,8 +343,10 @@ def test_multipart_request_with_charset_for_filename(tmpdir: Path, test_client_f } -def test_multipart_request_without_charset_for_filename(tmpdir: Path, test_client_factory: TestClientFactory) -> None: - client = test_client_factory(app) +def test_multipart_request_without_charset_for_filename( + tmpdir: Path, +) -> None: + client = TestClient(app) response = client.post( "/", data=( @@ -351,8 +369,10 @@ def test_multipart_request_without_charset_for_filename(tmpdir: Path, test_clien } -def test_multipart_request_with_encoded_value(tmpdir: Path, test_client_factory: TestClientFactory) -> None: - client = test_client_factory(app) +def test_multipart_request_with_encoded_value( + tmpdir: Path, +) -> None: + client = TestClient(app) response = client.post( "/", data=( @@ -367,38 +387,50 @@ def test_multipart_request_with_encoded_value(tmpdir: Path, test_client_factory: assert response.json() == {"value": "Transférer"} -def test_urlencoded_request_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None: - client = test_client_factory(app) +def test_urlencoded_request_data( + tmpdir: Path, +) -> None: + client = TestClient(app) response = client.post("/", data={"some": "data"}) assert response.json() == {"some": "data"} -def test_no_request_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None: - client = test_client_factory(app) +def test_no_request_data( + tmpdir: Path, +) -> None: + client = TestClient(app) response = client.post("/") assert response.json() == {} -def test_urlencoded_percent_encoding(tmpdir: Path, test_client_factory: TestClientFactory) -> None: - client = test_client_factory(app) +def test_urlencoded_percent_encoding( + tmpdir: Path, +) -> None: + client = TestClient(app) response = client.post("/", data={"some": "da ta"}) assert response.json() == {"some": "da ta"} -def test_urlencoded_percent_encoding_keys(tmpdir: Path, test_client_factory: TestClientFactory) -> None: - client = test_client_factory(app) +def test_urlencoded_percent_encoding_keys( + tmpdir: Path, +) -> None: + client = TestClient(app) response = client.post("/", data={"so me": "data"}) assert response.json() == {"so me": "data"} -def test_urlencoded_multi_field_app_reads_body(tmpdir: Path, test_client_factory: TestClientFactory) -> None: - client = test_client_factory(app_read_body) +def test_urlencoded_multi_field_app_reads_body( + tmpdir: Path, +) -> None: + client = TestClient(app_read_body) response = client.post("/", data={"some": "data", "second": "key pair"}) assert response.json() == {"some": "data", "second": "key pair"} -def test_multipart_multi_field_app_reads_body(tmpdir: Path, test_client_factory: TestClientFactory) -> None: - client = test_client_factory(app_read_body) +def test_multipart_multi_field_app_reads_body( + tmpdir: Path, +) -> None: + client = TestClient(app_read_body) response = client.post("/", data={"some": "data", "second": "key pair"}, files=FORCE_MULTIPART) assert response.json() == {"some": "data", "second": "key pair"} @@ -423,9 +455,8 @@ def test_user_safe_decode_ignores_wrong_charset() -> None: def test_missing_boundary_parameter( app: ASGIApp, expectation: typing.ContextManager[Exception], - test_client_factory: TestClientFactory, ) -> None: - client = test_client_factory(app) + client = TestClient(app) with expectation: res = client.post( "/", @@ -451,9 +482,8 @@ def test_missing_boundary_parameter( def test_missing_name_parameter_on_content_disposition( app: ASGIApp, expectation: typing.ContextManager[Exception], - test_client_factory: TestClientFactory, ) -> None: - client = test_client_factory(app) + client = TestClient(app) with expectation: res = client.post( "/", @@ -479,9 +509,8 @@ def test_missing_name_parameter_on_content_disposition( def test_too_many_fields_raise( app: ASGIApp, expectation: typing.ContextManager[Exception], - test_client_factory: TestClientFactory, ) -> None: - client = test_client_factory(app) + client = TestClient(app) fields = [] for i in range(1001): fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n") @@ -506,9 +535,8 @@ def test_too_many_fields_raise( def test_too_many_files_raise( app: ASGIApp, expectation: typing.ContextManager[Exception], - test_client_factory: TestClientFactory, ) -> None: - client = test_client_factory(app) + client = TestClient(app) fields = [] for i in range(1001): fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}"; filename="F{i}";\r\n\r\n' "\r\n") @@ -533,9 +561,8 @@ def test_too_many_files_raise( def test_too_many_files_single_field_raise( app: ASGIApp, expectation: typing.ContextManager[Exception], - test_client_factory: TestClientFactory, ) -> None: - client = test_client_factory(app) + client = TestClient(app) fields = [] for i in range(1001): # This uses the same field name "N" for all files, equivalent to a @@ -562,9 +589,8 @@ def test_too_many_files_single_field_raise( def test_too_many_files_and_fields_raise( app: ASGIApp, expectation: typing.ContextManager[Exception], - test_client_factory: TestClientFactory, ) -> None: - client = test_client_factory(app) + client = TestClient(app) fields = [] for i in range(1001): fields.append("--B\r\n" f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n' "\r\n") @@ -593,9 +619,8 @@ def test_too_many_files_and_fields_raise( def test_max_fields_is_customizable_low_raises( app: ASGIApp, expectation: typing.ContextManager[Exception], - test_client_factory: TestClientFactory, ) -> None: - client = test_client_factory(app) + client = TestClient(app) fields = [] for i in range(2): fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n") @@ -623,9 +648,8 @@ def test_max_fields_is_customizable_low_raises( def test_max_files_is_customizable_low_raises( app: ASGIApp, expectation: typing.ContextManager[Exception], - test_client_factory: TestClientFactory, ) -> None: - client = test_client_factory(app) + client = TestClient(app) fields = [] for i in range(2): fields.append("--B\r\n" f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n' "\r\n") @@ -640,8 +664,8 @@ def test_max_files_is_customizable_low_raises( assert res.text == "Too many files. Maximum number of files is 1." -def test_max_fields_is_customizable_high(test_client_factory: TestClientFactory) -> None: - client = test_client_factory(make_app_max_parts(max_fields=2000, max_files=2000)) +def test_max_fields_is_customizable_high() -> None: + client = TestClient(make_app_max_parts(max_fields=2000, max_files=2000)) fields = [] for i in range(2000): fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n") @@ -674,9 +698,8 @@ def test_max_fields_is_customizable_high(test_client_factory: TestClientFactory) def test_max_part_size_exceeds_limit( app: ASGIApp, expectation: typing.ContextManager[Exception], - test_client_factory: TestClientFactory, ) -> None: - client = test_client_factory(app) + client = TestClient(app) boundary = "------------------------4K1ON9fZkj9uCUmqLHRbbR" multipart_data = ( @@ -714,9 +737,8 @@ def test_max_part_size_exceeds_limit( def test_max_part_size_exceeds_custom_limit( app: ASGIApp, expectation: typing.ContextManager[Exception], - test_client_factory: TestClientFactory, ) -> None: - client = test_client_factory(app) + client = TestClient(app) boundary = "------------------------4K1ON9fZkj9uCUmqLHRbbR" multipart_data = ( diff --git a/tests/test_requests.py b/tests/test_requests.py index 7e2c608dc..f61a61c35 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -10,18 +10,18 @@ from starlette.datastructures import URL, Address, State from starlette.requests import ClientDisconnect, Request from starlette.responses import JSONResponse, PlainTextResponse, Response +from starlette.testclient import TestClient from starlette.types import Message, Receive, Scope, Send -from tests.types import TestClientFactory -def test_request_url(test_client_factory: TestClientFactory) -> None: +def test_request_url() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) data = {"method": request.method, "url": str(request.url)} response = JSONResponse(data) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/123?a=abc") assert response.json() == {"method": "GET", "url": "http://testserver/123?a=abc"} @@ -29,14 +29,14 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert response.json() == {"method": "GET", "url": "https://example.org:123/"} -def test_request_query_params(test_client_factory: TestClientFactory) -> None: +def test_request_query_params() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) params = dict(request.query_params) response = JSONResponse({"params": params}) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/?a=123&b=456") assert response.json() == {"params": {"a": "123", "b": "456"}} @@ -45,14 +45,14 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: any(module in sys.modules for module in ("brotli", "brotlicffi")), reason='urllib3 includes "br" to the "accept-encoding" headers.', ) -def test_request_headers(test_client_factory: TestClientFactory) -> None: +def test_request_headers() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) headers = dict(request.headers) response = JSONResponse({"headers": headers}) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/", headers={"host": "example.org"}) assert response.json() == { "headers": { @@ -79,14 +79,14 @@ def test_request_client(scope: Scope, expected_client: Address | None) -> None: assert client == expected_client -def test_request_body(test_client_factory: TestClientFactory) -> None: +def test_request_body() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) body = await request.body() response = JSONResponse({"body": body.decode()}) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.json() == {"body": ""} @@ -98,7 +98,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert response.json() == {"body": "abc"} -def test_request_stream(test_client_factory: TestClientFactory) -> None: +def test_request_stream() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) body = b"" @@ -107,7 +107,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response = JSONResponse({"body": body.decode()}) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.json() == {"body": ""} @@ -119,33 +119,33 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert response.json() == {"body": "abc"} -def test_request_form_urlencoded(test_client_factory: TestClientFactory) -> None: +def test_request_form_urlencoded() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) form = await request.form() response = JSONResponse({"form": dict(form)}) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.post("/", data={"abc": "123 @"}) assert response.json() == {"form": {"abc": "123 @"}} -def test_request_form_context_manager(test_client_factory: TestClientFactory) -> None: +def test_request_form_context_manager() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) async with request.form() as form: response = JSONResponse({"form": dict(form)}) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.post("/", data={"abc": "123 @"}) assert response.json() == {"form": {"abc": "123 @"}} -def test_request_body_then_stream(test_client_factory: TestClientFactory) -> None: +def test_request_body_then_stream() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) body = await request.body() @@ -155,13 +155,13 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response = JSONResponse({"body": body.decode(), "stream": chunks.decode()}) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.post("/", data="abc") # type: ignore assert response.json() == {"body": "abc", "stream": "abc"} -def test_request_stream_then_body(test_client_factory: TestClientFactory) -> None: +def test_request_stream_then_body() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) chunks = b"" @@ -174,20 +174,20 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response = JSONResponse({"body": body.decode(), "stream": chunks.decode()}) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.post("/", data="abc") # type: ignore assert response.json() == {"body": "", "stream": "abc"} -def test_request_json(test_client_factory: TestClientFactory) -> None: +def test_request_json() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) data = await request.json() response = JSONResponse({"json": data}) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.post("/", json={"a": "123"}) assert response.json() == {"json": {"a": "123"}} @@ -203,7 +203,7 @@ def test_request_scope_interface() -> None: assert len(request) == 3 -def test_request_raw_path(test_client_factory: TestClientFactory) -> None: +def test_request_raw_path() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) path = request.scope["path"] @@ -211,14 +211,12 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response = PlainTextResponse(f"{path}, {raw_path}") await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/he%2Fllo") assert response.text == "/he/llo, b'/he%2Fllo'" -def test_request_without_setting_receive( - test_client_factory: TestClientFactory, -) -> None: +def test_request_without_setting_receive() -> None: """ If Request is instantiated without the receive channel, then .body() is not available. @@ -233,7 +231,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response = JSONResponse({"json": data}) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.post("/", json={"a": "123"}) assert response.json() == {"json": "Receive channel not available"} @@ -266,7 +264,7 @@ async def receiver() -> Message: ) -def test_request_is_disconnected(test_client_factory: TestClientFactory) -> None: +def test_request_is_disconnected() -> None: """ If a client disconnect occurs after reading request body then request will be set disconnected properly. @@ -283,7 +281,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: await response(scope, receive, send) disconnected_after_response = await request.is_disconnected() - client = test_client_factory(app) + client = TestClient(app) response = client.post("/", content="foo") assert response.json() == {"body": "foo", "disconnected": False} assert disconnected_after_response @@ -303,19 +301,19 @@ def test_request_state_object() -> None: s.new -def test_request_state(test_client_factory: TestClientFactory) -> None: +def test_request_state() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) request.state.example = 123 response = JSONResponse({"state.example": request.state.example}) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/123?a=abc") assert response.json() == {"state.example": 123} -def test_request_cookies(test_client_factory: TestClientFactory) -> None: +def test_request_cookies() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) mycookie = request.cookies.get("mycookie") @@ -327,14 +325,14 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.text == "Hello, world!" response = client.get("/") assert response.text == "Hello, cookies!" -def test_cookie_lenient_parsing(test_client_factory: TestClientFactory) -> None: +def test_cookie_lenient_parsing() -> None: """ The following test is based on a cookie set by Okta, a well-known authorization service. It turns out that it's common practice to set cookies that would be @@ -361,7 +359,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response = JSONResponse({"cookies": request.cookies}) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/", headers={"cookie": tough_cookie}) result = response.json() assert len(result["cookies"]) == 4 @@ -390,17 +388,13 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: ("a=b; h=i; a=c", {"a": "c", "h": "i"}), ], ) -def test_cookies_edge_cases( - set_cookie: str, - expected: dict[str, str], - test_client_factory: TestClientFactory, -) -> None: +def test_cookies_edge_cases(set_cookie: str, expected: dict[str, str]) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) response = JSONResponse({"cookies": request.cookies}) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/", headers={"cookie": set_cookie}) result = response.json() assert result["cookies"] == expected @@ -432,7 +426,6 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: def test_cookies_invalid( set_cookie: str, expected: dict[str, str], - test_client_factory: TestClientFactory, ) -> None: """ Cookie strings that are against the RFC6265 spec but which browsers will send if set @@ -444,20 +437,20 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response = JSONResponse({"cookies": request.cookies}) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/", headers={"cookie": set_cookie}) result = response.json() assert result["cookies"] == expected -def test_chunked_encoding(test_client_factory: TestClientFactory) -> None: +def test_chunked_encoding() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) body = await request.body() response = JSONResponse({"body": body.decode()}) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) def post_body() -> Iterator[bytes]: yield b"foo" @@ -467,7 +460,7 @@ def post_body() -> Iterator[bytes]: assert response.json() == {"body": "foobar"} -def test_request_send_push_promise(test_client_factory: TestClientFactory) -> None: +def test_request_send_push_promise() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: # the server is push-enabled scope["extensions"]["http.response.push"] = {} @@ -478,14 +471,12 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response = JSONResponse({"json": "OK"}) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.json() == {"json": "OK"} -def test_request_send_push_promise_without_push_extension( - test_client_factory: TestClientFactory, -) -> None: +def test_request_send_push_promise_without_push_extension() -> None: """ If server does not support the `http.response.push` extension, .send_push_promise() does nothing. @@ -498,14 +489,12 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response = JSONResponse({"json": "OK"}) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.json() == {"json": "OK"} -def test_request_send_push_promise_without_setting_send( - test_client_factory: TestClientFactory, -) -> None: +def test_request_send_push_promise_without_setting_send() -> None: """ If Request is instantiated without the send channel, then .send_push_promise() is not available. @@ -524,7 +513,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response = JSONResponse({"json": data}) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.json() == {"json": "Send channel not available"} @@ -595,12 +584,12 @@ async def rcv() -> Message: await s1.__anext__() -def test_request_url_outside_starlette_context(test_client_factory: TestClientFactory) -> None: +def test_request_url_outside_starlette_context() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) request.url_for("index") - client = test_client_factory(app) + client = TestClient(app) with pytest.raises( RuntimeError, match="The `url_for` method can only be used inside a Starlette application or with a router.", @@ -608,7 +597,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: client.get("/") -def test_request_url_starlette_context(test_client_factory: TestClientFactory) -> None: +def test_request_url_starlette_context() -> None: from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.routing import Route @@ -631,6 +620,6 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: app = Starlette(routes=[Route("/home", homepage)], middleware=[Middleware(CustomMiddleware)]) - client = test_client_factory(app) + client = TestClient(app) client.get("/home") assert url_for == URL("http://testserver/home") diff --git a/tests/test_responses.py b/tests/test_responses.py index d5ed83499..5d4b4fea5 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -17,41 +17,40 @@ from starlette.responses import FileResponse, JSONResponse, RedirectResponse, Response, StreamingResponse from starlette.testclient import TestClient from starlette.types import Message, Receive, Scope, Send -from tests.types import TestClientFactory -def test_text_response(test_client_factory: TestClientFactory) -> None: +def test_text_response() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response("hello, world", media_type="text/plain") await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.text == "hello, world" -def test_bytes_response(test_client_factory: TestClientFactory) -> None: +def test_bytes_response() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response(b"xxxxx", media_type="image/png") await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.content == b"xxxxx" -def test_json_none_response(test_client_factory: TestClientFactory) -> None: +def test_json_none_response() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: response = JSONResponse(None) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.json() is None assert response.content == b"null" -def test_redirect_response(test_client_factory: TestClientFactory) -> None: +def test_redirect_response() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: if scope["path"] == "/": response = Response("hello, world", media_type="text/plain") @@ -59,13 +58,13 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response = RedirectResponse("/") await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/redirect") assert response.text == "hello, world" assert response.url == "http://testserver/" -def test_quoting_redirect_response(test_client_factory: TestClientFactory) -> None: +def test_quoting_redirect_response() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: if scope["path"] == "/I ♥ Starlette/": response = Response("hello, world", media_type="text/plain") @@ -73,15 +72,13 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response = RedirectResponse("/I ♥ Starlette/") await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/redirect") assert response.text == "hello, world" assert response.url == "http://testserver/I%20%E2%99%A5%20Starlette/" -def test_redirect_response_content_length_header( - test_client_factory: TestClientFactory, -) -> None: +def test_redirect_response_content_length_header() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: if scope["path"] == "/": response = Response("hello", media_type="text/plain") # pragma: no cover @@ -89,13 +86,13 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response = RedirectResponse("/") await response(scope, receive, send) - client: TestClient = test_client_factory(app) + client: TestClient = TestClient(app) response = client.request("GET", "/redirect", follow_redirects=False) assert response.url == "http://testserver/redirect" assert response.headers["content-length"] == "0" -def test_streaming_response(test_client_factory: TestClientFactory) -> None: +def test_streaming_response() -> None: filled_by_bg_task = "" async def app(scope: Scope, receive: Receive, send: Send) -> None: @@ -117,15 +114,13 @@ async def numbers_for_cleanup(start: int = 1, stop: int = 5) -> None: await response(scope, receive, send) assert filled_by_bg_task == "" - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.text == "1, 2, 3, 4, 5" assert filled_by_bg_task == "6, 7, 8, 9" -def test_streaming_response_custom_iterator( - test_client_factory: TestClientFactory, -) -> None: +def test_streaming_response_custom_iterator() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: class CustomAsyncIterator: def __init__(self) -> None: @@ -143,14 +138,12 @@ async def __anext__(self) -> str: response = StreamingResponse(CustomAsyncIterator(), media_type="text/plain") await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.text == "12345" -def test_streaming_response_custom_iterable( - test_client_factory: TestClientFactory, -) -> None: +def test_streaming_response_custom_iterable() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: class CustomAsyncIterable: async def __aiter__(self) -> AsyncIterator[str | bytes]: @@ -160,12 +153,12 @@ async def __aiter__(self) -> AsyncIterator[str | bytes]: response = StreamingResponse(CustomAsyncIterable(), media_type="text/plain") await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.text == "12345" -def test_sync_streaming_response(test_client_factory: TestClientFactory) -> None: +def test_sync_streaming_response() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: def numbers(minimum: int, maximum: int) -> Iterator[str]: for i in range(minimum, maximum + 1): @@ -177,37 +170,39 @@ def numbers(minimum: int, maximum: int) -> Iterator[str]: response = StreamingResponse(generator, media_type="text/plain") await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.text == "1, 2, 3, 4, 5" -def test_response_headers(test_client_factory: TestClientFactory) -> None: +def test_response_headers() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: headers = {"x-header-1": "123", "x-header-2": "456"} response = Response("hello, world", media_type="text/plain", headers=headers) response.headers["x-header-2"] = "789" await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.headers["x-header-1"] == "123" assert response.headers["x-header-2"] == "789" -def test_response_phrase(test_client_factory: TestClientFactory) -> None: +def test_response_phrase() -> None: app = Response(status_code=204) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.reason_phrase == "No Content" app = Response(b"", status_code=123) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.reason_phrase == "" -def test_file_response(tmp_path: Path, test_client_factory: TestClientFactory) -> None: +def test_file_response( + tmp_path: Path, +) -> None: path = tmp_path / "xyz" content = b"" * 1000 path.write_bytes(content) @@ -233,7 +228,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: await response(scope, receive, send) assert filled_by_bg_task == "" - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") expected_disposition = 'attachment; filename="example.png"' assert response.status_code == status.HTTP_200_OK @@ -275,42 +270,50 @@ async def send(message: Message) -> None: await app({"type": "http", "method": "head", "headers": [(b"key", b"value")]}, receive, send) -def test_file_response_set_media_type(tmp_path: Path, test_client_factory: TestClientFactory) -> None: +def test_file_response_set_media_type( + tmp_path: Path, +) -> None: path = tmp_path / "xyz" path.write_bytes(b"") # By default, FileResponse will determine the `content-type` based on # the filename or path, unless a specific `media_type` is provided. app = FileResponse(path=path, filename="example.png", media_type="image/jpeg") - client: TestClient = test_client_factory(app) + client: TestClient = TestClient(app) response = client.get("/") assert response.headers["content-type"] == "image/jpeg" -def test_file_response_with_directory_raises_error(tmp_path: Path, test_client_factory: TestClientFactory) -> None: +def test_file_response_with_directory_raises_error( + tmp_path: Path, +) -> None: app = FileResponse(path=tmp_path, filename="example.png") - client = test_client_factory(app) + client = TestClient(app) with pytest.raises(RuntimeError) as exc_info: client.get("/") assert "is not a file" in str(exc_info.value) -def test_file_response_with_missing_file_raises_error(tmp_path: Path, test_client_factory: TestClientFactory) -> None: +def test_file_response_with_missing_file_raises_error( + tmp_path: Path, +) -> None: path = tmp_path / "404.txt" app = FileResponse(path=path, filename="404.txt") - client = test_client_factory(app) + client = TestClient(app) with pytest.raises(RuntimeError) as exc_info: client.get("/") assert "does not exist" in str(exc_info.value) -def test_file_response_with_chinese_filename(tmp_path: Path, test_client_factory: TestClientFactory) -> None: +def test_file_response_with_chinese_filename( + tmp_path: Path, +) -> None: content = b"file content" filename = "你好.txt" # probably "Hello.txt" in Chinese path = tmp_path / filename path.write_bytes(content) app = FileResponse(path=path, filename=filename) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") expected_disposition = "attachment; filename*=utf-8''%E4%BD%A0%E5%A5%BD.txt" assert response.status_code == status.HTTP_200_OK @@ -318,13 +321,15 @@ def test_file_response_with_chinese_filename(tmp_path: Path, test_client_factory assert response.headers["content-disposition"] == expected_disposition -def test_file_response_with_inline_disposition(tmp_path: Path, test_client_factory: TestClientFactory) -> None: +def test_file_response_with_inline_disposition( + tmp_path: Path, +) -> None: content = b"file content" filename = "hello.txt" path = tmp_path / filename path.write_bytes(content) app = FileResponse(path=path, filename=filename, content_disposition_type="inline") - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") expected_disposition = 'inline; filename="hello.txt"' assert response.status_code == status.HTTP_200_OK @@ -337,14 +342,16 @@ def test_file_response_with_method_warns(tmp_path: Path) -> None: FileResponse(path=tmp_path, filename="example.png", method="GET") -def test_file_response_with_range_header(tmp_path: Path, test_client_factory: TestClientFactory) -> None: +def test_file_response_with_range_header( + tmp_path: Path, +) -> None: content = b"file content" filename = "hello.txt" path = tmp_path / filename path.write_bytes(content) etag = '"a_non_autogenerated_etag"' app = FileResponse(path=path, filename=filename, headers={"etag": etag}) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/", headers={"range": "bytes=0-4", "if-range": etag}) assert response.status_code == status.HTTP_206_PARTIAL_CONTENT assert response.content == content[:5] @@ -353,7 +360,7 @@ def test_file_response_with_range_header(tmp_path: Path, test_client_factory: Te assert response.headers["content-range"] == f"bytes 0-4/{len(content)}" -def test_set_cookie(test_client_factory: TestClientFactory, monkeypatch: pytest.MonkeyPatch) -> None: +def test_set_cookie(monkeypatch: pytest.MonkeyPatch) -> None: # Mock time used as a reference for `Expires` by stdlib `SimpleCookie`. mocked_now = dt.datetime(2037, 1, 22, 12, 0, 0, tzinfo=dt.timezone.utc) monkeypatch.setattr(time, "time", lambda: mocked_now.timestamp()) @@ -373,7 +380,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: ) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.text == "Hello, world!" assert ( @@ -382,25 +389,25 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: ) -def test_set_cookie_path_none(test_client_factory: TestClientFactory) -> None: +def test_set_cookie_path_none() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response("Hello, world!", media_type="text/plain") response.set_cookie("mycookie", "myvalue", path=None) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.text == "Hello, world!" assert response.headers["set-cookie"] == "mycookie=myvalue; SameSite=lax" -def test_set_cookie_samesite_none(test_client_factory: TestClientFactory) -> None: +def test_set_cookie_samesite_none() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response("Hello, world!", media_type="text/plain") response.set_cookie("mycookie", "myvalue", samesite=None) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.text == "Hello, world!" assert response.headers["set-cookie"] == "mycookie=myvalue; Path=/" @@ -414,11 +421,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: pytest.param(10, id="int"), ], ) -def test_expires_on_set_cookie( - test_client_factory: TestClientFactory, - monkeypatch: pytest.MonkeyPatch, - expires: str, -) -> None: +def test_expires_on_set_cookie(monkeypatch: pytest.MonkeyPatch, expires: str) -> None: # Mock time used as a reference for `Expires` by stdlib `SimpleCookie`. mocked_now = dt.datetime(2037, 1, 22, 12, 0, 0, tzinfo=dt.timezone.utc) monkeypatch.setattr(time, "time", lambda: mocked_now.timestamp()) @@ -428,13 +431,13 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response.set_cookie("mycookie", "myvalue", expires=expires) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") cookie = SimpleCookie(response.headers.get("set-cookie")) assert cookie["mycookie"]["expires"] == "Thu, 22 Jan 2037 12:00:10 GMT" -def test_delete_cookie(test_client_factory: TestClientFactory) -> None: +def test_delete_cookie() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) response = Response("Hello, world!", media_type="text/plain") @@ -444,98 +447,96 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response.set_cookie("mycookie", "myvalue") await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.cookies["mycookie"] response = client.get("/") assert not response.cookies.get("mycookie") -def test_populate_headers(test_client_factory: TestClientFactory) -> None: +def test_populate_headers() -> None: app = Response(content="hi", headers={}, media_type="text/html") - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.text == "hi" assert response.headers["content-length"] == "2" assert response.headers["content-type"] == "text/html; charset=utf-8" -def test_head_method(test_client_factory: TestClientFactory) -> None: +def test_head_method() -> None: app = Response("hello, world", media_type="text/plain") - client = test_client_factory(app) + client = TestClient(app) response = client.head("/") assert response.text == "" -def test_empty_response(test_client_factory: TestClientFactory) -> None: +def test_empty_response() -> None: app = Response() - client: TestClient = test_client_factory(app) + client: TestClient = TestClient(app) response = client.get("/") assert response.content == b"" assert response.headers["content-length"] == "0" assert "content-type" not in response.headers -def test_empty_204_response(test_client_factory: TestClientFactory) -> None: +def test_empty_204_response() -> None: app = Response(status_code=204) - client: TestClient = test_client_factory(app) + client: TestClient = TestClient(app) response = client.get("/") assert "content-length" not in response.headers -def test_non_empty_response(test_client_factory: TestClientFactory) -> None: +def test_non_empty_response() -> None: app = Response(content="hi") - client: TestClient = test_client_factory(app) + client: TestClient = TestClient(app) response = client.get("/") assert response.headers["content-length"] == "2" -def test_response_do_not_add_redundant_charset( - test_client_factory: TestClientFactory, -) -> None: +def test_response_do_not_add_redundant_charset() -> None: app = Response(media_type="text/plain; charset=utf-8") - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.headers["content-type"] == "text/plain; charset=utf-8" -def test_file_response_known_size(tmp_path: Path, test_client_factory: TestClientFactory) -> None: +def test_file_response_known_size( + tmp_path: Path, +) -> None: path = tmp_path / "xyz" content = b"" * 1000 path.write_bytes(content) app = FileResponse(path=path, filename="example.png") - client: TestClient = test_client_factory(app) + client: TestClient = TestClient(app) response = client.get("/") assert response.headers["content-length"] == str(len(content)) -def test_streaming_response_unknown_size( - test_client_factory: TestClientFactory, -) -> None: +def test_streaming_response_unknown_size() -> None: app = StreamingResponse(content=iter(["hello", "world"])) - client: TestClient = test_client_factory(app) + client: TestClient = TestClient(app) response = client.get("/") assert "content-length" not in response.headers -def test_streaming_response_known_size(test_client_factory: TestClientFactory) -> None: +def test_streaming_response_known_size() -> None: app = StreamingResponse(content=iter(["hello", "world"]), headers={"content-length": "10"}) - client: TestClient = test_client_factory(app) + client: TestClient = TestClient(app) response = client.get("/") assert response.headers["content-length"] == "10" -def test_response_memoryview(test_client_factory: TestClientFactory) -> None: +def test_response_memoryview() -> None: app = Response(content=memoryview(b"\xc0")) - client: TestClient = test_client_factory(app) + client: TestClient = TestClient(app) response = client.get("/") assert response.content == b"\xc0" -def test_streaming_response_memoryview(test_client_factory: TestClientFactory) -> None: +def test_streaming_response_memoryview() -> None: app = StreamingResponse(content=iter([memoryview(b"\xc0"), memoryview(b"\xf5")])) - client: TestClient = test_client_factory(app) + client: TestClient = TestClient(app) response = client.get("/") assert response.content == b"\xc0\xf5" @@ -623,8 +624,10 @@ def readme_file(tmp_path: Path) -> Path: @pytest.fixture -def file_response_client(readme_file: Path, test_client_factory: TestClientFactory) -> TestClient: - return test_client_factory(app=FileResponse(str(readme_file))) +def file_response_client( + readme_file: Path, +) -> TestClient: + return TestClient(app=FileResponse(str(readme_file))) def test_file_response_without_range(file_response_client: TestClient) -> None: diff --git a/tests/test_routing.py b/tests/test_routing.py index 933fe7c31..9e35b2077 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -17,7 +17,6 @@ from starlette.testclient import TestClient from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketDisconnect -from tests.types import TestClientFactory def homepage(request: Request) -> Response: @@ -163,10 +162,8 @@ async def websocket_params(session: WebSocket) -> None: @pytest.fixture -def client( - test_client_factory: TestClientFactory, -) -> typing.Generator[TestClient, None, None]: - with test_client_factory(app) as client: +def client() -> typing.Generator[TestClient, None, None]: + with TestClient(app) as client: yield client @@ -312,7 +309,7 @@ def test_router_add_websocket_route(client: TestClient) -> None: assert text == "Hello, test!" -def test_router_middleware(test_client_factory: TestClientFactory) -> None: +def test_router_middleware() -> None: class CustomMiddleware: def __init__(self, app: ASGIApp) -> None: self.app = app @@ -326,7 +323,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: middleware=[Middleware(CustomMiddleware)], ) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.status_code == 200 assert response.text == "OK" @@ -353,8 +350,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: ) -def test_protocol_switch(test_client_factory: TestClientFactory) -> None: - client = test_client_factory(mixed_protocol_app) +def test_protocol_switch() -> None: + client = TestClient(mixed_protocol_app) response = client.get("/") assert response.status_code == 200 @@ -371,9 +368,9 @@ def test_protocol_switch(test_client_factory: TestClientFactory) -> None: ok = PlainTextResponse("OK") -def test_mount_urls(test_client_factory: TestClientFactory) -> None: +def test_mount_urls() -> None: mounted = Router([Mount("/users", ok, name="users")]) - client = test_client_factory(mounted) + client = TestClient(mounted) assert client.get("/users").status_code == 200 assert client.get("/users").url == "http://testserver/users/" assert client.get("/users/").status_code == 200 @@ -399,9 +396,9 @@ def test_reverse_mount_urls() -> None: mounted.url_path_for("users") -def test_mount_at_root(test_client_factory: TestClientFactory) -> None: +def test_mount_at_root() -> None: mounted = Router([Mount("/", ok, name="users")]) - client = test_client_factory(mounted) + client = TestClient(mounted) assert client.get("/").status_code == 200 @@ -434,8 +431,8 @@ def users_api(request: Request) -> JSONResponse: ) -def test_host_routing(test_client_factory: TestClientFactory) -> None: - client = test_client_factory(mixed_hosts_app, base_url="https://api.example.org/") +def test_host_routing() -> None: + client = TestClient(mixed_hosts_app, base_url="https://api.example.org/") response = client.get("/users") assert response.status_code == 200 @@ -444,7 +441,7 @@ def test_host_routing(test_client_factory: TestClientFactory) -> None: response = client.get("/") assert response.status_code == 404 - client = test_client_factory(mixed_hosts_app, base_url="https://www.example.org/") + client = TestClient(mixed_hosts_app, base_url="https://www.example.org/") response = client.get("/users") assert response.status_code == 200 @@ -453,7 +450,7 @@ def test_host_routing(test_client_factory: TestClientFactory) -> None: response = client.get("/") assert response.status_code == 200 - client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org:3600/") + client = TestClient(mixed_hosts_app, base_url="https://port.example.org:3600/") response = client.get("/users") assert response.status_code == 404 @@ -463,12 +460,12 @@ def test_host_routing(test_client_factory: TestClientFactory) -> None: # Port in requested Host is irrelevant. - client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org/") + client = TestClient(mixed_hosts_app, base_url="https://port.example.org/") response = client.get("/") assert response.status_code == 200 - client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org:5600/") + client = TestClient(mixed_hosts_app, base_url="https://port.example.org:5600/") response = client.get("/") assert response.status_code == 200 @@ -499,8 +496,8 @@ async def subdomain_app(scope: Scope, receive: Receive, send: Send) -> None: subdomain_router = Router(routes=[Host("{subdomain}.example.org", app=subdomain_app, name="subdomains")]) -def test_subdomain_routing(test_client_factory: TestClientFactory) -> None: - client = test_client_factory(subdomain_router, base_url="https://foo.example.org/") +def test_subdomain_routing() -> None: + client = TestClient(subdomain_router, base_url="https://foo.example.org/") response = client.get("/") assert response.status_code == 200 @@ -535,9 +532,9 @@ async def echo_urls(request: Request) -> JSONResponse: ] -def test_url_for_with_root_path(test_client_factory: TestClientFactory) -> None: +def test_url_for_with_root_path() -> None: app = Starlette(routes=echo_url_routes) - client = test_client_factory(app, base_url="https://www.example.org/", root_path="/sub_path") + client = TestClient(app, base_url="https://www.example.org/", root_path="/sub_path") response = client.get("/sub_path/") assert response.json() == { "index": "https://www.example.org/sub_path/", @@ -565,31 +562,27 @@ def test_url_for_with_double_mount() -> None: assert url == "/mount/static/123" -def test_url_for_with_root_path_ending_with_slash(test_client_factory: TestClientFactory) -> None: +def test_url_for_with_root_path_ending_with_slash() -> None: def homepage(request: Request) -> JSONResponse: return JSONResponse({"index": str(request.url_for("homepage"))}) app = Starlette(routes=[Route("/", homepage, name="homepage")]) - client = test_client_factory(app, base_url="https://www.example.org/", root_path="/sub_path/") + client = TestClient(app, base_url="https://www.example.org/", root_path="/sub_path/") response = client.get("/sub_path/") assert response.json() == {"index": "https://www.example.org/sub_path/"} -def test_standalone_route_matches( - test_client_factory: TestClientFactory, -) -> None: +def test_standalone_route_matches() -> None: app = Route("/", PlainTextResponse("Hello, World!")) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.status_code == 200 assert response.text == "Hello, World!" -def test_standalone_route_does_not_match( - test_client_factory: typing.Callable[..., TestClient], -) -> None: +def test_standalone_route_does_not_match() -> None: app = Route("/", PlainTextResponse("Hello, World!")) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/invalid") assert response.status_code == 404 assert response.text == "Not Found" @@ -601,27 +594,23 @@ async def ws_helloworld(websocket: WebSocket) -> None: await websocket.close() -def test_standalone_ws_route_matches( - test_client_factory: TestClientFactory, -) -> None: +def test_standalone_ws_route_matches() -> None: app = WebSocketRoute("/", ws_helloworld) - client = test_client_factory(app) + client = TestClient(app) with client.websocket_connect("/") as websocket: text = websocket.receive_text() assert text == "Hello, world!" -def test_standalone_ws_route_does_not_match( - test_client_factory: TestClientFactory, -) -> None: +def test_standalone_ws_route_does_not_match() -> None: app = WebSocketRoute("/", ws_helloworld) - client = test_client_factory(app) + client = TestClient(app) with pytest.raises(WebSocketDisconnect): with client.websocket_connect("/invalid"): pass # pragma: no cover -def test_lifespan_async(test_client_factory: TestClientFactory) -> None: +def test_lifespan_async() -> None: startup_complete = False shutdown_complete = False @@ -645,7 +634,7 @@ async def run_shutdown() -> None: assert not startup_complete assert not shutdown_complete - with test_client_factory(app) as client: + with TestClient(app) as client: assert startup_complete assert not shutdown_complete client.get("/") @@ -653,7 +642,7 @@ async def run_shutdown() -> None: assert shutdown_complete -def test_lifespan_with_on_events(test_client_factory: TestClientFactory) -> None: +def test_lifespan_with_on_events() -> None: lifespan_called = False startup_called = False shutdown_called = False @@ -685,7 +674,7 @@ def run_shutdown() -> None: # pragma: no cover assert not shutdown_called # Triggers the lifespan events - with test_client_factory(app): + with TestClient(app): ... assert lifespan_called @@ -693,7 +682,7 @@ def run_shutdown() -> None: # pragma: no cover assert not shutdown_called -def test_lifespan_sync(test_client_factory: TestClientFactory) -> None: +def test_lifespan_sync() -> None: startup_complete = False shutdown_complete = False @@ -717,7 +706,7 @@ def run_shutdown() -> None: assert not startup_complete assert not shutdown_complete - with test_client_factory(app) as client: + with TestClient(app) as client: assert startup_complete assert not shutdown_complete client.get("/") @@ -725,9 +714,7 @@ def run_shutdown() -> None: assert shutdown_complete -def test_lifespan_state_unsupported( - test_client_factory: TestClientFactory, -) -> None: +def test_lifespan_state_unsupported() -> None: @contextlib.asynccontextmanager async def lifespan( app: ASGIApp, @@ -744,11 +731,11 @@ async def no_state_wrapper(scope: Scope, receive: Receive, send: Send) -> None: await app(scope, receive, send) with pytest.raises(RuntimeError, match='The server does not support "state" in the lifespan scope'): - with test_client_factory(no_state_wrapper): + with TestClient(no_state_wrapper): raise AssertionError("Should not be called") # pragma: no cover -def test_lifespan_state_async_cm(test_client_factory: TestClientFactory) -> None: +def test_lifespan_state_async_cm() -> None: startup_complete = False shutdown_complete = False @@ -786,7 +773,7 @@ async def lifespan(app: Starlette) -> typing.AsyncIterator[State]: assert not startup_complete assert not shutdown_complete - with test_client_factory(app) as client: + with TestClient(app) as client: assert startup_complete assert not shutdown_complete client.get("/") @@ -796,7 +783,7 @@ async def lifespan(app: Starlette) -> typing.AsyncIterator[State]: assert shutdown_complete -def test_raise_on_startup(test_client_factory: TestClientFactory) -> None: +def test_raise_on_startup() -> None: def run_startup() -> None: raise RuntimeError() @@ -814,12 +801,12 @@ async def _send(message: Message) -> None: await router(scope, receive, _send) with pytest.raises(RuntimeError): - with test_client_factory(app): + with TestClient(app): pass # pragma: no cover assert startup_failed -def test_raise_on_shutdown(test_client_factory: TestClientFactory) -> None: +def test_raise_on_shutdown() -> None: def run_shutdown() -> None: raise RuntimeError() @@ -827,12 +814,12 @@ def run_shutdown() -> None: app = Router(on_shutdown=[run_shutdown]) with pytest.raises(RuntimeError): - with test_client_factory(app): + with TestClient(app): pass # pragma: no cover -def test_partial_async_endpoint(test_client_factory: TestClientFactory) -> None: - test_client = test_client_factory(app) +def test_partial_async_endpoint() -> None: + test_client = TestClient(app) response = test_client.get("/partial") assert response.status_code == 200 assert response.json() == {"arg": "foo"} @@ -842,10 +829,8 @@ def test_partial_async_endpoint(test_client_factory: TestClientFactory) -> None: assert cls_method_response.json() == {"arg": "foo"} -def test_partial_async_ws_endpoint( - test_client_factory: TestClientFactory, -) -> None: - test_client = test_client_factory(app) +def test_partial_async_ws_endpoint() -> None: + test_client = TestClient(app) with test_client.websocket_connect("/partial/ws") as websocket: data = websocket.receive_json() assert data == {"url": "ws://testserver/partial/ws"} @@ -977,10 +962,9 @@ def assert_middleware_header_route(request: Request) -> Response: ], ) def test_base_route_middleware( - test_client_factory: TestClientFactory, app: Starlette, ) -> None: - test_client = test_client_factory(app) + test_client = TestClient(app) response = test_client.get("/home") assert response.status_code == 200 @@ -1004,9 +988,7 @@ def test_mount_asgi_app_with_middleware_url_path_for() -> None: mounted_app_with_middleware.url_path_for("route") -def test_add_route_to_app_after_mount( - test_client_factory: typing.Callable[..., TestClient], -) -> None: +def test_add_route_to_app_after_mount() -> None: """Checks that Mount will pick up routes added to the underlying app after it is mounted """ @@ -1017,29 +999,25 @@ def test_add_route_to_app_after_mount( endpoint=homepage, methods=["GET"], ) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/http/inner") assert response.status_code == 200 -def test_exception_on_mounted_apps( - test_client_factory: TestClientFactory, -) -> None: +def test_exception_on_mounted_apps() -> None: def exc(request: Request) -> None: raise Exception("Exc") sub_app = Starlette(routes=[Route("/", exc)]) app = Starlette(routes=[Mount("/sub", app=sub_app)]) - client = test_client_factory(app) + client = TestClient(app) with pytest.raises(Exception) as ctx: client.get("/sub/") assert str(ctx.value) == "Exc" -def test_mounted_middleware_does_not_catch_exception( - test_client_factory: typing.Callable[..., TestClient], -) -> None: +def test_mounted_middleware_does_not_catch_exception() -> None: # https://github.com/encode/starlette/pull/1649#discussion_r960236107 def exc(request: Request) -> Response: raise HTTPException(status_code=403, detail="auth") @@ -1073,7 +1051,7 @@ async def modified_send(msg: Message) -> None: middleware=[Middleware(NamedMiddleware, name="Outer")], ) - client = test_client_factory(app) + client = TestClient(app) resp = client.get("/home") assert resp.status_code == 200, resp.content @@ -1092,9 +1070,7 @@ async def modified_send(msg: Message) -> None: assert "X-Mounted" in resp.headers -def test_websocket_route_middleware( - test_client_factory: TestClientFactory, -) -> None: +def test_websocket_route_middleware() -> None: async def websocket_endpoint(session: WebSocket) -> None: await session.accept() await session.send_text("Hello, world!") @@ -1122,7 +1098,7 @@ async def modified_send(msg: Message) -> None: ] ) - client = test_client_factory(app) + client = TestClient(app) with client.websocket_connect("/ws") as websocket: text = websocket.receive_text() @@ -1263,9 +1239,9 @@ async def pure_asgi_echo_paths(scope: Scope, receive: Receive, send: Send, name: ] -def test_paths_with_root_path(test_client_factory: TestClientFactory) -> None: +def test_paths_with_root_path() -> None: app = Starlette(routes=echo_paths_routes) - client = test_client_factory(app, base_url="https://www.example.org/", root_path="/root") + client = TestClient(app, base_url="https://www.example.org/", root_path="/root") response = client.get("/root/path") assert response.status_code == 200 assert response.json() == { diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 3b321ca0b..0c9a126fe 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -4,8 +4,8 @@ from starlette.responses import Response from starlette.routing import Host, Mount, Route, Router, WebSocketRoute from starlette.schemas import SchemaGenerator +from starlette.testclient import TestClient from starlette.websockets import WebSocket -from tests.types import TestClientFactory schemas = SchemaGenerator({"openapi": "3.0.0", "info": {"title": "Example API", "version": "1.0"}}) @@ -247,8 +247,8 @@ def test_schema_generation() -> None: """ -def test_schema_endpoint(test_client_factory: TestClientFactory) -> None: - client = test_client_factory(app) +def test_schema_endpoint() -> None: + client = TestClient(app) response = client.get("/schema") assert response.headers["Content-Type"] == "application/vnd.oai.openapi" assert response.text.strip() == EXPECTED_SCHEMA.strip() diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index b4f131719..b20a13e8a 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -16,34 +16,34 @@ from starlette.responses import Response from starlette.routing import Mount from starlette.staticfiles import StaticFiles -from tests.types import TestClientFactory +from starlette.testclient import TestClient -def test_staticfiles(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_staticfiles(tmpdir: Path) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmpdir) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/example.txt") assert response.status_code == 200 assert response.text == "" -def test_staticfiles_with_pathlib(tmp_path: Path, test_client_factory: TestClientFactory) -> None: +def test_staticfiles_with_pathlib(tmp_path: Path) -> None: path = tmp_path / "example.txt" with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmp_path) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/example.txt") assert response.status_code == 200 assert response.text == "" -def test_staticfiles_head_with_middleware(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_staticfiles_head_with_middleware(tmpdir: Path) -> None: """ see https://github.com/encode/starlette/pull/935 """ @@ -59,62 +59,62 @@ async def does_nothing_middleware(request: Request, call_next: RequestResponseEn middleware = [Middleware(BaseHTTPMiddleware, dispatch=does_nothing_middleware)] app = Starlette(routes=routes, middleware=middleware) - client = test_client_factory(app) + client = TestClient(app) response = client.head("/static/example.txt") assert response.status_code == 200 assert response.headers.get("content-length") == "100" -def test_staticfiles_with_package(test_client_factory: TestClientFactory) -> None: +def test_staticfiles_with_package() -> None: app = StaticFiles(packages=["tests"]) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/example.txt") assert response.status_code == 200 assert response.text == "123\n" app = StaticFiles(packages=[("tests", "statics")]) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/example.txt") assert response.status_code == 200 assert response.text == "123\n" -def test_staticfiles_post(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_staticfiles_post(tmpdir: Path) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")] app = Starlette(routes=routes) - client = test_client_factory(app) + client = TestClient(app) response = client.post("/example.txt") assert response.status_code == 405 assert response.text == "Method Not Allowed" -def test_staticfiles_with_directory_returns_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_staticfiles_with_directory_returns_404(tmpdir: Path) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")] app = Starlette(routes=routes) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.status_code == 404 assert response.text == "Not Found" -def test_staticfiles_with_missing_file_returns_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_staticfiles_with_missing_file_returns_404(tmpdir: Path) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")] app = Starlette(routes=routes) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/404.txt") assert response.status_code == 404 @@ -128,32 +128,30 @@ def test_staticfiles_instantiated_with_missing_directory(tmpdir: Path) -> None: assert "does not exist" in str(exc_info.value) -def test_staticfiles_configured_with_missing_directory(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_staticfiles_configured_with_missing_directory(tmpdir: Path) -> None: path = os.path.join(tmpdir, "no_such_directory") app = StaticFiles(directory=path, check_dir=False) - client = test_client_factory(app) + client = TestClient(app) with pytest.raises(RuntimeError) as exc_info: client.get("/example.txt") assert "does not exist" in str(exc_info.value) -def test_staticfiles_configured_with_file_instead_of_directory( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_staticfiles_configured_with_file_instead_of_directory(tmpdir: Path) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=path, check_dir=False) - client = test_client_factory(app) + client = TestClient(app) with pytest.raises(RuntimeError) as exc_info: client.get("/example.txt") assert "is not a directory" in str(exc_info.value) -def test_staticfiles_config_check_occurs_only_once(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_staticfiles_config_check_occurs_only_once(tmpdir: Path) -> None: app = StaticFiles(directory=tmpdir) - client = test_client_factory(app) + client = TestClient(app) assert not app.config_checked with pytest.raises(HTTPException): @@ -185,26 +183,26 @@ def test_staticfiles_prevents_breaking_out_of_directory(tmpdir: Path) -> None: assert exc_info.value.detail == "Not Found" -def test_staticfiles_never_read_file_for_head_method(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_staticfiles_never_read_file_for_head_method(tmpdir: Path) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmpdir) - client = test_client_factory(app) + client = TestClient(app) response = client.head("/example.txt") assert response.status_code == 200 assert response.content == b"" assert response.headers["content-length"] == "14" -def test_staticfiles_304_with_etag_match(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_staticfiles_304_with_etag_match(tmpdir: Path) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmpdir) - client = test_client_factory(app) + client = TestClient(app) first_resp = client.get("/example.txt") assert first_resp.status_code == 200 last_etag = first_resp.headers["etag"] @@ -216,13 +214,13 @@ def test_staticfiles_304_with_etag_match(tmpdir: Path, test_client_factory: Test assert second_resp.content == b"" -def test_staticfiles_200_with_etag_mismatch(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_staticfiles_200_with_etag_mismatch(tmpdir: Path) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmpdir) - client = test_client_factory(app) + client = TestClient(app) first_resp = client.get("/example.txt") assert first_resp.status_code == 200 assert first_resp.headers["etag"] != '"123"' @@ -231,9 +229,7 @@ def test_staticfiles_200_with_etag_mismatch(tmpdir: Path, test_client_factory: T assert second_resp.content == b"" -def test_staticfiles_304_with_last_modified_compare_last_req( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_staticfiles_304_with_last_modified_compare_last_req(tmpdir: Path) -> None: path = os.path.join(tmpdir, "example.txt") file_last_modified_time = time.mktime(time.strptime("2013-10-10 23:40:00", "%Y-%m-%d %H:%M:%S")) with open(path, "w") as file: @@ -241,7 +237,7 @@ def test_staticfiles_304_with_last_modified_compare_last_req( os.utime(path, (file_last_modified_time, file_last_modified_time)) app = StaticFiles(directory=tmpdir) - client = test_client_factory(app) + client = TestClient(app) # last modified less than last request, 304 response = client.get("/example.txt", headers={"If-Modified-Since": "Thu, 11 Oct 2013 15:30:19 GMT"}) assert response.status_code == 304 @@ -252,7 +248,7 @@ def test_staticfiles_304_with_last_modified_compare_last_req( assert response.content == b"" -def test_staticfiles_html_normal(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_staticfiles_html_normal(tmpdir: Path) -> None: path = os.path.join(tmpdir, "404.html") with open(path, "w") as file: file.write("

Custom not found page

") @@ -263,7 +259,7 @@ def test_staticfiles_html_normal(tmpdir: Path, test_client_factory: TestClientFa file.write("

Hello

") app = StaticFiles(directory=tmpdir, html=True) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/dir/") assert response.url == "http://testserver/dir/" @@ -285,7 +281,7 @@ def test_staticfiles_html_normal(tmpdir: Path, test_client_factory: TestClientFa assert response.text == "

Custom not found page

" -def test_staticfiles_html_without_index(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_staticfiles_html_without_index(tmpdir: Path) -> None: path = os.path.join(tmpdir, "404.html") with open(path, "w") as file: file.write("

Custom not found page

") @@ -293,7 +289,7 @@ def test_staticfiles_html_without_index(tmpdir: Path, test_client_factory: TestC os.mkdir(path) app = StaticFiles(directory=tmpdir, html=True) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/dir/") assert response.url == "http://testserver/dir/" @@ -310,7 +306,7 @@ def test_staticfiles_html_without_index(tmpdir: Path, test_client_factory: TestC assert response.text == "

Custom not found page

" -def test_staticfiles_html_without_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_staticfiles_html_without_404(tmpdir: Path) -> None: path = os.path.join(tmpdir, "dir") os.mkdir(path) path = os.path.join(path, "index.html") @@ -318,7 +314,7 @@ def test_staticfiles_html_without_404(tmpdir: Path, test_client_factory: TestCli file.write("

Hello

") app = StaticFiles(directory=tmpdir, html=True) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/dir/") assert response.url == "http://testserver/dir/" @@ -335,13 +331,13 @@ def test_staticfiles_html_without_404(tmpdir: Path, test_client_factory: TestCli assert exc_info.value.status_code == 404 -def test_staticfiles_html_only_files(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_staticfiles_html_only_files(tmpdir: Path) -> None: path = os.path.join(tmpdir, "hello.html") with open(path, "w") as file: file.write("

Hello

") app = StaticFiles(directory=tmpdir, html=True) - client = test_client_factory(app) + client = TestClient(app) with pytest.raises(HTTPException) as exc_info: response = client.get("/") @@ -352,9 +348,7 @@ def test_staticfiles_html_only_files(tmpdir: Path, test_client_factory: TestClie assert response.text == "

Hello

" -def test_staticfiles_cache_invalidation_for_deleted_file_html_mode( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_staticfiles_cache_invalidation_for_deleted_file_html_mode(tmpdir: Path) -> None: path_404 = os.path.join(tmpdir, "404.html") with open(path_404, "w") as file: file.write("

404 file

") @@ -367,7 +361,7 @@ def test_staticfiles_cache_invalidation_for_deleted_file_html_mode( os.utime(path_some, (common_modified_time, common_modified_time)) app = StaticFiles(directory=tmpdir, html=True) - client = test_client_factory(app) + client = TestClient(app) resp_exists = client.get("/some.html") assert resp_exists.status_code == 200 @@ -389,9 +383,7 @@ def test_staticfiles_cache_invalidation_for_deleted_file_html_mode( assert resp_deleted.text == "

404 file

" -def test_staticfiles_with_invalid_dir_permissions_returns_401( - tmp_path: Path, test_client_factory: TestClientFactory -) -> None: +def test_staticfiles_with_invalid_dir_permissions_returns_401(tmp_path: Path) -> None: (tmp_path / "example.txt").write_bytes(b"") original_mode = tmp_path.stat().st_mode @@ -405,7 +397,7 @@ def test_staticfiles_with_invalid_dir_permissions_returns_401( ) ] app = Starlette(routes=routes) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/example.txt") assert response.status_code == 401 @@ -414,38 +406,38 @@ def test_staticfiles_with_invalid_dir_permissions_returns_401( tmp_path.chmod(original_mode) -def test_staticfiles_with_missing_dir_returns_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_staticfiles_with_missing_dir_returns_404(tmpdir: Path) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")] app = Starlette(routes=routes) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/foo/example.txt") assert response.status_code == 404 assert response.text == "Not Found" -def test_staticfiles_access_file_as_dir_returns_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_staticfiles_access_file_as_dir_returns_404(tmpdir: Path) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")] app = Starlette(routes=routes) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/example.txt/foo") assert response.status_code == 404 assert response.text == "Not Found" -def test_staticfiles_filename_too_long(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_staticfiles_filename_too_long(tmpdir: Path) -> None: routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")] app = Starlette(routes=routes) - client = test_client_factory(app) + client = TestClient(app) path_max_size = os.pathconf("/", "PC_PATH_MAX") response = client.get(f"/{'a' * path_max_size}.txt") @@ -453,11 +445,7 @@ def test_staticfiles_filename_too_long(tmpdir: Path, test_client_factory: TestCl assert response.text == "Not Found" -def test_staticfiles_unhandled_os_error_returns_500( - tmpdir: Path, - test_client_factory: TestClientFactory, - monkeypatch: pytest.MonkeyPatch, -) -> None: +def test_staticfiles_unhandled_os_error_returns_500(tmpdir: Path, monkeypatch: pytest.MonkeyPatch) -> None: def mock_timeout(*args: typing.Any, **kwargs: typing.Any) -> None: raise TimeoutError @@ -467,7 +455,7 @@ def mock_timeout(*args: typing.Any, **kwargs: typing.Any) -> None: routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")] app = Starlette(routes=routes) - client = test_client_factory(app, raise_server_exceptions=False) + client = TestClient(app, raise_server_exceptions=False) monkeypatch.setattr("starlette.staticfiles.StaticFiles.lookup_path", mock_timeout) @@ -476,7 +464,7 @@ def mock_timeout(*args: typing.Any, **kwargs: typing.Any) -> None: assert response.text == "Internal Server Error" -def test_staticfiles_follows_symlinks(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_staticfiles_follows_symlinks(tmpdir: Path) -> None: statics_path = os.path.join(tmpdir, "statics") os.mkdir(statics_path) @@ -489,7 +477,7 @@ def test_staticfiles_follows_symlinks(tmpdir: Path, test_client_factory: TestCli os.symlink(source_file_path, statics_file_path) app = StaticFiles(directory=statics_path, follow_symlink=True) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/index.html") assert response.url == "http://testserver/index.html" @@ -497,7 +485,7 @@ def test_staticfiles_follows_symlinks(tmpdir: Path, test_client_factory: TestCli assert response.text == "

Hello

" -def test_staticfiles_follows_symlink_directories(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_staticfiles_follows_symlink_directories(tmpdir: Path) -> None: statics_path = os.path.join(tmpdir, "statics") statics_html_path = os.path.join(statics_path, "html") os.mkdir(statics_path) @@ -510,7 +498,7 @@ def test_staticfiles_follows_symlink_directories(tmpdir: Path, test_client_facto os.symlink(source_path, statics_html_path) app = StaticFiles(directory=statics_path, follow_symlink=True) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/html/page.html") assert response.url == "http://testserver/html/page.html" @@ -576,7 +564,7 @@ def test_staticfiles_avoids_path_traversal(tmp_path: Path) -> None: assert exc_info.value.detail == "Not Found" -def test_staticfiles_self_symlinks(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_staticfiles_self_symlinks(tmpdir: Path) -> None: statics_path = os.path.join(tmpdir, "statics") os.mkdir(statics_path) @@ -588,7 +576,7 @@ def test_staticfiles_self_symlinks(tmpdir: Path, test_client_factory: TestClient os.symlink(statics_path, statics_symlink_path) app = StaticFiles(directory=statics_symlink_path, follow_symlink=True) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/index.html") assert response.url == "http://testserver/index.html" diff --git a/tests/test_templates.py b/tests/test_templates.py index 6b2080c17..473601d50 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -15,10 +15,12 @@ from starlette.responses import Response from starlette.routing import Route from starlette.templating import Jinja2Templates -from tests.types import TestClientFactory +from starlette.testclient import TestClient -def test_templates(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_templates( + tmpdir: Path, +) -> None: path = os.path.join(tmpdir, "index.html") with open(path, "w") as file: file.write("Hello, world") @@ -29,14 +31,16 @@ async def homepage(request: Request) -> Response: app = Starlette(debug=True, routes=[Route("/", endpoint=homepage)]) templates = Jinja2Templates(directory=str(tmpdir)) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.text == "Hello, world" assert response.template.name == "index.html" # type: ignore assert set(response.context.keys()) == {"request"} # type: ignore -def test_calls_context_processors(tmp_path: Path, test_client_factory: TestClientFactory) -> None: +def test_calls_context_processors( + tmp_path: Path, +) -> None: path = tmp_path / "index.html" path.write_text("Hello {{ username }}") @@ -57,14 +61,16 @@ def hello_world_processor(request: Request) -> dict[str, str]: ], ) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.text == "Hello World" assert response.template.name == "index.html" # type: ignore assert set(response.context.keys()) == {"request", "username"} # type: ignore -def test_template_with_middleware(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_template_with_middleware( + tmpdir: Path, +) -> None: path = os.path.join(tmpdir, "index.html") with open(path, "w") as file: file.write("Hello, world") @@ -83,14 +89,16 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) - ) templates = Jinja2Templates(directory=str(tmpdir)) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.text == "Hello, world" assert response.template.name == "index.html" # type: ignore assert set(response.context.keys()) == {"request"} # type: ignore -def test_templates_with_directories(tmp_path: Path, test_client_factory: TestClientFactory) -> None: +def test_templates_with_directories( + tmp_path: Path, +) -> None: dir_a = tmp_path.resolve() / "a" dir_a.mkdir() template_a = dir_a / "template_a.html" @@ -113,7 +121,7 @@ async def page_b(request: Request) -> Response: ) templates = Jinja2Templates(directory=[dir_a, dir_b]) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/a") assert response.text == " a" assert response.template.name == "template_a.html" # type: ignore @@ -145,7 +153,9 @@ def test_templates_with_directory(tmpdir: Path) -> None: assert template.render({}) == "Hello" -def test_templates_with_environment(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_templates_with_environment( + tmpdir: Path, +) -> None: path = os.path.join(tmpdir, "index.html") with open(path, "w") as file: file.write("Hello, world") @@ -159,7 +169,7 @@ async def homepage(request: Request) -> Response: routes=[Route("/", endpoint=homepage)], ) templates = Jinja2Templates(env=env) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.text == "Hello, world" assert response.template.name == "index.html" # type: ignore @@ -171,7 +181,9 @@ def test_templates_with_environment_options_emit_warning(tmpdir: Path) -> None: Jinja2Templates(str(tmpdir), autoescape=True) -def test_templates_with_kwargs_only(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_templates_with_kwargs_only( + tmpdir: Path, +) -> None: # MAINTAINERS: remove after 1.0 path = os.path.join(tmpdir, "index.html") with open(path, "w") as file: @@ -192,7 +204,7 @@ def page(request: Request) -> Response: ) app = Starlette(routes=[Route("/", page)]) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.text == "value: b" # context was rendered @@ -215,7 +227,7 @@ def test_templates_with_kwargs_only_requires_request_in_context(tmpdir: Path) -> def test_templates_with_kwargs_only_warns_when_no_request_keyword( - tmpdir: Path, test_client_factory: TestClientFactory + tmpdir: Path, ) -> None: # MAINTAINERS: remove after 1.0 @@ -229,7 +241,7 @@ def page(request: Request) -> Response: return templates.TemplateResponse(name="index.html", context={"request": request}) app = Starlette(routes=[Route("/", page)]) - client = test_client_factory(app) + client = TestClient(app) with pytest.warns( DeprecationWarning, @@ -247,7 +259,7 @@ def test_templates_with_requires_request_in_context(tmpdir: Path) -> None: def test_templates_warns_when_first_argument_isnot_request( - tmpdir: Path, test_client_factory: TestClientFactory + tmpdir: Path, ) -> None: # MAINTAINERS: remove after 1.0 path = os.path.join(tmpdir, "index.html") @@ -268,7 +280,7 @@ def page(request: Request) -> Response: ) app = Starlette(routes=[Route("/", page)]) - client = test_client_factory(app) + client = TestClient(app) with pytest.warns(DeprecationWarning): response = client.get("/") @@ -279,7 +291,9 @@ def page(request: Request) -> Response: spy.assert_called() -def test_templates_when_first_argument_is_request(tmpdir: Path, test_client_factory: TestClientFactory) -> None: +def test_templates_when_first_argument_is_request( + tmpdir: Path, +) -> None: # MAINTAINERS: remove after 1.0 path = os.path.join(tmpdir, "index.html") with open(path, "w") as file: @@ -300,7 +314,7 @@ def page(request: Request) -> Response: ) app = Starlette(routes=[Route("/", page)]) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.text == "value: b" # context was rendered diff --git a/tests/test_testclient.py b/tests/test_testclient.py index 478dbca46..83f4cbaf5 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -2,16 +2,13 @@ import itertools import sys -from asyncio import Task, current_task as asyncio_current_task +from asyncio import current_task from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from typing import Any import anyio import anyio.lowlevel import pytest -import sniffio -import trio.lowlevel from starlette.applications import Starlette from starlette.middleware import Middleware @@ -21,7 +18,6 @@ from starlette.testclient import ASGIInstance, TestClient from starlette.types import ASGIApp, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketDisconnect -from tests.types import TestClientFactory def mock_service_endpoint(request: Request) -> JSONResponse: @@ -31,26 +27,11 @@ def mock_service_endpoint(request: Request) -> JSONResponse: mock_service = Starlette(routes=[Route("/", endpoint=mock_service_endpoint)]) -def current_task() -> Task[Any] | trio.lowlevel.Task: - # anyio's TaskInfo comparisons are invalid after their associated native - # task object is GC'd https://github.com/agronholm/anyio/issues/324 - asynclib_name = sniffio.current_async_library() - if asynclib_name == "trio": - return trio.lowlevel.current_task() - - if asynclib_name == "asyncio": - task = asyncio_current_task() - if task is None: - raise RuntimeError("must be called from a running task") # pragma: no cover - return task - raise RuntimeError(f"unsupported asynclib={asynclib_name}") # pragma: no cover - - def startup() -> None: raise RuntimeError() -def test_use_testclient_in_endpoint(test_client_factory: TestClientFactory) -> None: +def test_use_testclient_in_endpoint() -> None: """ We should be able to use the test client within applications. @@ -59,13 +40,13 @@ def test_use_testclient_in_endpoint(test_client_factory: TestClientFactory) -> N """ def homepage(request: Request) -> JSONResponse: - client = test_client_factory(mock_service) + client = TestClient(mock_service) response = client.get("/") return JSONResponse(response.json()) app = Starlette(routes=[Route("/", endpoint=homepage)]) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.json() == {"mock": "example"} @@ -89,7 +70,7 @@ def test_testclient_headers_behavior() -> None: assert client.headers.get("Authentication") == "Bearer 123" -def test_use_testclient_as_contextmanager(test_client_factory: TestClientFactory, anyio_backend_name: str) -> None: +def test_use_testclient_as_contextmanager(anyio_backend_name: str) -> None: """ This test asserts a number of properties that are important for an app level task_group @@ -129,7 +110,7 @@ async def loop_id(request: Request) -> JSONResponse: routes=[Route("/loop_id", endpoint=loop_id)], ) - client = test_client_factory(app) + client = TestClient(app) with client: # within a TestClient context every async request runs in the same thread @@ -167,16 +148,16 @@ async def loop_id(request: Request) -> JSONResponse: assert first_task is not startup_task -def test_error_on_startup(test_client_factory: TestClientFactory) -> None: +def test_error_on_startup() -> None: with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"): startup_error_app = Starlette(on_startup=[startup]) with pytest.raises(RuntimeError): - with test_client_factory(startup_error_app): + with TestClient(startup_error_app): pass # pragma: no cover -def test_exception_in_middleware(test_client_factory: TestClientFactory) -> None: +def test_exception_in_middleware() -> None: class MiddlewareException(Exception): pass @@ -190,11 +171,11 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: broken_middleware = Starlette(middleware=[Middleware(BrokenMiddleware)]) with pytest.raises(MiddlewareException): - with test_client_factory(broken_middleware): + with TestClient(broken_middleware): pass # pragma: no cover -def test_testclient_asgi2(test_client_factory: TestClientFactory) -> None: +def test_testclient_asgi2() -> None: def app(scope: Scope) -> ASGIInstance: async def inner(receive: Receive, send: Send) -> None: await send( @@ -208,12 +189,12 @@ async def inner(receive: Receive, send: Send) -> None: return inner - client = test_client_factory(app) # type: ignore + client = TestClient(app) # type: ignore response = client.get("/") assert response.text == "Hello, world!" -def test_testclient_asgi3(test_client_factory: TestClientFactory) -> None: +def test_testclient_asgi3() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: await send( { @@ -224,12 +205,12 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: ) await send({"type": "http.response.body", "body": b"Hello, world!"}) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.text == "Hello, world!" -def test_websocket_blocking_receive(test_client_factory: TestClientFactory) -> None: +def test_websocket_blocking_receive() -> None: def app(scope: Scope) -> ASGIInstance: async def respond(websocket: WebSocket) -> None: await websocket.send_json({"message": "test"}) @@ -248,13 +229,13 @@ async def asgi(receive: Receive, send: Send) -> None: return asgi - client = test_client_factory(app) # type: ignore + client = TestClient(app) # type: ignore with client.websocket_connect("/") as websocket: data = websocket.receive_json() assert data == {"message": "test"} -def test_websocket_not_block_on_close(test_client_factory: TestClientFactory) -> None: +def test_websocket_not_block_on_close() -> None: cancelled = False def app(scope: Scope) -> ASGIInstance: @@ -270,13 +251,13 @@ async def asgi(receive: Receive, send: Send) -> None: return asgi - client = test_client_factory(app) # type: ignore + client = TestClient(app) # type: ignore with client.websocket_connect("/"): ... assert cancelled -def test_client(test_client_factory: TestClientFactory) -> None: +def test_client() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: client = scope.get("client") assert client is not None @@ -284,12 +265,12 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response = JSONResponse({"host": host, "port": port}) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") assert response.json() == {"host": "testclient", "port": 50000} -def test_client_custom_client(test_client_factory: TestClientFactory) -> None: +def test_client_custom_client() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: client = scope.get("client") assert client is not None @@ -297,18 +278,18 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response = JSONResponse({"host": host, "port": port}) await response(scope, receive, send) - client = test_client_factory(app, client=("192.168.0.1", 3000)) + client = TestClient(app, client=("192.168.0.1", 3000)) response = client.get("/") assert response.json() == {"host": "192.168.0.1", "port": 3000} @pytest.mark.parametrize("param", ("2020-07-14T00:00:00+00:00", "España", "voilà")) -def test_query_params(test_client_factory: TestClientFactory, param: str) -> None: +def test_query_params(param: str) -> None: def homepage(request: Request) -> Response: return Response(request.query_params["param"]) app = Starlette(routes=[Route("/", endpoint=homepage)]) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/", params={"param": param}) assert response.text == param @@ -331,7 +312,7 @@ def homepage(request: Request) -> Response: ("example.com", False), ], ) -def test_domain_restricted_cookies(test_client_factory: TestClientFactory, domain: str, ok: bool) -> None: +def test_domain_restricted_cookies(domain: str, ok: bool) -> None: """ Test that test client discards domain restricted cookies which do not match the base_url of the testclient (`http://testserver` by default). @@ -351,13 +332,13 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: ) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/") cookie_set = len(response.cookies) == 1 assert cookie_set == ok -def test_forward_follow_redirects(test_client_factory: TestClientFactory) -> None: +def test_forward_follow_redirects() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: if "/ok" in scope["path"]: response = Response("ok") @@ -365,52 +346,52 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response = RedirectResponse("/ok") await response(scope, receive, send) - client = test_client_factory(app, follow_redirects=True) + client = TestClient(app, follow_redirects=True) response = client.get("/") assert response.status_code == 200 -def test_forward_nofollow_redirects(test_client_factory: TestClientFactory) -> None: +def test_forward_nofollow_redirects() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: response = RedirectResponse("/ok") await response(scope, receive, send) - client = test_client_factory(app, follow_redirects=False) + client = TestClient(app, follow_redirects=False) response = client.get("/") assert response.status_code == 307 -def test_with_duplicate_headers(test_client_factory: TestClientFactory) -> None: +def test_with_duplicate_headers() -> None: def homepage(request: Request) -> JSONResponse: return JSONResponse({"x-token": request.headers.getlist("x-token")}) app = Starlette(routes=[Route("/", endpoint=homepage)]) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/", headers=[("x-token", "foo"), ("x-token", "bar")]) assert response.json() == {"x-token": ["foo", "bar"]} -def test_merge_url(test_client_factory: TestClientFactory) -> None: +def test_merge_url() -> None: def homepage(request: Request) -> Response: return Response(request.url.path) app = Starlette(routes=[Route("/api/v1/bar", endpoint=homepage)]) - client = test_client_factory(app, base_url="http://testserver/api/v1/") + client = TestClient(app, base_url="http://testserver/api/v1/") response = client.get("/bar") assert response.text == "/api/v1/bar" -def test_raw_path_with_querystring(test_client_factory: TestClientFactory) -> None: +def test_raw_path_with_querystring() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response(scope.get("raw_path")) await response(scope, receive, send) - client = test_client_factory(app) + client = TestClient(app) response = client.get("/hello-world", params={"foo": "bar"}) assert response.content == b"/hello-world" -def test_websocket_raw_path_without_params(test_client_factory: TestClientFactory) -> None: +def test_websocket_raw_path_without_params() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -418,7 +399,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert raw_path is not None await websocket.send_bytes(raw_path) - client = test_client_factory(app) + client = TestClient(app) with client.websocket_connect("/hello-world", params={"foo": "bar"}) as websocket: data = websocket.receive_bytes() assert data == b"/hello-world" diff --git a/tests/test_websockets.py b/tests/test_websockets.py index e76d8f29b..18cb71707 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -8,26 +8,25 @@ from starlette import status from starlette.responses import Response -from starlette.testclient import WebSocketDenialResponse +from starlette.testclient import TestClient, WebSocketDenialResponse from starlette.types import Message, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState -from tests.types import TestClientFactory -def test_websocket_url(test_client_factory: TestClientFactory) -> None: +def test_websocket_url() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.send_json({"url": str(websocket.url)}) await websocket.close() - client = test_client_factory(app) + client = TestClient(app) with client.websocket_connect("/123?a=abc") as websocket: data = websocket.receive_json() assert data == {"url": "ws://testserver/123?a=abc"} -def test_websocket_binary_json(test_client_factory: TestClientFactory) -> None: +def test_websocket_binary_json() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -35,16 +34,14 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: await websocket.send_json(message, mode="binary") await websocket.close() - client = test_client_factory(app) + client = TestClient(app) with client.websocket_connect("/123?a=abc") as websocket: websocket.send_json({"test": "data"}, mode="binary") data = websocket.receive_json(mode="binary") assert data == {"test": "data"} -def test_websocket_ensure_unicode_on_send_json( - test_client_factory: TestClientFactory, -) -> None: +def test_websocket_ensure_unicode_on_send_json() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) @@ -53,14 +50,14 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: await websocket.send_json(message, mode="text") await websocket.close() - client = test_client_factory(app) + client = TestClient(app) with client.websocket_connect("/123?a=abc") as websocket: websocket.send_json({"test": "数据"}, mode="text") data = websocket.receive_text() assert data == '{"test":"数据"}' -def test_websocket_query_params(test_client_factory: TestClientFactory) -> None: +def test_websocket_query_params() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) query_params = dict(websocket.query_params) @@ -68,7 +65,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: await websocket.send_json({"params": query_params}) await websocket.close() - client = test_client_factory(app) + client = TestClient(app) with client.websocket_connect("/?a=abc&b=456") as websocket: data = websocket.receive_json() assert data == {"params": {"a": "abc", "b": "456"}} @@ -78,7 +75,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: any(module in sys.modules for module in ("brotli", "brotlicffi")), reason='urllib3 includes "br" to the "accept-encoding" headers.', ) -def test_websocket_headers(test_client_factory: TestClientFactory) -> None: +def test_websocket_headers() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) headers = dict(websocket.headers) @@ -86,7 +83,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: await websocket.send_json({"headers": headers}) await websocket.close() - client = test_client_factory(app) + client = TestClient(app) with client.websocket_connect("/") as websocket: expected_headers = { "accept": "*/*", @@ -101,22 +98,20 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert data == {"headers": expected_headers} -def test_websocket_port(test_client_factory: TestClientFactory) -> None: +def test_websocket_port() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.send_json({"port": websocket.url.port}) await websocket.close() - client = test_client_factory(app) + client = TestClient(app) with client.websocket_connect("ws://example.com:123/123?a=abc") as websocket: data = websocket.receive_json() assert data == {"port": 123} -def test_websocket_send_and_receive_text( - test_client_factory: TestClientFactory, -) -> None: +def test_websocket_send_and_receive_text() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -124,16 +119,14 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: await websocket.send_text("Message was: " + data) await websocket.close() - client = test_client_factory(app) + client = TestClient(app) with client.websocket_connect("/") as websocket: websocket.send_text("Hello, world!") data = websocket.receive_text() assert data == "Message was: Hello, world!" -def test_websocket_send_and_receive_bytes( - test_client_factory: TestClientFactory, -) -> None: +def test_websocket_send_and_receive_bytes() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -141,16 +134,14 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: await websocket.send_bytes(b"Message was: " + data) await websocket.close() - client = test_client_factory(app) + client = TestClient(app) with client.websocket_connect("/") as websocket: websocket.send_bytes(b"Hello, world!") data = websocket.receive_bytes() assert data == b"Message was: Hello, world!" -def test_websocket_send_and_receive_json( - test_client_factory: TestClientFactory, -) -> None: +def test_websocket_send_and_receive_json() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -158,56 +149,56 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: await websocket.send_json({"message": data}) await websocket.close() - client = test_client_factory(app) + client = TestClient(app) with client.websocket_connect("/") as websocket: websocket.send_json({"hello": "world"}) data = websocket.receive_json() assert data == {"message": {"hello": "world"}} -def test_websocket_iter_text(test_client_factory: TestClientFactory) -> None: +def test_websocket_iter_text() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() async for data in websocket.iter_text(): await websocket.send_text("Message was: " + data) - client = test_client_factory(app) + client = TestClient(app) with client.websocket_connect("/") as websocket: websocket.send_text("Hello, world!") data = websocket.receive_text() assert data == "Message was: Hello, world!" -def test_websocket_iter_bytes(test_client_factory: TestClientFactory) -> None: +def test_websocket_iter_bytes() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() async for data in websocket.iter_bytes(): await websocket.send_bytes(b"Message was: " + data) - client = test_client_factory(app) + client = TestClient(app) with client.websocket_connect("/") as websocket: websocket.send_bytes(b"Hello, world!") data = websocket.receive_bytes() assert data == b"Message was: Hello, world!" -def test_websocket_iter_json(test_client_factory: TestClientFactory) -> None: +def test_websocket_iter_json() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() async for data in websocket.iter_json(): await websocket.send_json({"message": data}) - client = test_client_factory(app) + client = TestClient(app) with client.websocket_connect("/") as websocket: websocket.send_json({"hello": "world"}) data = websocket.receive_json() assert data == {"message": {"hello": "world"}} -def test_websocket_concurrency_pattern(test_client_factory: TestClientFactory) -> None: +def test_websocket_concurrency_pattern() -> None: stream_send: ObjectSendStream[MutableMapping[str, Any]] stream_receive: ObjectReceiveStream[MutableMapping[str, Any]] stream_send, stream_receive = anyio.create_memory_object_stream() @@ -230,14 +221,14 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: await writer(websocket) await websocket.close() - client = test_client_factory(app) + client = TestClient(app) with client.websocket_connect("/") as websocket: websocket.send_json({"hello": "world"}) data = websocket.receive_json() assert data == {"hello": "world"} -def test_client_close(test_client_factory: TestClientFactory) -> None: +def test_client_close() -> None: close_code = None close_reason = None @@ -251,7 +242,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: close_code = exc.code close_reason = exc.reason - client = test_client_factory(app) + client = TestClient(app) with client.websocket_connect("/") as websocket: websocket.close(code=status.WS_1001_GOING_AWAY, reason="Going Away") assert close_code == status.WS_1001_GOING_AWAY @@ -279,34 +270,34 @@ async def send(message: Message) -> None: assert ctx.value.code == status.WS_1006_ABNORMAL_CLOSURE -def test_application_close(test_client_factory: TestClientFactory) -> None: +def test_application_close() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.close(status.WS_1001_GOING_AWAY) - client = test_client_factory(app) + client = TestClient(app) with client.websocket_connect("/") as websocket: with pytest.raises(WebSocketDisconnect) as exc: websocket.receive_text() assert exc.value.code == status.WS_1001_GOING_AWAY -def test_rejected_connection(test_client_factory: TestClientFactory) -> None: +def test_rejected_connection() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) msg = await websocket.receive() assert msg == {"type": "websocket.connect"} await websocket.close(status.WS_1001_GOING_AWAY) - client = test_client_factory(app) + client = TestClient(app) with pytest.raises(WebSocketDisconnect) as exc: with client.websocket_connect("/"): pass # pragma: no cover assert exc.value.code == status.WS_1001_GOING_AWAY -def test_send_denial_response(test_client_factory: TestClientFactory) -> None: +def test_send_denial_response() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) msg = await websocket.receive() @@ -314,7 +305,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response(status_code=404, content="foo") await websocket.send_denial_response(response) - client = test_client_factory(app) + client = TestClient(app) with pytest.raises(WebSocketDenialResponse) as exc: with client.websocket_connect("/"): pass # pragma: no cover @@ -322,7 +313,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert exc.value.content == b"foo" -def test_send_response_multi(test_client_factory: TestClientFactory) -> None: +def test_send_response_multi() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) msg = await websocket.receive() @@ -337,7 +328,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: await websocket.send({"type": "websocket.http.response.body", "body": b"hard", "more_body": True}) await websocket.send({"type": "websocket.http.response.body", "body": b"body"}) - client = test_client_factory(app) + client = TestClient(app) with pytest.raises(WebSocketDenialResponse) as exc: with client.websocket_connect("/"): pass # pragma: no cover @@ -346,7 +337,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert exc.value.headers["foo"] == "bar" -def test_send_response_unsupported(test_client_factory: TestClientFactory) -> None: +def test_send_response_unsupported() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: del scope["extensions"]["websocket.http.response"] websocket = WebSocket(scope, receive=receive, send=send) @@ -360,14 +351,14 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: await websocket.send_denial_response(response) await websocket.close() - client = test_client_factory(app) + client = TestClient(app) with pytest.raises(WebSocketDisconnect) as exc: with client.websocket_connect("/"): pass # pragma: no cover assert exc.value.code == status.WS_1000_NORMAL_CLOSURE -def test_send_response_duplicate_start(test_client_factory: TestClientFactory) -> None: +def test_send_response_duplicate_start() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) msg = await websocket.receive() @@ -388,7 +379,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: } ) - client = test_client_factory(app) + client = TestClient(app) with pytest.raises( RuntimeError, match=("Expected ASGI message \"websocket.http.response.body\", but got 'websocket.http.response.start'"), @@ -397,64 +388,64 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: pass # pragma: no cover -def test_subprotocol(test_client_factory: TestClientFactory) -> None: +def test_subprotocol() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) assert websocket["subprotocols"] == ["soap", "wamp"] await websocket.accept(subprotocol="wamp") await websocket.close() - client = test_client_factory(app) + client = TestClient(app) with client.websocket_connect("/", subprotocols=["soap", "wamp"]) as websocket: assert websocket.accepted_subprotocol == "wamp" -def test_additional_headers(test_client_factory: TestClientFactory) -> None: +def test_additional_headers() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept(headers=[(b"additional", b"header")]) await websocket.close() - client = test_client_factory(app) + client = TestClient(app) with client.websocket_connect("/") as websocket: assert websocket.extra_headers == [(b"additional", b"header")] -def test_no_additional_headers(test_client_factory: TestClientFactory) -> None: +def test_no_additional_headers() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.close() - client = test_client_factory(app) + client = TestClient(app) with client.websocket_connect("/") as websocket: assert websocket.extra_headers == [] -def test_websocket_exception(test_client_factory: TestClientFactory) -> None: +def test_websocket_exception() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: assert False - client = test_client_factory(app) + client = TestClient(app) with pytest.raises(AssertionError): with client.websocket_connect("/123?a=abc"): pass # pragma: no cover -def test_duplicate_close(test_client_factory: TestClientFactory) -> None: +def test_duplicate_close() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.close() await websocket.close() - client = test_client_factory(app) + client = TestClient(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): pass # pragma: no cover -def test_duplicate_disconnect(test_client_factory: TestClientFactory) -> None: +def test_duplicate_disconnect() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -462,7 +453,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert message["type"] == "websocket.disconnect" message = await websocket.receive() - client = test_client_factory(app) + client = TestClient(app) with pytest.raises(RuntimeError): with client.websocket_connect("/") as websocket: websocket.close() @@ -495,13 +486,13 @@ async def mock_send(message: Message) -> None: ... # pragma: no cover assert {websocket} == {websocket} -def test_websocket_close_reason(test_client_factory: TestClientFactory) -> None: +def test_websocket_close_reason() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.close(code=status.WS_1001_GOING_AWAY, reason="Going Away") - client = test_client_factory(app) + client = TestClient(app) with client.websocket_connect("/") as websocket: with pytest.raises(WebSocketDisconnect) as exc: websocket.receive_text() @@ -509,106 +500,106 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert exc.value.reason == "Going Away" -def test_send_json_invalid_mode(test_client_factory: TestClientFactory) -> None: +def test_send_json_invalid_mode() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.send_json({}, mode="invalid") - client = test_client_factory(app) + client = TestClient(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): pass # pragma: no cover -def test_receive_json_invalid_mode(test_client_factory: TestClientFactory) -> None: +def test_receive_json_invalid_mode() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.receive_json(mode="invalid") - client = test_client_factory(app) + client = TestClient(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): pass # pragma: no cover -def test_receive_text_before_accept(test_client_factory: TestClientFactory) -> None: +def test_receive_text_before_accept() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.receive_text() - client = test_client_factory(app) + client = TestClient(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): pass # pragma: no cover -def test_receive_bytes_before_accept(test_client_factory: TestClientFactory) -> None: +def test_receive_bytes_before_accept() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.receive_bytes() - client = test_client_factory(app) + client = TestClient(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): pass # pragma: no cover -def test_receive_json_before_accept(test_client_factory: TestClientFactory) -> None: +def test_receive_json_before_accept() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.receive_json() - client = test_client_factory(app) + client = TestClient(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): pass # pragma: no cover -def test_send_before_accept(test_client_factory: TestClientFactory) -> None: +def test_send_before_accept() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.send({"type": "websocket.send"}) - client = test_client_factory(app) + client = TestClient(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): pass # pragma: no cover -def test_send_wrong_message_type(test_client_factory: TestClientFactory) -> None: +def test_send_wrong_message_type() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.send({"type": "websocket.accept"}) await websocket.send({"type": "websocket.accept"}) - client = test_client_factory(app) + client = TestClient(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): pass # pragma: no cover -def test_receive_before_accept(test_client_factory: TestClientFactory) -> None: +def test_receive_before_accept() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() websocket.client_state = WebSocketState.CONNECTING await websocket.receive() - client = test_client_factory(app) + client = TestClient(app) with pytest.raises(RuntimeError): with client.websocket_connect("/") as websocket: websocket.send({"type": "websocket.send"}) -def test_receive_wrong_message_type(test_client_factory: TestClientFactory) -> None: +def test_receive_wrong_message_type() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.receive() - client = test_client_factory(app) + client = TestClient(app) with pytest.raises(RuntimeError): with client.websocket_connect("/") as websocket: websocket.send({"type": "websocket.connect"}) diff --git a/tests/types.py b/tests/types.py deleted file mode 100644 index e4769d308..000000000 --- a/tests/types.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Protocol - -import httpx - -from starlette.testclient import TestClient -from starlette.types import ASGIApp - -if TYPE_CHECKING: - - class TestClientFactory(Protocol): # pragma: no cover - def __call__( - self, - app: ASGIApp, - base_url: str = "http://testserver", - raise_server_exceptions: bool = True, - root_path: str = "", - cookies: httpx._types.CookieTypes | None = None, - headers: dict[str, str] | None = None, - follow_redirects: bool = True, - client: tuple[str, int] = ("testclient", 50000), - ) -> TestClient: ... -else: # pragma: no cover - - class TestClientFactory: - __test__ = False