• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

zincware / ZnFlow / 13105848089

16 Dec 2024 09:21AM UTC coverage: 96.786%. Remained the same
13105848089

push

github

web-flow
raise error on `append` and `extend` connections (#128)

* add __len__

* raise error on append/extend

* update functionfuture as well

* test connections

* assert type

49 of 51 new or added lines in 2 files covered. (96.08%)

10 existing lines in 4 files now uncovered.

2710 of 2800 relevant lines covered (96.79%)

3.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.2
/znflow/base.py
1
from __future__ import annotations
4✔
2

3
import contextlib
4✔
4
import dataclasses
4✔
5
import typing
4✔
6
from typing import Any
4✔
7
from uuid import UUID
4✔
8

9
from znflow import exceptions
4✔
10

11
if typing.TYPE_CHECKING:
4✔
12
    from znflow.graph import DiGraph
×
13

14

15
@contextlib.contextmanager
4✔
16
def disable_graph(*args, **kwargs):
4✔
17
    """Temporarily disable set the graph to empty.
18

19
    This can be useful, if you e.g. want to use 'get_attribute'.
20
    """
21
    graph = get_graph()
4✔
22
    set_graph(empty_graph)
4✔
23
    try:
4✔
24
        yield
4✔
25
    finally:
26
        set_graph(graph)
4✔
27

28

29
class Property:
4✔
30
    """Custom Property with disabled graph.
4✔
31

32
    References
33
    ----------
34
    Adapted from https://docs.python.org/3/howto/descriptor.html#properties
35
    """
36

37
    def __init__(self, fget=None, fset=None, fdel=None, doc=None):
4✔
38
        self.fget = disable_graph()(fget)
4✔
39
        self.fset = disable_graph()(fset)
4✔
40
        self.fdel = disable_graph()(fdel)
4✔
41
        if doc is None and fget is not None:
4✔
42
            doc = fget.__doc__
4✔
43
        self.__doc__ = doc
4✔
44
        self._name = ""
4✔
45

46
    def __set_name__(self, owner, name):
4✔
47
        self._name = name
4✔
48

49
    def __get__(self, obj, objtype=None):
4✔
50
        if obj is None:
4✔
51
            return self
4✔
52
        if self.fget is None:
4✔
53
            raise AttributeError(f"property '{self._name}' has no getter")
×
54
        return self.fget(obj)
4✔
55

56
    def __set__(self, obj, value):
4✔
57
        if self.fset is None:
4✔
58
            raise AttributeError(f"property '{self._name}' has no setter")
×
59
        self.fset(obj, value)
4✔
60

61
    def __delete__(self, obj):
4✔
62
        if self.fdel is None:
×
63
            raise AttributeError(f"property '{self._name}' has no deleter")
×
64
        self.fdel(obj)
×
65

66
    def getter(self, fget):
4✔
67
        prop = type(self)(fget, self.fset, self.fdel, self.__doc__)
×
68
        prop._name = self._name
×
69
        return prop
×
70

71
    def setter(self, fset):
4✔
72
        prop = type(self)(self.fget, fset, self.fdel, self.__doc__)
4✔
73
        prop._name = self._name
4✔
74
        return prop
4✔
75

76
    def deleter(self, fdel):
4✔
77
        prop = type(self)(self.fget, self.fset, fdel, self.__doc__)
×
78
        prop._name = self._name
×
79
        return prop
×
80

81

82
@dataclasses.dataclass(frozen=True)
4✔
83
class EmptyGraph:
4✔
84
    """An empty class used as a default value for _graph_."""
4✔
85

86

87
empty_graph = EmptyGraph()
4✔
88

89

90
class NodeBaseMixin:
4✔
91
    """A Parent for all Nodes.
4✔
92

93
    This class is used to globally access and change all classes that inherit from it.
94

95
    Attributes
96
    ----------
97
        _graph_ : DiGraph
98
            The graph this node belongs to.
99
            This is only available within the graph context.
100
        uuid : UUID
101
            The unique identifier of this node.
102
        _external_ : bool
103
            If true, the node is allowed to be created outside of a graph context.
104
            In this case connections can be created to this node, otherwise
105
            an exception is raised.
106
    """
107

108
    _graph_ = empty_graph
4✔
109
    _external_ = False
4✔
110
    _uuid: UUID = None
4✔
111
    _znflow_resolved: bool = False
4✔
112

113
    _protected_ = [
4✔
114
        "_graph_",
115
        "uuid",
116
        "_uuid",
117
        "model_fields",  # pydantic
118
        "model_computed_fields",  # pydantic
119
    ]
120

121
    @property
4✔
122
    def uuid(self):
4✔
123
        return self._uuid
4✔
124

125
    @uuid.setter
4✔
126
    def uuid(self, value):
4✔
127
        if self._uuid is not None:
4✔
128
            raise ValueError("uuid is already set")
×
129
        self._uuid = value
4✔
130

131
    def run(self):
4✔
132
        raise NotImplementedError
133

134

135
def get_graph() -> DiGraph:
4✔
136
    return NodeBaseMixin._graph_
4✔
137

138

139
def set_graph(value):
4✔
140
    NodeBaseMixin._graph_ = value
4✔
141

142

143
_get_attribute_none = object()
4✔
144

145

146
def get_attribute(obj, name, default=_get_attribute_none):
4✔
147
    """Get the real value of the attribute and not a znflow.Connection."""
148
    with disable_graph():
4✔
149
        if default is _get_attribute_none:
4✔
150
            return getattr(obj, name)
4✔
151
        return getattr(obj, name, default)
4✔
152

153

154
@dataclasses.dataclass(frozen=True)
4✔
155
class Connection:
4✔
156
    """A Connector for Nodes.
4✔
157

158
    Attributes
159
    ----------
160
        instance: Node|FunctionFuture
161
            the object this connection points to
162
        attribute: str
163
            Node.attribute
164
            or FunctionFuture.result
165
            or None if the class is passed and not an attribute
166
        item: any
167
            any slice or list index to be applied to the result
168
    """
169

170
    instance: Any = dataclasses.field(repr=False)
4✔
171
    attribute: str
4✔
172
    item: Any = None
4✔
173

174
    def __post_init__(self):
4✔
175
        if self.attribute is not None and self.attribute.startswith("_"):
4✔
176
            raise ValueError("Private attributes are not allowed.")
×
177

178
    def __getitem__(self, item):
4✔
179
        return dataclasses.replace(self, instance=self, attribute=None, item=item)
4✔
180

181
    def __add__(
4✔
182
        self, other: typing.Union[Connection, FunctionFuture, CombinedConnections]
183
    ) -> CombinedConnections:
184
        if isinstance(other, (Connection, FunctionFuture, CombinedConnections)):
4✔
185
            return CombinedConnections(connections=[self, other])
4✔
186
        raise TypeError(f"Can not add {type(other)} to {type(self)}.")
4✔
187

188
    def __radd__(self, other):
4✔
189
        """Enable 'sum([a, b], [])'"""
190
        return self if other == [] else self.__add__(other)
4✔
191

192
    @property
4✔
193
    def uuid(self):
4✔
194
        return self.instance.uuid
4✔
195

196
    @property
4✔
197
    def _external_(self):
4✔
198
        return self.instance._external_
4✔
199

200
    @property
4✔
201
    def result(self):
4✔
202
        if self.attribute:
4✔
203
            result = getattr(self.instance, self.attribute)
4✔
204
        elif isinstance(self.instance, (FunctionFuture, self.__class__)):
4✔
205
            result = self.instance.result
4✔
206
        else:
207
            result = self.instance
4✔
208
        return result[self.item] if self.item else result
4✔
209

210
    def __getattribute__(self, __name: str) -> Any:
4✔
211
        try:
4✔
212
            return super().__getattribute__(__name)
4✔
213
        except AttributeError as e:
4✔
214
            raise exceptions.ConnectionAttributeError(
4✔
215
                "Connection does not support further attributes to its result."
216
            ) from e
217

218
    def __eq__(self, other) -> bool:
4✔
219
        """Overwrite for dynamic break points."""
220
        from znflow import resolve, get_graph, empty_graph
4✔
221

222
        if isinstance(other, (Connection)):
4✔
223
            return self.instance == other.instance
4✔
224
        if isinstance(other, (FunctionFuture)):
4✔
225
            return False
×
226

227
        if get_graph() is empty_graph:
4✔
228
            return super().__eq__(other)
4✔
229
        return resolve(self).__eq__(other)
4✔
230

231
    def __lt__(self, other) -> bool:
4✔
232
        """Overwrite for dynamic break points."""
233
        from znflow import resolve, get_graph, empty_graph
4✔
234

235
        if get_graph() is empty_graph:
4✔
236
            return super().__lt__(other)
×
237
        return resolve(self).__lt__(other)
4✔
238

239
    def __le__(self, other) -> bool:
4✔
240
        """Overwrite for dynamic break points."""
241
        from znflow import resolve, get_graph, empty_graph
4✔
242

243
        if get_graph() is empty_graph:
4✔
244
            return super().__le__(other)
×
245
        return resolve(self).__le__(other)
4✔
246

247
    def __gt__(self, other) -> bool:
4✔
248
        """Overwrite for dynamic break points."""
249
        from znflow import resolve, get_graph, empty_graph
4✔
250

251
        if get_graph() is empty_graph:
4✔
252
            return super().__gt__(other)
×
253
        return resolve(self).__gt__(other)
4✔
254

255
    def __ge__(self, other) -> bool:
4✔
256
        """Overwrite for dynamic break points."""
257
        from znflow import resolve, get_graph, empty_graph
4✔
258

259
        if get_graph() is empty_graph:
4✔
260
            return super().__ge__(other)
×
261
        return resolve(self).__ge__(other)
4✔
262

263
    def __iter__(self):
4✔
264
        from znflow import resolve
4✔
265

266
        try:
4✔
267
            return resolve(self).__iter__()
4✔
268
        except AttributeError:
4✔
269
            raise TypeError(f"'{self}' object is not iterable")
4✔
270

271
    def extend(self, *args) -> None:
4✔
272
        """
273
        Raises
274
        ------
275
        TypeError
276
            If the method is called.
277
        """
278
        raise TypeError("Connections can not be extended. Use 'self += other' instead.")
4✔
279

280
    def append(self, *args) -> None:
4✔
281
        """
282
        Raises
283
        ------
284
        TypeError
285
            If the method is called.
286
        """
287
        raise TypeError("Connections can not be appended.")
4✔
288

289

290
@dataclasses.dataclass(frozen=True)
4✔
291
class CombinedConnections:
4✔
292
    """Combine multiple Connections into one.
4✔
293

294
    This class allows to 'add' Connections and/or FunctionFutures.
295
    This only works if the Connection or FunctionFuture points to a 'list'.
296
    A new entry of 'CombinedConnections' will be created for every time a new
297
    item is added.
298

299
    Examples
300
    --------
301

302
    >>> import znflow
303
    >>> @znflow.nodfiy
304
    >>> def add(size) -> list:
305
    >>>     return list(range(size))
306
    >>> with znflow.DiGraph() as graph:
307
    >>>     outs = add(2) + add(3)
308
    >>> graph.run()
309
    >>> assert outs.result == [0, 1, 0, 1, 2]
310

311
    Attributes
312
    ----------
313
    connections : list[Connection|FunctionFuture|AddedConnections]
314
        The List of items to be added.
315
    item : any
316
        Any slice to be applied to the result.
317
    """
318

319
    connections: typing.List[Connection]
4✔
320
    item: any = None
4✔
321

322
    def __add__(
4✔
323
        self, other: typing.Union[Connection, FunctionFuture, CombinedConnections]
324
    ) -> CombinedConnections:
325
        """Implement add for AddedConnections.
326

327
        Raises
328
        ------
329
        ValueError
330
            If  self.item is set, we can not add another item.
331
        TypeError
332
            If other is not a Connection, FunctionFuture or AddedConnections.
333
        """
334
        if self.item is not None:
4✔
335
            raise ValueError("Can not combine multiple slices")
4✔
336
        if isinstance(other, (Connection, FunctionFuture)):
4✔
337
            return dataclasses.replace(self, connections=self.connections + [other])
4✔
338
        elif isinstance(other, CombinedConnections):
4✔
339
            return dataclasses.replace(
4✔
340
                self, connections=self.connections + other.connections
341
            )
342
        else:
343
            raise TypeError(f"Can not add {type(other)} to {type(self)}.")
4✔
344

345
    def __radd__(self, other):
4✔
346
        """Enable 'sum([a, b], [])'"""
347
        return self if other == [] else self.__add__(other)
4✔
348

349
    def __getitem__(self, item):
4✔
350
        return dataclasses.replace(self, item=item)
4✔
351

352
    def __len__(self) -> int:
4✔
353
        return len(self.connections)
4✔
354

355
    def __iter__(self):
4✔
UNCOV
356
        raise TypeError(f"Can not iterate over {self}.")
×
357

358
    @property
4✔
359
    def result(self):
4✔
360
        try:
4✔
361
            results = []
4✔
362
            for connection in self.connections:
4✔
363
                results.extend(connection.result)
4✔
364
            return results[self.item] if self.item else results
4✔
365
        except TypeError as err:
4✔
366
            raise TypeError(
4✔
367
                f"The value {connection.result} is of type {type(connection.result)}. The"
368
                f" only supported type is list. Please change {connection}"
369
            ) from err
370

371
    def extend(self, *args) -> None:
4✔
372
        """
373
        Raises
374
        ------
375
        TypeError
376
            If the method is called.
377
        """
NEW
378
        raise TypeError(
×
379
            "CombinedConnections can not be extended. Use 'self += other' instead."
380
        )
381

382
    def append(self, *args) -> None:
4✔
383
        """
384
        Raises
385
        ------
386
        TypeError
387
            If the method is called.
388
        """
NEW
389
        raise TypeError("CombinedConnections can not be appended.")
×
390

391

392
@dataclasses.dataclass
4✔
393
class FunctionFuture(NodeBaseMixin):
4✔
394
    function: typing.Callable
4✔
395
    args: typing.Tuple
4✔
396
    kwargs: typing.Dict
4✔
397
    item: any = None
4✔
398

399
    result: any = dataclasses.field(default=None, init=False, repr=True)
4✔
400

401
    _protected_ = NodeBaseMixin._protected_ + ["function", "args", "kwargs"]
4✔
402

403
    def run(self):
4✔
404
        self.result = self.function(*self.args, **self.kwargs)
4✔
405

406
    def __getitem__(self, item):
4✔
407
        return Connection(instance=self, attribute=None, item=item)
4✔
408

409
    def __add__(
4✔
410
        self, other: typing.Union[Connection, FunctionFuture, CombinedConnections]
411
    ) -> CombinedConnections:
412
        if isinstance(other, (Connection, FunctionFuture, CombinedConnections)):
4✔
413
            return CombinedConnections(connections=[self, other])
4✔
414
        raise TypeError(f"Can not add {type(other)} to {type(self)}.")
4✔
415

416
    def __radd__(self, other):
4✔
417
        """Enable 'sum([a, b], [])'"""
418
        return self if other == [] else self.__add__(other)
4✔
419

420
    def __eq__(self, other) -> bool:
4✔
421
        """Overwrite for dynamic break points."""
422
        from znflow import resolve, get_graph, empty_graph
4✔
423

424
        if isinstance(other, (Connection)):
4✔
UNCOV
425
            return False
×
426
        if isinstance(other, (FunctionFuture)):
4✔
427
            return (
4✔
428
                self.function == other.function
429
                and self.args == other.args
430
                and self.kwargs == other.kwargs
431
                and self.item == other.item
432
            )
433

434
        if get_graph() is empty_graph:
4✔
435
            return super().__eq__(other)
4✔
436
        return resolve(self).__eq__(other)
4✔
437

438
    def __lt__(self, other) -> bool:
4✔
439
        """Overwrite for dynamic break points."""
440
        from znflow import resolve, get_graph, empty_graph
4✔
441

442
        if get_graph() is empty_graph:
4✔
UNCOV
443
            return super().__lt__(other)
×
444
        return resolve(self).__lt__(other)
4✔
445

446
    def __le__(self, other) -> bool:
4✔
447
        """Overwrite for dynamic break points."""
448
        from znflow import resolve, get_graph, empty_graph
4✔
449

450
        if get_graph() is empty_graph:
4✔
UNCOV
451
            return super().__le__(other)
×
452
        return resolve(self).__le__(other)
4✔
453

454
    def __gt__(self, other) -> bool:
4✔
455
        """Overwrite for dynamic break points."""
456
        from znflow import resolve, get_graph, empty_graph
4✔
457

458
        if get_graph() is empty_graph:
4✔
UNCOV
459
            return super().__gt__(other)
×
460
        return resolve(self).__gt__(other)
4✔
461

462
    def __ge__(self, other) -> bool:
4✔
463
        """Overwrite for dynamic break points."""
464
        from znflow import resolve, get_graph, empty_graph
4✔
465

466
        if get_graph() is empty_graph:
4✔
UNCOV
467
            return super().__ge__(other)
×
468
        return resolve(self).__ge__(other)
4✔
469

470
    def __iter__(self):
4✔
471
        from znflow import resolve
4✔
472

473
        try:
4✔
474
            return resolve(self).__iter__()
4✔
475
        except AttributeError:
4✔
476
            raise TypeError(f"'{self}' object is not iterable")
4✔
477

478
    def extend(self, *args) -> None:
4✔
479
        """
480
        Raises
481
        ------
482
        TypeError
483
            If the method is called.
484
        """
485
        raise TypeError(
4✔
486
            "FunctionFuture can not be extended. Use 'self += other' instead."
487
        )
488

489
    def append(self, *args) -> None:
4✔
490
        """
491
        Raises
492
        ------
493
        TypeError
494
            If the method is called.
495
        """
496
        raise TypeError("FunctionFuture can not be appended.")
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc