diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 393deb4..68562d3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,10 +15,11 @@ jobs: fail-fast: false matrix: python: - - '3.7' - '3.8' - '3.9' - '3.10' + - '3.11' + - '3.12' check_formatting: ['0'] check_typing: ['0'] runtime_only: ['0'] diff --git a/README.rst b/README.rst index 0dac61a..f3b7ec7 100644 --- a/README.rst +++ b/README.rst @@ -70,7 +70,7 @@ Install trio-typing with mypy extras:: pip install trio-typing[mypy] -Note that due to recent plugin API changes, trio-typing 0.7.0+ requires mypy 0.920+. +Note that due to recent plugin API changes, trio-typing 0.10.0+ requires mypy 1.0+. Enable the plugin in your ``mypy.ini``:: @@ -129,10 +129,6 @@ The ``trio_typing`` package provides: The ``trio_typing.plugin`` mypy plugin provides: -* Argument type checking for functions decorated with - ``@asynccontextmanager`` (either the one in ``async_generator`` or the - one in 3.7+ ``contextlib``) and ``@async_generator`` - * Inference of more specific ``trio.open_file()`` and ``trio.Path.open()`` return types based on constant ``mode`` and ``buffering`` arguments, so ``await trio.open_file("foo", "rb", 0)`` returns an unbuffered async diff --git a/allowlist.txt b/allowlist.txt index 4b7602e..cac67b9 100644 --- a/allowlist.txt +++ b/allowlist.txt @@ -173,6 +173,8 @@ trio.Process.__aenter__ .*_AttrsAttributes__ .*__attrs_own_setattr__ .*__attrs_post_init__ +.*_AT +.*__slots__ # Probably invalid __match_args__ trio.MemoryReceiveChannel.__match_args__ diff --git a/async_generator-stubs/__init__.pyi b/async_generator-stubs/__init__.pyi index 80ed00e..f7d1566 100644 --- a/async_generator-stubs/__init__.pyi +++ b/async_generator-stubs/__init__.pyi @@ -14,16 +14,16 @@ from typing import ( overload, ) from trio_typing import AsyncGenerator, CompatAsyncGenerator, YieldType, SendType -from typing_extensions import Protocol +from typing_extensions import Protocol, ParamSpec _T = TypeVar("_T") +_P = ParamSpec("_P") -# The returned async generator's YieldType and SendType and the -# argument types of the decorated function get inferred by +# The returned async generator's YieldType and SendType get inferred by # trio_typing.plugin def async_generator( - __fn: Callable[..., Awaitable[_T]] -) -> Callable[..., CompatAsyncGenerator[Any, Any, _T]]: ... + __fn: Callable[_P, Awaitable[_T]] +) -> Callable[_P, CompatAsyncGenerator[Any, Any, _T]]: ... # The return type and a more specific argument type can be # inferred by trio_typing.plugin, based on the enclosing @@ -40,12 +40,9 @@ async def yield_from_(agen: AsyncGenerator[Any, Any]) -> None: ... async def yield_from_(agen: AsyncIterable[Any]) -> None: ... def isasyncgen(obj: object) -> bool: ... def isasyncgenfunction(obj: object) -> bool: ... - -# Argument types of the decorated function get inferred by -# trio_typing.plugin def asynccontextmanager( - fn: Callable[..., AsyncIterator[_T]] -) -> Callable[..., AsyncContextManager[_T]]: ... + fn: Callable[_P, AsyncIterator[_T]] +) -> Callable[_P, AsyncContextManager[_T]]: ... class _AsyncCloseable(Protocol): def aclose(self) -> Awaitable[None]: ... diff --git a/ci.sh b/ci.sh index c055193..62a882d 100755 --- a/ci.sh +++ b/ci.sh @@ -3,7 +3,7 @@ set -ex -o pipefail BLACK_VERSION=22.3 -MYPY_VERSION=1.4 +MYPY_VERSION=1.7 pip install -U pip setuptools wheel diff --git a/outcome-stubs/__init__.pyi b/outcome-stubs/__init__.pyi index 54049ef..5642023 100644 --- a/outcome-stubs/__init__.pyi +++ b/outcome-stubs/__init__.pyi @@ -11,12 +11,13 @@ from typing import ( Union, ) from types import TracebackType -from typing_extensions import Protocol +from typing_extensions import Protocol, ParamSpec T = TypeVar("T") U = TypeVar("U") T_co = TypeVar("T_co", covariant=True) T_contra = TypeVar("T_contra", contravariant=True) +P = ParamSpec("P") # Can't use AsyncGenerator as it creates a dependency cycle # (outcome stubs -> trio_typing stubs -> trio.hazmat stubs -> outcome) @@ -47,9 +48,9 @@ class Error: Outcome = Union[Value[T], Error] -# TODO: narrower typing for these (the args and kwargs should -# be acceptable to the callable) -def capture(sync_fn: Callable[..., T], *args: Any, **kwargs: Any) -> Outcome[T]: ... +def capture( + sync_fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs +) -> Outcome[T]: ... async def acapture( - async_fn: Callable[..., Awaitable[T]], *args: Any, **kwargs: Any + async_fn: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs ) -> Outcome[T]: ... diff --git a/setup.py b/setup.py index d636075..70a6c9d 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ ], extras_require={ "mypy": [ # can't be installed on PyPy due to its dependency on typed-ast - "mypy >= 0.920", + "mypy >= 1.0", ], }, keywords=["async", "trio", "mypy"], @@ -42,8 +42,11 @@ "Operating System :: POSIX :: BSD", "Operating System :: Microsoft :: Windows", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.5", - "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", "Development Status :: 3 - Alpha", diff --git a/trio-stubs/__init__.pyi b/trio-stubs/__init__.pyi index ecc04e9..b2c66ab 100644 --- a/trio-stubs/__init__.pyi +++ b/trio-stubs/__init__.pyi @@ -30,7 +30,7 @@ from types import TracebackType from _typeshed import StrOrBytesPath from _typeshed import OpenBinaryMode, OpenTextMode, ReadableBuffer, WriteableBuffer from trio_typing import TaskStatus, takes_callable_and_args -from typing_extensions import Protocol, Literal +from typing_extensions import Protocol, Literal, Buffer from mypy_extensions import NamedArg, VarArg import signal import io @@ -48,9 +48,6 @@ _T = TypeVar("_T") _T_co = TypeVar("_T_co", covariant=True) _T_contra = TypeVar("_T_contra", contravariant=True) -class _Statistics: - def __getattr__(self, name: str) -> Any: ... - # Inheriting from this (even outside of stubs) produces a class that # mypy thinks is abstract, but the interpreter thinks is concrete. class _NotConstructible(Protocol): @@ -208,13 +205,24 @@ class TooSlowError(Exception): pass # _sync +@attr.s(frozen=True, slots=True) +class EventStatistics: + tasks_waiting: int = attr.ib() + @final @attr.s(eq=False, repr=False, slots=True) class Event(metaclass=ABCMeta): def is_set(self) -> bool: ... def set(self) -> None: ... async def wait(self) -> None: ... - def statistics(self) -> _Statistics: ... + def statistics(self) -> EventStatistics: ... + +@attr.s(frozen=True, slots=True) +class CapacityLimiterStatistics: + borrowed_tokens: int = attr.ib() + total_tokens: int | float = attr.ib() + borrowers: list[trio.lowlevel.Task | object] = attr.ib() + tasks_waiting: int = attr.ib() @final class CapacityLimiter(metaclass=ABCMeta): @@ -232,9 +240,14 @@ class CapacityLimiter(metaclass=ABCMeta): async def acquire_on_behalf_of(self, borrower: object) -> None: ... def release(self) -> None: ... def release_on_behalf_of(self, borrower: object) -> None: ... - def statistics(self) -> _Statistics: ... + def statistics(self) -> CapacityLimiterStatistics: ... async def __aenter__(self) -> None: ... - async def __aexit__(self, *exc: object) -> None: ... + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: ... @final class Semaphore(metaclass=ABCMeta): @@ -246,9 +259,20 @@ class Semaphore(metaclass=ABCMeta): def acquire_nowait(self) -> None: ... async def acquire(self) -> None: ... def release(self) -> None: ... - def statistics(self) -> _Statistics: ... + def statistics(self) -> lowlevel.ParkingLotStatistics: ... async def __aenter__(self) -> None: ... - async def __aexit__(self, *exc: object) -> None: ... + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: ... + +@attr.s(frozen=True, slots=True) +class LockStatistics: + locked: bool = attr.ib() + owner: trio.lowlevel.Task | None = attr.ib() + tasks_waiting: int = attr.ib() @final class Lock(metaclass=ABCMeta): @@ -256,9 +280,14 @@ class Lock(metaclass=ABCMeta): def acquire_nowait(self) -> None: ... async def acquire(self) -> None: ... def release(self) -> None: ... - def statistics(self) -> _Statistics: ... + def statistics(self) -> LockStatistics: ... async def __aenter__(self) -> None: ... - async def __aexit__(self, *exc: object) -> None: ... + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: ... @final class StrictFIFOLock(metaclass=ABCMeta): @@ -266,9 +295,19 @@ class StrictFIFOLock(metaclass=ABCMeta): def acquire_nowait(self) -> None: ... async def acquire(self) -> None: ... def release(self) -> None: ... - def statistics(self) -> _Statistics: ... + def statistics(self) -> LockStatistics: ... async def __aenter__(self) -> None: ... - async def __aexit__(self, *exc: object) -> None: ... + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: ... + +@attr.s(frozen=True, slots=True) +class ConditionStatistics: + tasks_waiting: int = attr.ib() + lock_statistics: LockStatistics = attr.ib() @final class Condition(metaclass=ABCMeta): @@ -280,9 +319,14 @@ class Condition(metaclass=ABCMeta): async def wait(self) -> None: ... def notify(self, n: int = 1) -> None: ... def notify_all(self) -> None: ... - def statistics(self) -> _Statistics: ... + def statistics(self) -> ConditionStatistics: ... async def __aenter__(self) -> None: ... - async def __aexit__(self, *exc: object) -> None: ... + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: ... # _highlevel_generic async def aclose_forcefully(resource: trio.abc.AsyncResource) -> None: ... @@ -298,6 +342,15 @@ class StapledStream(trio.abc.HalfCloseableStream): async def send_eof(self) -> None: ... # _channel +@attr.s(frozen=True, slots=True) +class _MemoryChannelStats: + current_buffer_used: int = attr.ib() + max_buffer_size: int | float = attr.ib() + open_send_channels: int = attr.ib() + open_receive_channels: int = attr.ib() + tasks_waiting_send: int = attr.ib() + tasks_waiting_receive: int = attr.ib() + @final @attr.s(eq=False, repr=False) class MemorySendChannel(trio.abc.SendChannel[_T_contra]): @@ -305,7 +358,7 @@ class MemorySendChannel(trio.abc.SendChannel[_T_contra]): async def send(self, value: _T_contra) -> None: ... def clone(self: _T) -> _T: ... async def aclose(self) -> None: ... - def statistics(self) -> _Statistics: ... + def statistics(self) -> _MemoryChannelStats: ... def close(self) -> None: ... def __enter__(self) -> MemorySendChannel[_T_contra]: ... def __exit__( @@ -322,7 +375,7 @@ class MemoryReceiveChannel(trio.abc.ReceiveChannel[_T_co]): async def receive(self) -> _T_co: ... def clone(self: _T) -> _T: ... async def aclose(self) -> None: ... - def statistics(self) -> _Statistics: ... + def statistics(self) -> _MemoryChannelStats: ... def close(self) -> None: ... def __enter__(self) -> MemoryReceiveChannel[_T_co]: ... def __exit__( @@ -349,7 +402,12 @@ def open_signal_receiver( class SocketStream(trio.abc.HalfCloseableStream): socket: trio.socket.SocketType def __init__(self, socket: trio.socket.SocketType) -> None: ... - def setsockopt(self, level: int, option: int, value: Union[int, bytes]) -> None: ... + @overload + def setsockopt( + self, level: int, option: int, value: int | Buffer, length: None = None + ) -> None: ... + @overload + def setsockopt(self, level: int, option: int, value: None, length: int) -> None: ... @overload def getsockopt(self, level: int, option: int) -> int: ... @overload @@ -400,6 +458,10 @@ class DTLSEndpoint(metaclass=ABCMeta): exc_tb: TracebackType | None, ) -> None: ... +@attr.frozen +class DTLSChannelStatistics: + incoming_packets_dropped_in_trio: int + @final class DTLSChannel(_NotConstructible, trio.abc.Channel[bytes], metaclass=ABCMeta): endpoint: DTLSEndpoint @@ -411,7 +473,7 @@ class DTLSChannel(_NotConstructible, trio.abc.Channel[bytes], metaclass=ABCMeta) async def receive(self) -> bytes: ... def set_ciphertext_mtu(self, new_mtu: int) -> None: ... def get_cleartext_mtu(self) -> int: ... - def statistics(self) -> Any: ... + def statistics(self) -> DTLSChannelStatistics: ... async def aclose(self) -> None: ... def close(self) -> None: ... def __enter__(self) -> DTLSChannel: ... @@ -452,7 +514,12 @@ class AsyncIO(AsyncIterator[AnyStr], Generic[AnyStr], trio.abc.AsyncResource): async def __anext__(self) -> AnyStr: ... def __aiter__(self) -> AsyncIterator[AnyStr]: ... async def __aenter__(self: _T) -> _T: ... - async def __aexit__(self, *exc: object) -> None: ... + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: ... class AsyncBinaryIO(AsyncIO[bytes]): pass diff --git a/trio-stubs/abc.pyi b/trio-stubs/abc.pyi index 10d2eae..8e8034e 100644 --- a/trio-stubs/abc.pyi +++ b/trio-stubs/abc.pyi @@ -1,6 +1,8 @@ +import socket import trio from abc import ABCMeta, abstractmethod from typing import List, Tuple, Union, Any, Optional, Generic, TypeVar, AsyncIterator +from types import TracebackType _T = TypeVar("_T") @@ -43,16 +45,21 @@ class SocketFactory(metaclass=ABCMeta): @abstractmethod def socket( self, - family: Optional[int] = None, - type: Optional[int] = None, - proto: Optional[int] = None, + family: socket.AddressFamily | int = ..., + type: socket.SocketKind | int = ..., + proto: int = ..., ) -> trio.socket.SocketType: ... class AsyncResource(metaclass=ABCMeta): @abstractmethod async def aclose(self) -> None: ... async def __aenter__(self: _T) -> _T: ... - async def __aexit__(self, *exc: object) -> None: ... + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: ... class SendStream(AsyncResource): @abstractmethod diff --git a/trio-stubs/from_thread.pyi b/trio-stubs/from_thread.pyi index 4995b18..b8923d0 100644 --- a/trio-stubs/from_thread.pyi +++ b/trio-stubs/from_thread.pyi @@ -17,3 +17,4 @@ def run_sync( *args: Any, trio_token: Optional[trio.lowlevel.TrioToken] = ..., ) -> _T: ... +def check_cancelled() -> None: ... diff --git a/trio-stubs/lowlevel.pyi b/trio-stubs/lowlevel.pyi index b6d8955..2e45d76 100644 --- a/trio-stubs/lowlevel.pyi +++ b/trio-stubs/lowlevel.pyi @@ -37,9 +37,6 @@ import sys _T = TypeVar("_T") _F = TypeVar("_F", bound=Callable[..., Any]) -class _Statistics: - def __getattr__(self, name: str) -> Any: ... - # _core._ki def enable_ki_protection(fn: _F) -> _F: ... def disable_ki_protection(fn: _F) -> _F: ... @@ -58,6 +55,11 @@ class TrioToken(metaclass=ABCMeta): ) -> None: ... # _core._unbounded_queue +@attr.s(slots=True, frozen=True) +class UnboundedQueueStatistics: + qsize: int = attr.ib() + tasks_waiting: int = attr.ib() + @final class UnboundedQueue(Generic[_T], metaclass=ABCMeta): def __init__(self) -> None: ... @@ -66,11 +68,42 @@ class UnboundedQueue(Generic[_T], metaclass=ABCMeta): def put_nowait(self, obj: _T) -> None: ... def get_batch_nowait(self) -> Sequence[_T]: ... async def get_batch(self) -> Sequence[_T]: ... - def statistics(self) -> _Statistics: ... + def statistics(self) -> UnboundedQueueStatistics: ... def __aiter__(self) -> AsyncIterator[Sequence[_T]]: ... async def __anext__(self) -> Sequence[_T]: ... # _core._run +if sys.platform == "win32": + @attr.frozen + class _IOStatistics: + tasks_waiting_read: int = attr.ib() + tasks_waiting_write: int = attr.ib() + tasks_waiting_overlapped: int = attr.ib() + completion_key_monitors: int = attr.ib() + backend: Literal["windows"] = attr.ib(init=False, default="windows") + +elif sys.platform == "linux": + @attr.frozen + class _IOStatistics: + tasks_waiting_read: int = attr.ib() + tasks_waiting_write: int = attr.ib() + backend: Literal["epoll"] = attr.ib(init=False, default="epoll") + +else: # kqueue + @attr.frozen + class _IOStatistics: + tasks_waiting: int = attr.ib() + monitors: int = attr.ib() + backend: Literal["kqueue"] = attr.ib(init=False, default="kqueue") + +@attr.frozen +class RunStatistics: + tasks_living: int + tasks_runnable: int + seconds_to_next_deadline: float + io_statistics: _IOStatistics + run_sync_soon_queue_size: int + @final @attr.s(eq=False, hash=False, repr=False, slots=True) class Task(metaclass=ABCMeta): @@ -90,7 +123,7 @@ async def checkpoint() -> None: ... async def checkpoint_if_cancelled() -> None: ... def current_task() -> Task: ... def current_root_task() -> Task: ... -def current_statistics() -> _Statistics: ... +def current_statistics() -> RunStatistics: ... def current_clock() -> trio.abc.Clock: ... def current_trio_token() -> TrioToken: ... def reschedule(task: Task, next_send: outcome.Outcome[Any] = ...) -> None: ... @@ -161,6 +194,10 @@ async def temporarily_detach_coroutine_object( async def reattach_detached_coroutine_object(task: Task, yield_value: Any) -> None: ... # _core._parking_lot +@attr.s(frozen=True, slots=True) +class ParkingLotStatistics: + tasks_waiting: int = attr.ib() + @final @attr.s(eq=False, hash=False, slots=True) class ParkingLot(metaclass=ABCMeta): @@ -171,11 +208,17 @@ class ParkingLot(metaclass=ABCMeta): def unpark_all(self) -> Sequence[Task]: ... def repark(self, new_lot: ParkingLot, *, count: int = 1) -> None: ... def repark_all(self, new_lot: ParkingLot) -> None: ... - def statistics(self) -> _Statistics: ... + def statistics(self) -> ParkingLotStatistics: ... # _core._local -class _RunVarToken: - pass +class _NoValue: ... + +@final +@attr.s(eq=False, hash=False, slots=True) +class RunVarToken(Generic[_T], metaclass=ABCMeta): + _var: RunVar[_T] = attr.ib() + previous_value: _T | type[_NoValue] = attr.ib(default=_NoValue) + redeemed: bool = attr.ib(init=False) @final @attr.s(eq=False, hash=False, slots=True) @@ -183,8 +226,8 @@ class RunVar(Generic[_T], metaclass=ABCMeta): _name: str = attr.ib() _default: _T = attr.ib(default=cast(_T, object())) def get(self, default: _T = ...) -> _T: ... - def set(self, value: _T) -> _RunVarToken: ... - def reset(self, token: _RunVarToken) -> None: ... + def set(self, value: _T) -> RunVarToken[_T]: ... + def reset(self, token: RunVarToken[_T]) -> None: ... # _core._thread_cache def start_thread_soon( diff --git a/trio-stubs/socket.pyi b/trio-stubs/socket.pyi index fa0c001..6ac85f9 100644 --- a/trio-stubs/socket.pyi +++ b/trio-stubs/socket.pyi @@ -402,12 +402,16 @@ async def getaddrinfo( ]: ... class SocketType: - family: int - type: int - proto: int - did_shutdown_SHUT_WR: bool def __enter__(self: _T) -> _T: ... def __exit__(self, *args: Any) -> None: ... + @property + def did_shutdown_SHUT_WR(self) -> bool: ... + @property + def family(self) -> int: ... + @property + def type(self) -> int: ... + @property + def proto(self) -> int: ... def dup(self) -> SocketType: ... def close(self) -> None: ... async def bind(self, address: Union[Tuple[Any, ...], str, bytes]) -> None: ... diff --git a/trio_typing/_tests/test-data/async_generator.test b/trio_typing/_tests/test-data/async_generator.test index c60cf90..c1b1852 100644 --- a/trio_typing/_tests/test-data/async_generator.test +++ b/trio_typing/_tests/test-data/async_generator.test @@ -61,7 +61,7 @@ async def dummy(): firstiter = agen.get_asyncgen_hooks().firstiter if firstiter is not None: - firstiter(iter([])) # E: Argument 1 has incompatible type "Iterator[]"; expected "AsyncGenerator[Any, Any]" + firstiter(iter([])) # E: Argument 1 has incompatible type "Iterator[Never]"; expected "AsyncGenerator[Any, Any]" reveal_type(firstiter(dummy())) # N: Revealed type is "Any" agen.set_asyncgen_hooks(firstiter) agen.set_asyncgen_hooks(firstiter, finalizer=firstiter) diff --git a/trio_typing/_tests/test-data/outcome.test b/trio_typing/_tests/test-data/outcome.test index eb9575e..86093d9 100644 --- a/trio_typing/_tests/test-data/outcome.test +++ b/trio_typing/_tests/test-data/outcome.test @@ -32,4 +32,4 @@ async def test() -> None: reveal_type(test_outcome.unwrap()) # N: Revealed type is "builtins.str" reveal_type(test_outcome.send(calc_lengths())) # N: Revealed type is "builtins.int" reveal_type(await test_outcome.asend(calc_lengths_async())) # N: Revealed type is "builtins.int" - test_outcome.send(wrong_type_gen()) # E: Argument 1 to "send" of "Value" has incompatible type "Iterator[int]"; expected "Generator[, str, Any]" # E: Argument 1 to "send" of "Error" has incompatible type "Iterator[int]"; expected "Generator[, Any, Any]" + test_outcome.send(wrong_type_gen()) # E: Argument 1 to "send" of "Value" has incompatible type "Iterator[int]"; expected "Generator[Never, str, Any]" # E: Argument 1 to "send" of "Error" has incompatible type "Iterator[int]"; expected "Generator[Never, Any, Any]" diff --git a/trio_typing/_tests/test-data/taskstatus.test b/trio_typing/_tests/test-data/taskstatus.test index 071b1a2..6400b20 100644 --- a/trio_typing/_tests/test-data/taskstatus.test +++ b/trio_typing/_tests/test-data/taskstatus.test @@ -21,10 +21,10 @@ async def parent() -> None: nursery.start_soon(child2) # E: Argument 1 to "start_soon" of "Nursery" has incompatible type "Callable[[int, DefaultNamedArg(TaskStatus[None], 'task_status')], Coroutine[Any, Any, None]]"; expected "Callable[[], Awaitable[Any]]" nursery.start_soon(child2, "hi") # E: Argument 1 to "start_soon" of "Nursery" has incompatible type "Callable[[int, DefaultNamedArg(TaskStatus[None], 'task_status')], Coroutine[Any, Any, None]]"; expected "Callable[[str], Awaitable[Any]]" nursery.start_soon(child2, 50) - await nursery.start(child) # E: Argument 1 to "start" of "Nursery" has incompatible type "Callable[[int, NamedArg(TaskStatus[int], 'task_status')], Coroutine[Any, Any, None]]"; expected "Callable[[NamedArg(TaskStatus[], 'task_status')], Awaitable[Any]]" + await nursery.start(child) # E: Argument 1 to "start" of "Nursery" has incompatible type "Callable[[int, NamedArg(TaskStatus[int], 'task_status')], Coroutine[Any, Any, None]]"; expected "Callable[[NamedArg(TaskStatus[int], 'task_status')], Awaitable[Any]]" await nursery.start(child, "hi") # E: Argument 1 to "start" of "Nursery" has incompatible type "Callable[[int, NamedArg(TaskStatus[int], 'task_status')], Coroutine[Any, Any, None]]"; expected "Callable[[str, NamedArg(TaskStatus[int], 'task_status')], Awaitable[Any]]" result = await nursery.start(child, 10) - result2 = await nursery.start(child2, 10) # E: Function does not return a value + result2 = await nursery.start(child2, 10) # E: Function does not return a value (it only ever returns None) await nursery.start(child2, 10) reveal_type(result) # N: Revealed type is "builtins.int" @@ -54,6 +54,6 @@ async def parent() -> None: await nursery.start(child) await nursery.start(child, "hi") result = await nursery.start(child, 10) - result2 = await nursery.start(child2, 10) # E: Function does not return a value + result2 = await nursery.start(child2, 10) # E: Function does not return a value (it only ever returns None) await nursery.start(child2, 10) reveal_type(result) # N: Revealed type is "builtins.int" diff --git a/trio_typing/_tests/test-data/trio-basic.test b/trio_typing/_tests/test-data/trio-basic.test index e57df21..fc12aae 100644 --- a/trio_typing/_tests/test-data/trio-basic.test +++ b/trio_typing/_tests/test-data/trio-basic.test @@ -26,7 +26,7 @@ reveal_type(val) # N: Revealed type is "builtins.list[builtins.float]" trio.run(sleep_sort, ["hi", "there"]) # E: Argument 1 to "run" has incompatible type "Callable[[Sequence[float]], Coroutine[Any, Any, List[float]]]"; expected "Callable[[List[str]], Awaitable[List[float]]]" -reveal_type(trio.Event().statistics().anything) # N: Revealed type is "Any" +reveal_type(trio.Event().statistics().tasks_waiting) # N: Revealed type is "builtins.int" [case testTrioBasic_NoPlugin] import trio @@ -55,7 +55,7 @@ val = trio.run(sleep_sort, (1, 3, 5, 2, 4), clock=trio.testing.MockClock(autojum reveal_type(val) # N: Revealed type is "builtins.list[builtins.float]" trio.run(sleep_sort, ["hi", "there"]) -reveal_type(trio.Event().statistics().anything) # N: Revealed type is "Any" +reveal_type(trio.Event().statistics().tasks_waiting) # N: Revealed type is "builtins.int" [case testExceptions] import trio @@ -75,11 +75,6 @@ def filter_exc(exc: BaseException): with trio.MultiError.catch(filter_exc): pass -try: - trio.run(trio.sleep, 3) -except trio.MultiError as ex: - reveal_type(ex.exceptions[0]) # N: Revealed type is "builtins.BaseException" - [case testOverloaded] from typing import overload, Any diff --git a/trio_typing/plugin.py b/trio_typing/plugin.py index ca88b1d..31f4887 100644 --- a/trio_typing/plugin.py +++ b/trio_typing/plugin.py @@ -39,11 +39,6 @@ class TrioPlugin(Plugin): def get_function_hook( self, fullname: str ) -> Optional[Callable[[FunctionContext], Type]]: - if fullname in ( - "contextlib.asynccontextmanager", - "async_generator.asynccontextmanager", - ): - return args_invariant_decorator_callback if fullname == "trio_typing.takes_callable_and_args": return takes_callable_and_args_callback if fullname == "async_generator.async_generator": @@ -65,26 +60,6 @@ def get_function_hook( return super().get_function_hook(fullname) -def args_invariant_decorator_callback(ctx: FunctionContext) -> Type: - """Infer a better return type for @asynccontextmanager, - @async_generator, and other decorators that affect the return - type but not the argument types of the function they decorate. - """ - # (adapted from the @contextmanager support in mypy's builtin plugin) - if ctx.arg_types and len(ctx.arg_types[0]) == 1: - arg_type = get_proper_type(ctx.arg_types[0][0]) - ret_type = get_proper_type(ctx.default_return_type) - if isinstance(arg_type, CallableType) and isinstance(ret_type, CallableType): - return ret_type.copy_modified( - arg_types=arg_type.arg_types, - arg_kinds=arg_type.arg_kinds, - arg_names=arg_type.arg_names, - variables=arg_type.variables, - is_ellipsis_args=arg_type.is_ellipsis_args, - ) - return ctx.default_return_type - - def decode_agen_types_from_return_type( ctx: FunctionContext, original_async_return_type: Type ) -> Tuple[Type, Type, Type]: @@ -189,25 +164,23 @@ async def example() -> Union[str, YieldType[bool], SendType[int]]: YieldType[bool], SendType[int]]`` without the plugin. """ - # Apply the common logic to not change the arguments of the - # decorated function - new_return_type = args_invariant_decorator_callback(ctx) - if not isinstance(new_return_type, CallableType): - return new_return_type - agen_return_type = get_proper_type(new_return_type.ret_type) + decorator_return_type = ctx.default_return_type + if not isinstance(decorator_return_type, CallableType): + return decorator_return_type + agen_return_type = get_proper_type(decorator_return_type.ret_type) if ( isinstance(agen_return_type, Instance) and agen_return_type.type.fullname == "trio_typing.CompatAsyncGenerator" and len(agen_return_type.args) == 3 ): - return new_return_type.copy_modified( + return decorator_return_type.copy_modified( ret_type=agen_return_type.copy_modified( args=list( decode_agen_types_from_return_type(ctx, agen_return_type.args[2]) ) ) ) - return new_return_type + return decorator_return_type def decode_enclosing_agen_types(ctx: FunctionContext) -> Tuple[Type, Type]: @@ -524,7 +497,9 @@ def start_soon( def plugin(version: str) -> typing_Type[Plugin]: mypy_version = parse_version(version) - if mypy_version < parse_version("1.4"): + if mypy_version < parse_version("1.0"): + raise RuntimeError("This version of trio-typing requires at least mypy 1.0.") + elif mypy_version < parse_version("1.4"): return TrioPlugin13 else: return TrioPlugin