• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

zincware / ZnFlow / 11546904243

18 Oct 2024 03:41PM UTC coverage: 96.804% (-0.04%) from 96.844%
11546904243

push

github

web-flow
add automatic break points based on magic method detection (#113)

* add automatic break points based on magic method detection

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tests but undo magic methods

* fix tests

* support all magic methods

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* pre-commit fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

219 of 229 new or added lines in 5 files covered. (95.63%)

4 existing lines in 3 files now uncovered.

2635 of 2722 relevant lines covered (96.8%)

3.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.36
/znflow/base.py
1
from __future__ import annotations
4✔
2

3
import contextlib
4✔
4
import dataclasses
4✔
5
import typing
4✔
6
from typing import Any
4✔
7
from uuid import UUID
4✔
8

9
from znflow import exceptions
4✔
10

11
if typing.TYPE_CHECKING:
4✔
12
    from znflow.graph import DiGraph
×
13

14

15
@contextlib.contextmanager
4✔
16
def disable_graph(*args, **kwargs):
4✔
17
    """Temporarily disable set the graph to empty.
18

19
    This can be useful, if you e.g. want to use 'get_attribute'.
20
    """
21
    graph = get_graph()
4✔
22
    set_graph(empty_graph)
4✔
23
    try:
4✔
24
        yield
4✔
25
    finally:
26
        set_graph(graph)
4✔
27

28

29
class Property:
4✔
30
    """Custom Property with disabled graph.
4✔
31

32
    References
33
    ----------
34
    Adapted from https://docs.python.org/3/howto/descriptor.html#properties
35
    """
36

37
    def __init__(self, fget=None, fset=None, fdel=None, doc=None):
4✔
38
        self.fget = disable_graph()(fget)
4✔
39
        self.fset = disable_graph()(fset)
4✔
40
        self.fdel = disable_graph()(fdel)
4✔
41
        if doc is None and fget is not None:
4✔
42
            doc = fget.__doc__
4✔
43
        self.__doc__ = doc
4✔
44
        self._name = ""
4✔
45

46
    def __set_name__(self, owner, name):
4✔
47
        self._name = name
4✔
48

49
    def __get__(self, obj, objtype=None):
4✔
50
        if obj is None:
4✔
51
            return self
4✔
52
        if self.fget is None:
4✔
53
            raise AttributeError(f"property '{self._name}' has no getter")
×
54
        return self.fget(obj)
4✔
55

56
    def __set__(self, obj, value):
4✔
57
        if self.fset is None:
4✔
58
            raise AttributeError(f"property '{self._name}' has no setter")
×
59
        self.fset(obj, value)
4✔
60

61
    def __delete__(self, obj):
4✔
62
        if self.fdel is None:
×
63
            raise AttributeError(f"property '{self._name}' has no deleter")
×
64
        self.fdel(obj)
×
65

66
    def getter(self, fget):
4✔
67
        prop = type(self)(fget, self.fset, self.fdel, self.__doc__)
×
68
        prop._name = self._name
×
69
        return prop
×
70

71
    def setter(self, fset):
4✔
72
        prop = type(self)(self.fget, fset, self.fdel, self.__doc__)
4✔
73
        prop._name = self._name
4✔
74
        return prop
4✔
75

76
    def deleter(self, fdel):
4✔
77
        prop = type(self)(self.fget, self.fset, fdel, self.__doc__)
×
78
        prop._name = self._name
×
79
        return prop
×
80

81

82
@dataclasses.dataclass(frozen=True)
4✔
83
class EmptyGraph:
4✔
84
    """An empty class used as a default value for _graph_."""
4✔
85

86

87
empty_graph = EmptyGraph()
4✔
88

89

90
class NodeBaseMixin:
4✔
91
    """A Parent for all Nodes.
4✔
92

93
    This class is used to globally access and change all classes that inherit from it.
94

95
    Attributes
96
    ----------
97
        _graph_ : DiGraph
98
            The graph this node belongs to.
99
            This is only available within the graph context.
100
        uuid : UUID
101
            The unique identifier of this node.
102
        _external_ : bool
103
            If true, the node is allowed to be created outside of a graph context.
104
            In this case connections can be created to this node, otherwise
105
            an exception is raised.
106
    """
107

108
    _graph_ = empty_graph
4✔
109
    _external_ = False
4✔
110
    _uuid: UUID = None
4✔
111

112
    _protected_ = [
4✔
113
        "_graph_",
114
        "uuid",
115
        "_uuid",
116
        "model_fields",  # pydantic
117
        "model_computed_fields",  # pydantic
118
    ]
119

120
    @property
4✔
121
    def uuid(self):
4✔
122
        return self._uuid
4✔
123

124
    @uuid.setter
4✔
125
    def uuid(self, value):
4✔
126
        if self._uuid is not None:
4✔
127
            raise ValueError("uuid is already set")
×
128
        self._uuid = value
4✔
129

130
    def run(self):
4✔
131
        raise NotImplementedError
132

133

134
def get_graph() -> DiGraph:
4✔
135
    return NodeBaseMixin._graph_
4✔
136

137

138
def set_graph(value):
4✔
139
    NodeBaseMixin._graph_ = value
4✔
140

141

142
_get_attribute_none = object()
4✔
143

144

145
def get_attribute(obj, name, default=_get_attribute_none):
4✔
146
    """Get the real value of the attribute and not a znflow.Connection."""
147
    with disable_graph():
4✔
148
        if default is _get_attribute_none:
4✔
149
            return getattr(obj, name)
4✔
150
        return getattr(obj, name, default)
4✔
151

152

153
@dataclasses.dataclass(frozen=True)
4✔
154
class Connection:
4✔
155
    """A Connector for Nodes.
4✔
156

157
    Attributes
158
    ----------
159
        instance: Node|FunctionFuture
160
            the object this connection points to
161
        attribute: str
162
            Node.attribute
163
            or FunctionFuture.result
164
            or None if the class is passed and not an attribute
165
        item: any
166
            any slice or list index to be applied to the result
167
    """
168

169
    instance: any
4✔
170
    attribute: str
4✔
171
    item: any = None
4✔
172

173
    def __post_init__(self):
4✔
174
        if self.attribute is not None and self.attribute.startswith("_"):
4✔
175
            raise ValueError("Private attributes are not allowed.")
×
176

177
    def __getitem__(self, item):
4✔
178
        return dataclasses.replace(self, instance=self, attribute=None, item=item)
4✔
179

180
    def __add__(
4✔
181
        self, other: typing.Union[Connection, FunctionFuture, CombinedConnections]
182
    ) -> CombinedConnections:
183
        if isinstance(other, (Connection, FunctionFuture, CombinedConnections)):
4✔
184
            return CombinedConnections(connections=[self, other])
4✔
185
        raise TypeError(f"Can not add {type(other)} to {type(self)}.")
4✔
186

187
    def __radd__(self, other):
4✔
188
        """Enable 'sum([a, b], [])'"""
189
        return self if other == [] else self.__add__(other)
4✔
190

191
    @property
4✔
192
    def uuid(self):
4✔
193
        return self.instance.uuid
4✔
194

195
    @property
4✔
196
    def _external_(self):
4✔
197
        return self.instance._external_
4✔
198

199
    @property
4✔
200
    def result(self):
4✔
201
        if self.attribute:
4✔
202
            result = getattr(self.instance, self.attribute)
4✔
203
        elif isinstance(self.instance, (FunctionFuture, self.__class__)):
4✔
204
            result = self.instance.result
4✔
205
        else:
206
            result = self.instance
4✔
207
        return result[self.item] if self.item else result
4✔
208

209
    def __getattribute__(self, __name: str) -> Any:
4✔
210
        try:
4✔
211
            return super().__getattribute__(__name)
4✔
212
        except AttributeError as e:
4✔
213
            raise exceptions.ConnectionAttributeError(
4✔
214
                "Connection does not support further attributes to its result."
215
            ) from e
216

217
    def __eq__(self, other) -> bool:
4✔
218
        """Overwrite for dynamic break points."""
219
        from znflow import resolve, get_graph, empty_graph
4✔
220

221
        if isinstance(other, (Connection)):
4✔
222
            return self.instance == other.instance
4✔
223
        if isinstance(other, (FunctionFuture)):
4✔
NEW
224
            return False
×
225

226
        if get_graph() is empty_graph:
4✔
227
            return super().__eq__(other)
4✔
228
        return resolve(self).__eq__(other)
4✔
229

230
    def __lt__(self, other) -> bool:
4✔
231
        """Overwrite for dynamic break points."""
232
        from znflow import resolve, get_graph, empty_graph
4✔
233

234
        if get_graph() is empty_graph:
4✔
NEW
235
            return super().__lt__(other)
×
236
        return resolve(self).__lt__(other)
4✔
237

238
    def __le__(self, other) -> bool:
4✔
239
        """Overwrite for dynamic break points."""
240
        from znflow import resolve, get_graph, empty_graph
4✔
241

242
        if get_graph() is empty_graph:
4✔
NEW
243
            return super().__le__(other)
×
244
        return resolve(self).__le__(other)
4✔
245

246
    def __gt__(self, other) -> bool:
4✔
247
        """Overwrite for dynamic break points."""
248
        from znflow import resolve, get_graph, empty_graph
4✔
249

250
        if get_graph() is empty_graph:
4✔
NEW
251
            return super().__gt__(other)
×
252
        return resolve(self).__gt__(other)
4✔
253

254
    def __ge__(self, other) -> bool:
4✔
255
        """Overwrite for dynamic break points."""
256
        from znflow import resolve, get_graph, empty_graph
4✔
257

258
        if get_graph() is empty_graph:
4✔
NEW
259
            return super().__ge__(other)
×
260
        return resolve(self).__ge__(other)
4✔
261

262
    def __iter__(self):
4✔
263
        from znflow import resolve
4✔
264

265
        try:
4✔
266
            return resolve(self).__iter__()
4✔
267
        except AttributeError:
4✔
268
            raise TypeError(f"'{self}' object is not iterable")
4✔
269

270

271
@dataclasses.dataclass(frozen=True)
4✔
272
class CombinedConnections:
4✔
273
    """Combine multiple Connections into one.
4✔
274

275
    This class allows to 'add' Connections and/or FunctionFutures.
276
    This only works if the Connection or FunctionFuture points to a 'list'.
277
    A new entry of 'CombinedConnections' will be created for every time a new
278
    item is added.
279

280
    Examples
281
    --------
282

283
    >>> import znflow
284
    >>> @znflow.nodfiy
285
    >>> def add(size) -> list:
286
    >>>     return list(range(size))
287
    >>> with znflow.DiGraph() as graph:
288
    >>>     outs = add(2) + add(3)
289
    >>> graph.run()
290
    >>> assert outs.result == [0, 1, 0, 1, 2]
291

292
    Attributes
293
    ----------
294
    connections : list[Connection|FunctionFuture|AddedConnections]
295
        The List of items to be added.
296
    item : any
297
        Any slice to be applied to the result.
298
    """
299

300
    connections: typing.List[Connection]
4✔
301
    item: any = None
4✔
302

303
    def __add__(
4✔
304
        self, other: typing.Union[Connection, FunctionFuture, CombinedConnections]
305
    ) -> CombinedConnections:
306
        """Implement add for AddedConnections.
307

308
        Raises
309
        ------
310
        ValueError
311
            If  self.item is set, we can not add another item.
312
        TypeError
313
            If other is not a Connection, FunctionFuture or AddedConnections.
314
        """
315
        if self.item is not None:
4✔
316
            raise ValueError("Can not combine multiple slices")
4✔
317
        if isinstance(other, (Connection, FunctionFuture)):
4✔
318
            return dataclasses.replace(self, connections=self.connections + [other])
4✔
319
        elif isinstance(other, CombinedConnections):
4✔
320
            return dataclasses.replace(
4✔
321
                self, connections=self.connections + other.connections
322
            )
323
        else:
324
            raise TypeError(f"Can not add {type(other)} to {type(self)}.")
4✔
325

326
    def __radd__(self, other):
4✔
327
        """Enable 'sum([a, b], [])'"""
328
        return self if other == [] else self.__add__(other)
4✔
329

330
    def __getitem__(self, item):
4✔
331
        return dataclasses.replace(self, item=item)
4✔
332

333
    def __iter__(self):
4✔
334
        raise TypeError(f"Can not iterate over {self}.")
×
335

336
    @property
4✔
337
    def result(self):
4✔
338
        try:
4✔
339
            results = []
4✔
340
            for connection in self.connections:
4✔
341
                results.extend(connection.result)
4✔
342
            return results[self.item] if self.item else results
4✔
343
        except TypeError as err:
4✔
344
            raise TypeError(
4✔
345
                f"The value {connection.result} is of type {type(connection.result)}. The"
346
                f" only supported type is list. Please change {connection}"
347
            ) from err
348

349

350
@dataclasses.dataclass
4✔
351
class FunctionFuture(NodeBaseMixin):
4✔
352
    function: typing.Callable
4✔
353
    args: typing.Tuple
4✔
354
    kwargs: typing.Dict
4✔
355
    item: any = None
4✔
356

357
    result: any = dataclasses.field(default=None, init=False, repr=True)
4✔
358

359
    _protected_ = NodeBaseMixin._protected_ + ["function", "args", "kwargs"]
4✔
360

361
    def run(self):
4✔
362
        self.result = self.function(*self.args, **self.kwargs)
4✔
363

364
    def __getitem__(self, item):
4✔
365
        return Connection(instance=self, attribute=None, item=item)
4✔
366

367
    def __add__(
4✔
368
        self, other: typing.Union[Connection, FunctionFuture, CombinedConnections]
369
    ) -> CombinedConnections:
370
        if isinstance(other, (Connection, FunctionFuture, CombinedConnections)):
4✔
371
            return CombinedConnections(connections=[self, other])
4✔
372
        raise TypeError(f"Can not add {type(other)} to {type(self)}.")
4✔
373

374
    def __radd__(self, other):
4✔
375
        """Enable 'sum([a, b], [])'"""
376
        return self if other == [] else self.__add__(other)
4✔
377

378
    def __eq__(self, other) -> bool:
4✔
379
        """Overwrite for dynamic break points."""
380
        from znflow import resolve, get_graph, empty_graph
4✔
381

382
        if isinstance(other, (Connection)):
4✔
NEW
383
            return False
×
384
        if isinstance(other, (FunctionFuture)):
4✔
385
            return (
4✔
386
                self.function == other.function
387
                and self.args == other.args
388
                and self.kwargs == other.kwargs
389
                and self.item == other.item
390
            )
391

392
        if get_graph() is empty_graph:
4✔
393
            return super().__eq__(other)
4✔
394
        return resolve(self).__eq__(other)
4✔
395

396
    def __lt__(self, other) -> bool:
4✔
397
        """Overwrite for dynamic break points."""
398
        from znflow import resolve, get_graph, empty_graph
4✔
399

400
        if get_graph() is empty_graph:
4✔
NEW
401
            return super().__lt__(other)
×
402
        return resolve(self).__lt__(other)
4✔
403

404
    def __le__(self, other) -> bool:
4✔
405
        """Overwrite for dynamic break points."""
406
        from znflow import resolve, get_graph, empty_graph
4✔
407

408
        if get_graph() is empty_graph:
4✔
NEW
409
            return super().__le__(other)
×
410
        return resolve(self).__le__(other)
4✔
411

412
    def __gt__(self, other) -> bool:
4✔
413
        """Overwrite for dynamic break points."""
414
        from znflow import resolve, get_graph, empty_graph
4✔
415

416
        if get_graph() is empty_graph:
4✔
NEW
417
            return super().__gt__(other)
×
418
        return resolve(self).__gt__(other)
4✔
419

420
    def __ge__(self, other) -> bool:
4✔
421
        """Overwrite for dynamic break points."""
422
        from znflow import resolve, get_graph, empty_graph
4✔
423

424
        if get_graph() is empty_graph:
4✔
NEW
425
            return super().__ge__(other)
×
426
        return resolve(self).__ge__(other)
4✔
427

428
    def __iter__(self):
4✔
429
        from znflow import resolve
4✔
430

431
        try:
4✔
432
            return resolve(self).__iter__()
4✔
433
        except AttributeError:
4✔
434
            raise TypeError(f"'{self}' object is not iterable")
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc