Skip to content

Commit

Permalink
Annotated syntax support
Browse files Browse the repository at this point in the history
  • Loading branch information
vitalik committed Sep 10, 2023
1 parent 9f9b567 commit a451939
Show file tree
Hide file tree
Showing 6 changed files with 247 additions and 46 deletions.
7 changes: 6 additions & 1 deletion WHATSNEW_V1.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
- CSRF changes
- Auth async support
- Schema.Meta
- pagination request in paginate_queryset
- pagination: request in paginate_queryset
- decorators
- openapi docs plugable
- add_router supports strings

TODO:
- async pagination

Backwards incompatible stuff
- resolve_xxx(self, ...)
- pydantic v1
17 changes: 2 additions & 15 deletions ninja/compatibility/util.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,8 @@
from typing import Any, Callable, Optional, Union
from typing import Callable, Union

import django

__all__ = ["get_origin", "get_args", "async_to_sync"]

# python3.8+ get_origin, get_args
try:
from typing import get_args, get_origin # type: ignore
except ImportError: # pragma: no coverage

def get_origin(tp: Any) -> Optional[Any]:
"typing.get_origin introduced in python3.8"
return getattr(tp, "__origin__", None)

def get_args(tp: Any) -> Optional[Any]:
"typing.get_args introduced in python3.8"
return getattr(tp, "__args__", None)
__all__ = ["async_to_sync", "UNION_TYPES"]


# python3.10+ syntax of creating a union or optional type (with str | int)
Expand Down
2 changes: 1 addition & 1 deletion ninja/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from django.db.models import QuerySet
from django.http import HttpRequest
from django.utils.module_loading import import_string
from typing_extensions import get_args as get_collection_args # type: ignore

from ninja import Field, Query, Router, Schema
from ninja.compatibility.util import get_args as get_collection_args
from ninja.conf import settings
from ninja.constants import NOT_SET
from ninja.errors import ConfigError
Expand Down
14 changes: 7 additions & 7 deletions ninja/params_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def Path( # noqa: N802
default: Any,
default: Any = ...,
*,
alias: Optional[str] = None,
title: Optional[str] = None,
Expand Down Expand Up @@ -48,7 +48,7 @@ def Path( # noqa: N802


def Query( # noqa: N802
default: Any,
default: Any = ...,
*,
alias: Optional[str] = None,
title: Optional[str] = None,
Expand Down Expand Up @@ -87,7 +87,7 @@ def Query( # noqa: N802


def Header( # noqa: N802
default: Any,
default: Any = ...,
*,
alias: Optional[str] = None,
title: Optional[str] = None,
Expand Down Expand Up @@ -126,7 +126,7 @@ def Header( # noqa: N802


def Cookie( # noqa: N802
default: Any,
default: Any = ...,
*,
alias: Optional[str] = None,
title: Optional[str] = None,
Expand Down Expand Up @@ -165,7 +165,7 @@ def Cookie( # noqa: N802


def Body( # noqa: N802
default: Any,
default: Any = ...,
*,
alias: Optional[str] = None,
title: Optional[str] = None,
Expand Down Expand Up @@ -204,7 +204,7 @@ def Body( # noqa: N802


def Form( # noqa: N802
default: Any,
default: Any = ...,
*,
alias: Optional[str] = None,
title: Optional[str] = None,
Expand Down Expand Up @@ -243,7 +243,7 @@ def Form( # noqa: N802


def File( # noqa: N802
default: Any,
default: Any = ...,
*,
alias: Optional[str] = None,
title: Optional[str] = None,
Expand Down
48 changes: 26 additions & 22 deletions ninja/signature/details.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,10 @@
from django.http import HttpResponse
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
from typing_extensions import Annotated, get_args, get_origin # type: ignore

from ninja import UploadedFile, params
from ninja.compatibility.util import (
UNION_TYPES,
get_args,
)
from ninja.compatibility.util import (
get_origin as get_collection_origin,
)
from ninja.compatibility.util import UNION_TYPES
from ninja.errors import ConfigError
from ninja.params import Body, File, Form, _MultiPartBody
from ninja.params_models import TModel, TModels
Expand Down Expand Up @@ -205,15 +200,24 @@ def _model_flatten_map(self, model: TModel, prefix: str) -> Generator:
def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam:
# _EMPTY = self.signature.empty
annotation = arg.annotation
default = arg.default

if get_origin(annotation) is Annotated:
args = get_args(annotation)
if isinstance(args[1], params.Param):
prev_default = default
annotation, default = args
if prev_default != self.signature.empty:
default.default = prev_default

if annotation == self.signature.empty:
if arg.default == self.signature.empty:
if default == self.signature.empty:
annotation = str
else:
if isinstance(arg.default, params.Param):
annotation = type(arg.default.default)
if isinstance(default, params.Param):
annotation = type(default.default)
else:
annotation = type(arg.default)
annotation = type(default)

if annotation == PydanticUndefined.__class__:
# TODO: ^ check why is that so
Expand All @@ -228,34 +232,34 @@ def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam:
is_collection and annotation.__args__[0] == UploadedFile
):
# People often forgot to mark UploadedFile as a File, so we better assign it automatically
if arg.default == self.signature.empty or arg.default is None:
default = arg.default == self.signature.empty and ... or arg.default
if default == self.signature.empty or default is None:
default = default == self.signature.empty and ... or default
return FuncParam(name, name, File(default), annotation, is_collection)

# 1) if type of the param is defined as one of the Param's subclasses - we just use that definition
if isinstance(arg.default, params.Param):
param_source = arg.default
if isinstance(default, params.Param):
param_source = default

# 2) if param name is a part of the path parameter
elif name in self.path_params_names:
assert (
arg.default == self.signature.empty
default == self.signature.empty
), f"'{name}' is a path param, default not allowed"
param_source = params.Path(...)

# 3) if param is a collection, or annotation is part of pydantic model:
elif is_collection or is_pydantic_model(annotation):
if arg.default == self.signature.empty:
if default == self.signature.empty:
param_source = params.Body(...)
else:
param_source = params.Body(arg.default)
param_source = params.Body(default)

# 4) the last case is query param
else:
if arg.default == self.signature.empty:
if default == self.signature.empty:
param_source = params.Query(...)
else:
param_source = params.Query(arg.default)
param_source = params.Query(default)

return FuncParam(
name, param_source.alias or name, param_source, annotation, is_collection
Expand All @@ -264,15 +268,15 @@ def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam:

def is_pydantic_model(cls: Any) -> bool:
try:
if get_collection_origin(cls) in UNION_TYPES:
if get_origin(cls) in UNION_TYPES:
return any(issubclass(arg, pydantic.BaseModel) for arg in get_args(cls))
return issubclass(cls, pydantic.BaseModel)
except TypeError:
return False


def is_collection_type(annotation: Any) -> bool:
origin = get_collection_origin(annotation)
origin = get_origin(annotation)
collection_types = (List, list, set, tuple)
if origin is None:
return (
Expand Down
Loading

0 comments on commit a451939

Please sign in to comment.