• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

MrThearMan / undine / 16081423623

04 Jul 2025 09:57PM UTC coverage: 97.685%. First build
16081423623

Pull #33

github

web-flow
Merge 6eb57167c into 784a68391
Pull Request #33: Add Subscriptions

1798 of 1841 branches covered (97.66%)

Branch coverage included in aggregate %.

1009 of 1176 new or added lines in 36 files covered. (85.8%)

26853 of 27489 relevant lines covered (97.69%)

8.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

62.89
/undine/subscription.py
1
from __future__ import annotations
9✔
2

3
from abc import ABC, ABCMeta, abstractmethod
9✔
4
from types import MappingProxyType
9✔
5
from typing import TYPE_CHECKING, Any, ClassVar, Unpack
9✔
6

7
from graphql import DirectiveLocation, GraphQLInputField, Undefined
9✔
8

9
from undine.converters import convert_to_graphql_type
9✔
10
from undine.parsers import parse_class_attribute_docstrings, parse_return_annotation
9✔
11
from undine.settings import undine_settings
9✔
12
from undine.utils.graphql.type_registry import get_or_create_graphql_input_object_type
9✔
13
from undine.utils.graphql.utils import check_directives
9✔
14
from undine.utils.reflection import FunctionEqualityWrapper, get_members
9✔
15
from undine.utils.text import dotpath, get_docstring, to_schema_name
9✔
16

17
if TYPE_CHECKING:
18
    from collections.abc import AsyncGenerator
19

20
    from graphql import GraphQLInputObjectType, GraphQLInputType, GraphQLObjectType
21

22
    from undine.directives import Directive
23
    from undine.typing import DefaultValueType, GQLInfo, SubscriptionInputParams, SubscriptionTypeParams, TypeHint
24

25
__all__ = [
9✔
26
    "SubscriptionInput",
27
    "SubscriptionType",
28
]
29

30

31
class SubscriptionTypeMeta(ABCMeta):
9✔
32
    """A metaclass that modifies how a `SubscriptionType` is created."""
33

34
    # Set in '__new__'
35
    __input_map__: dict[str, SubscriptionInput]
9✔
36
    __schema_name__: str
9✔
37
    __directives__: list[Directive]
9✔
38
    __extensions__: dict[str, Any]
9✔
39
    __attribute_docstrings__: dict[str, str]
9✔
40

41
    def __new__(
9✔
42
        cls,
43
        _name: str,
44
        _bases: tuple[type, ...],
45
        _attrs: dict[str, Any],
46
        **kwargs: Unpack[SubscriptionTypeParams],
47
    ) -> SubscriptionTypeMeta:
48
        if _name == "SubscriptionType":  # Early return for the `SubscriptionType` class itself.
9✔
49
            return super().__new__(cls, _name, _bases, _attrs)
9✔
50

51
        subscription_type = super().__new__(cls, _name, _bases, _attrs)
9✔
52

53
        # Members should use `__dunder__` names to avoid name collisions with possible `SubscriptionInput` names.
54
        subscription_type.__input_map__ = get_members(subscription_type, SubscriptionInput)
9✔
55
        subscription_type.__schema_name__ = kwargs.get("schema_name", _name)
9✔
56
        subscription_type.__directives__ = kwargs.get("directives", [])
9✔
57
        subscription_type.__extensions__ = kwargs.get("extensions", {})
9✔
58
        subscription_type.__attribute_docstrings__ = parse_class_attribute_docstrings(subscription_type)
9✔
59

60
        check_directives(subscription_type.__directives__, location=DirectiveLocation.INPUT_OBJECT)
9✔
61
        subscription_type.__extensions__[undine_settings.SUBSCRIPTION_TYPE_EXTENSIONS_KEY] = subscription_type
9✔
62

63
        for name, sub_input in subscription_type.__input_map__.items():
9✔
NEW
64
            sub_input.__connect__(subscription_type, name)  # type: ignore[arg-type]
×
65

66
        return subscription_type
9✔
67

68
    def __str__(cls) -> str:
9✔
NEW
69
        return undine_settings.SDL_PRINTER.print_object_type(cls.__output_type__())
×
70

71
    def __input_type__(cls) -> GraphQLInputObjectType:
9✔
72
        """Creates the GraphQL `InputObjectType` for this `SubscriptionType`."""
NEW
73
        return get_or_create_graphql_input_object_type(
×
74
            name=cls.__schema_name__,
75
            fields=FunctionEqualityWrapper(cls.__input_fields__, context=cls),
76
            description=get_docstring(cls),
77
            extensions=cls.__extensions__,
78
        )
79

80
    def __input_fields__(cls) -> dict[str, GraphQLInputField]:
9✔
NEW
81
        return {
×
82
            sub_input.schema_name: sub_input.as_graphql_input_field()  # ...
83
            for sub_input in cls.__input_map__.values()
84
        }
85

86
    def __output_type__(cls) -> GraphQLObjectType:
9✔
87
        """Creates the GraphQL `ObjectType` for this `SubscriptionType`."""
88
        # TODO: What if type is in a `TYPE_CHECKING` block
NEW
89
        ann = parse_return_annotation(cls.__run__)  # type: ignore[attr-defined]
×
NEW
90
        return convert_to_graphql_type(ann)
×
91

92

93
class SubscriptionType(ABC, metaclass=SubscriptionTypeMeta):
9✔
94
    """
95
    A class for creating a new `SubscriptionType` with `SubscriptionInputs`.
96
    Represents a GraphQL `GraphQLObjectType` in the GraphQL schema.
97

98
    The following parameters can be passed in the class definition:
99

100
     `schema_name: str = <class name>`
101
        Override name for the `GraphQLObjectType` for this `SubscriptionType` in the GraphQL schema.
102

103
     `directives: list[Directive] = []`
104
        `Directives` to add to the created `GraphQLObjectType`.
105

106
     `extensions: dict[str, Any] = {}`
107
        GraphQL extensions for the created `GraphQLObjectType`.
108

109
    >>> class Countdown(SubscriptionType):
110
    >>>     @classmethod
111
    >>>     async def __run__(cls, root: Any, info: GQLInfo) -> AsyncGenerator[int, None]:
112
    >>>         for i in range(10):
113
    >>>             yield i
114
    """
115

116
    # Set in metaclass
117
    __input_map__: ClassVar[dict[str, SubscriptionInput]]
9✔
118
    __schema_name__: ClassVar[str]
9✔
119
    __directives__: ClassVar[list[Directive]]
9✔
120
    __extensions__: ClassVar[dict[str, Any]]
9✔
121
    __attribute_docstrings__: ClassVar[dict[str, str]]
9✔
122

123
    @abstractmethod
9✔
124
    async def __run__(self, root: Any, info: GQLInfo) -> AsyncGenerator[Any, None]:
9✔
125
        """Async generator for the subscription."""
126

127
    def __init__(self, **kwargs: Any) -> None:
9✔
128
        parameters: dict[str, Any] = {}
9✔
129

130
        for inpt in self.__input_map__.values():
9✔
NEW
131
            value = kwargs.pop(inpt.name, inpt.default_value)
×
NEW
132
            if value is Undefined:
×
NEW
133
                raise ValueError  # TODO: Custom error
×
134

NEW
135
            parameters[inpt.name] = value
×
136

137
        if kwargs:
9✔
NEW
138
            raise ValueError  # TODO: Custom error
×
139

140
        self.__parameters__: MappingProxyType[str, Any] = MappingProxyType(parameters)
9✔
141

142

143
class SubscriptionInput:
9✔
144
    """
145
    A class for defining a possible input for a subscription.
146
    Represents an input field on a GraphQL `InputObjectType` for the `SubscriptionType` this is added to.
147

148
    >>> class Countdown(SubscriptionType):
149
    >>>     start = SubscriptionInput(int, default_value=0)
150
    >>>
151
    >>>     async def __run__(self, root: Any, info: GQLInfo) -> AsyncGenerator[int, None]:
152
    >>>         for i in range(self.start):
153
    >>>             yield i
154
    """
155

156
    def __init__(self, ref: TypeHint, **kwargs: Unpack[SubscriptionInputParams]) -> None:
9✔
157
        """
158
        Create a new `SubscriptionInput`.
159

160
        :param ref: The argument reference to use for the `SubscriptionInput`.
161
        :param default_value: The default value for the `SubscriptionInput`.
162
        :param description: Description for the `SubscriptionInput`.
163
        :param deprecation_reason: If the `SubscriptionInput` is deprecated, describes the reason for deprecation.
164
        :param schema_name: Actual name in the GraphQL schema. Only needed if argument name is a python keyword.
165
        :param directives: GraphQL directives for the `SubscriptionInput`.
166
        :param extensions: GraphQL extensions for the `SubscriptionInput`.
167
        """
NEW
168
        self.ref = ref
×
169

NEW
170
        self.default_value: DefaultValueType = kwargs.get("default_value", Undefined)
×
NEW
171
        self.description: str | None = kwargs.get("description", Undefined)  # type: ignore[assignment]
×
NEW
172
        self.deprecation_reason: str | None = kwargs.get("deprecation_reason")
×
NEW
173
        self.schema_name: str = kwargs.get("schema_name", Undefined)  # type: ignore[assignment]
×
NEW
174
        self.directives: list[Directive] = kwargs.get("directives", [])
×
NEW
175
        self.extensions: dict[str, Any] = kwargs.get("extensions", {})
×
176

NEW
177
        check_directives(self.directives, location=DirectiveLocation.INPUT_FIELD_DEFINITION)
×
NEW
178
        self.extensions[undine_settings.SUBSCRIPTION_INPUT_EXTENSIONS_KEY] = self
×
179

180
    def __connect__(self, subscription_type: type[SubscriptionType], name: str) -> None:
9✔
181
        """Connect this `SubscriptionInput` to the given `SubscriptionType` using the given name."""
NEW
182
        self.subscription_type = subscription_type
×
NEW
183
        self.name = name
×
NEW
184
        self.schema_name = self.schema_name or to_schema_name(name)
×
185

NEW
186
        if self.description is Undefined:
×
NEW
187
            self.description = self.subscription_type.__attribute_docstrings__.get(name)
×
188

189
    def __repr__(self) -> str:
9✔
NEW
190
        return f"<{dotpath(self.__class__)}(ref={self.ref!r})>"
×
191

192
    def __str__(self) -> str:
9✔
NEW
193
        inpt = self.as_graphql_input_field()
×
NEW
194
        return undine_settings.SDL_PRINTER.print_input_field(self.schema_name, inpt, indent=False)
×
195

196
    def __get__(self, instance: SubscriptionType | None, cls: type[SubscriptionType]) -> Any:
9✔
NEW
197
        if instance is None:
×
NEW
198
            return self
×
NEW
199
        return instance.__parameters__[self.name]
×
200

201
    def as_graphql_input_field(self) -> GraphQLInputField:
9✔
NEW
202
        return GraphQLInputField(
×
203
            type_=self.get_field_type(),
204
            default_value=self.default_value,
205
            description=self.description,
206
            deprecation_reason=self.deprecation_reason,
207
            out_name=self.name,
208
            extensions=self.extensions,
209
        )
210

211
    def get_field_type(self) -> GraphQLInputType:
9✔
NEW
212
        return convert_to_graphql_type(self.ref, is_input=True)  # type: ignore[return-value]
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc