Skip to content

Commit

Permalink
Restore zero copy writes on Python 3.12.9+/3.13.2+ (#10137)
Browse files Browse the repository at this point in the history
Co-authored-by: 🇺🇦 Sviatoslav Sydorenko (Святослав Сидоренко) <[email protected]>
(cherry picked from commit 25c7f23)
  • Loading branch information
bdraco committed Feb 5, 2025
1 parent 6709f53 commit 0bb7b5b
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 5 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci-cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,11 @@ jobs:
uses: actions/checkout@v4
with:
submodules: true
- name: Setup Python 3.13
- name: Setup Python 3.13.2
id: python-install
uses: actions/setup-python@v5
with:
python-version: 3.13
python-version: 3.13.2
cache: pip
cache-dependency-path: requirements/*.txt
- name: Update pip, wheel, setuptools, build, twine
Expand Down
3 changes: 3 additions & 0 deletions CHANGES/10137.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Restored support for zero copy writes when using Python 3.12 versions 3.12.9 and later or Python 3.13.2+ -- by :user:`bdraco`.

Zero copy writes were previously disabled due to :cve:`2024-12254` which is resolved in these Python versions.
17 changes: 16 additions & 1 deletion aiohttp/http_writer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Http related parsers and protocol."""

import asyncio
import sys
import zlib
from typing import ( # noqa
Any,
Expand All @@ -24,6 +25,17 @@
__all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11")


MIN_PAYLOAD_FOR_WRITELINES = 2048
IS_PY313_BEFORE_313_2 = (3, 13, 0) <= sys.version_info < (3, 13, 2)
IS_PY_BEFORE_312_9 = sys.version_info < (3, 12, 9)
SKIP_WRITELINES = IS_PY313_BEFORE_313_2 or IS_PY_BEFORE_312_9
# writelines is not safe for use
# on Python 3.12+ until 3.12.9
# on Python 3.13+ until 3.13.2
# and on older versions it not any faster than write
# CVE-2024-12254: https://github.com/python/cpython/pull/127656


class HttpVersion(NamedTuple):
major: int
minor: int
Expand Down Expand Up @@ -90,7 +102,10 @@ def _writelines(self, chunks: Iterable[bytes]) -> None:
transport = self._protocol.transport
if transport is None or transport.is_closing():
raise ClientConnectionResetError("Cannot write to closing transport")
transport.write(b"".join(chunks))
if SKIP_WRITELINES or size < MIN_PAYLOAD_FOR_WRITELINES:
transport.write(b"".join(chunks))
else:
transport.writelines(chunks)

async def write(
self,
Expand Down
111 changes: 109 additions & 2 deletions tests/test_http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import array
import asyncio
import zlib
from typing import Iterable
from typing import Generator, Iterable
from unittest import mock

import pytest
Expand All @@ -14,7 +14,19 @@


@pytest.fixture
def buf():
def enable_writelines() -> Generator[None, None, None]:
with mock.patch("aiohttp.http_writer.SKIP_WRITELINES", False):
yield


@pytest.fixture
def force_writelines_small_payloads() -> Generator[None, None, None]:
with mock.patch("aiohttp.http_writer.MIN_PAYLOAD_FOR_WRITELINES", 1):
yield


@pytest.fixture
def buf() -> bytearray:
return bytearray()


Expand Down Expand Up @@ -117,6 +129,33 @@ async def test_write_large_payload_deflate_compression_data_in_eof(
assert zlib.decompress(content) == (b"data" * 4096) + payload


@pytest.mark.usefixtures("enable_writelines")
async def test_write_large_payload_deflate_compression_data_in_eof_writelines(
protocol: BaseProtocol,
transport: asyncio.Transport,
loop: asyncio.AbstractEventLoop,
) -> None:
msg = http.StreamWriter(protocol, loop)
msg.enable_compression("deflate")

await msg.write(b"data" * 4096)
assert transport.write.called # type: ignore[attr-defined]
chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined]
transport.write.reset_mock() # type: ignore[attr-defined]
assert not transport.writelines.called # type: ignore[attr-defined]

# This payload compresses to 20447 bytes
payload = b"".join(
[bytes((*range(0, i), *range(i, 0, -1))) for i in range(255) for _ in range(64)]
)
await msg.write_eof(payload)
assert not transport.write.called # type: ignore[attr-defined]
assert transport.writelines.called # type: ignore[attr-defined]
chunks.extend(transport.writelines.mock_calls[0][1][0]) # type: ignore[attr-defined]
content = b"".join(chunks)
assert zlib.decompress(content) == (b"data" * 4096) + payload


async def test_write_payload_chunked_filter(
protocol: BaseProtocol,
transport: asyncio.Transport,
Expand Down Expand Up @@ -185,6 +224,26 @@ async def test_write_payload_deflate_compression_chunked(
assert content == expected


@pytest.mark.usefixtures("enable_writelines")
@pytest.mark.usefixtures("force_writelines_small_payloads")
async def test_write_payload_deflate_compression_chunked_writelines(
protocol: BaseProtocol,
transport: asyncio.Transport,
loop: asyncio.AbstractEventLoop,
) -> None:
expected = b"2\r\nx\x9c\r\na\r\nKI,I\x04\x00\x04\x00\x01\x9b\r\n0\r\n\r\n"
msg = http.StreamWriter(protocol, loop)
msg.enable_compression("deflate")
msg.enable_chunking()
await msg.write(b"data")
await msg.write_eof()

chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined]
assert all(chunks)
content = b"".join(chunks)
assert content == expected


async def test_write_payload_deflate_and_chunked(
buf: bytearray,
protocol: BaseProtocol,
Expand Down Expand Up @@ -221,6 +280,26 @@ async def test_write_payload_deflate_compression_chunked_data_in_eof(
assert content == expected


@pytest.mark.usefixtures("enable_writelines")
@pytest.mark.usefixtures("force_writelines_small_payloads")
async def test_write_payload_deflate_compression_chunked_data_in_eof_writelines(
protocol: BaseProtocol,
transport: asyncio.Transport,
loop: asyncio.AbstractEventLoop,
) -> None:
expected = b"2\r\nx\x9c\r\nd\r\nKI,IL\xcdK\x01\x00\x0b@\x02\xd2\r\n0\r\n\r\n"
msg = http.StreamWriter(protocol, loop)
msg.enable_compression("deflate")
msg.enable_chunking()
await msg.write(b"data")
await msg.write_eof(b"end")

chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined]
assert all(chunks)
content = b"".join(chunks)
assert content == expected


async def test_write_large_payload_deflate_compression_chunked_data_in_eof(
protocol: BaseProtocol,
transport: asyncio.Transport,
Expand All @@ -247,6 +326,34 @@ async def test_write_large_payload_deflate_compression_chunked_data_in_eof(
assert zlib.decompress(content) == (b"data" * 4096) + payload


@pytest.mark.usefixtures("enable_writelines")
@pytest.mark.usefixtures("force_writelines_small_payloads")
async def test_write_large_payload_deflate_compression_chunked_data_in_eof_writelines(
protocol: BaseProtocol,
transport: asyncio.Transport,
loop: asyncio.AbstractEventLoop,
) -> None:
msg = http.StreamWriter(protocol, loop)
msg.enable_compression("deflate")
msg.enable_chunking()

await msg.write(b"data" * 4096)
# This payload compresses to 1111 bytes
payload = b"".join([bytes((*range(0, i), *range(i, 0, -1))) for i in range(255)])
await msg.write_eof(payload)
assert not transport.write.called # type: ignore[attr-defined]

chunks = []
for write_lines_call in transport.writelines.mock_calls: # type: ignore[attr-defined]
chunked_payload = list(write_lines_call[1][0])[1:]
chunked_payload.pop()
chunks.extend(chunked_payload)

assert all(chunks)
content = b"".join(chunks)
assert zlib.decompress(content) == (b"data" * 4096) + payload


async def test_write_payload_deflate_compression_chunked_connection_lost(
protocol: BaseProtocol,
transport: asyncio.Transport,
Expand Down

0 comments on commit 0bb7b5b

Please sign in to comment.