• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

apowers313 / roc / 10580582352

26 Aug 2024 02:45PM UTC coverage: 93.116% (-0.7%) from 93.864%
10580582352

push

github

apowers313
fix weakref bug in fe_list

6 of 6 new or added lines in 1 file covered. (100.0%)

37 existing lines in 5 files now uncovered.

1853 of 1990 relevant lines covered (93.12%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.45
/roc/graphdb.py
1
"""This module is a wrapper around a graph database and abstracts away all the
1✔
2
database-specific features as various classes (GraphDB, Node, Edge, etc)
3
"""
4

5
from __future__ import annotations
1✔
6

7
import warnings
1✔
8
from collections.abc import Iterator, Mapping, MutableSet
1✔
9
from itertools import islice
1✔
10
from typing import Any, Callable, Generic, Literal, NewType, TypeVar, cast
1✔
11

12
import mgclient
1✔
13
import networkx as nx
1✔
14
from cachetools import LRUCache
1✔
15
from pydantic import BaseModel, Field, field_validator
1✔
16
from typing_extensions import Self
1✔
17

18
from .config import Config
1✔
19
from .logger import logger
1✔
20

21
RecordFn = Callable[[str, Iterator[Any]], None]
1✔
22
CacheType = TypeVar("CacheType")
1✔
23
CacheId = TypeVar("CacheId")
1✔
24
EdgeId = NewType("EdgeId", int)
1✔
25
NodeId = NewType("NodeId", int)
1✔
26
next_new_edge: EdgeId = cast(EdgeId, -1)
1✔
27
next_new_node: NodeId = cast(NodeId, -1)
1✔
28

29

30
def true_filter(_: Any) -> bool:
1✔
31
    """Helper function that accepts any value and returns True. Great for
32
    default filters.
33
    """
34
    return True
1✔
35

36

37
def no_callback(_: Any) -> None:
1✔
38
    """Helper function that accepts any value and returns None. Great for
39
    default callback functions.
40
    """
41
    pass
×
42

43

44
class ErrorSavingDuringDelWarning(Warning):
1✔
45
    """An error that occurs while saving a Node during __del__"""
1✔
46

47
    pass
1✔
48

49

50
class GraphDBInternalError(Exception):
1✔
51
    """An generic exception for unexpected errors"""
1✔
52

53
    pass
1✔
54

55

56
#########
57
# GRAPHDB
58
#########
59
graph_db_singleton: GraphDB | None = None
1✔
60

61

62
class GraphDB:
1✔
63
    """A graph database singleton. Settings for the graph database come from the config module."""
1✔
64

65
    def __init__(self) -> None:
1✔
66
        settings = Config.get()
1✔
67
        self.host = settings.db_host
1✔
68
        self.port = settings.db_port
1✔
69
        self.encrypted = settings.db_conn_encrypted
1✔
70
        self.username = settings.db_username
1✔
71
        self.password = settings.db_password
1✔
72
        self.lazy = settings.db_lazy
1✔
73
        self.client_name = "roc-graphdb-client"
1✔
74
        self.db_conn = self.connect()
1✔
75
        self.closed = False
1✔
76

77
    def raw_fetch(
1✔
78
        self, query: str, *, params: dict[str, Any] | None = None
79
    ) -> Iterator[dict[str, Any]]:
80
        """Executes a Cypher query and returns the results as an iterator of
81
        dictionaries. Used for any query that has a 'RETURN' clause.
82

83
        Args:
84
            query (str): The Cypher query to execute
85
            params (dict[str, Any] | None, optional): Any parameters to pass to
86
                the query. Defaults to None. See also: https://memgraph.com/docs/querying/expressions#parameters
87

88
        Yields:
89
            Iterator[dict[str, Any]]: An iterator of the results from the database.
90
        """
91
        params = params or {}
1✔
92
        logger.trace(f"raw_fetch: '{query}' *** with params: *** '{params}")
1✔
93

94
        cursor = self.db_conn.cursor()
1✔
95
        cursor.execute(query, params)
1✔
96
        while True:
1✔
97
            row = cursor.fetchone()
1✔
98
            if row is None:
1✔
99
                break
1✔
100
            yield {dsc.name: row[index] for index, dsc in enumerate(cursor.description)}
1✔
101

102
    def raw_execute(self, query: str, *, params: dict[str, Any] | None = None) -> None:
1✔
103
        """Executes a query with no return value. Used for 'SET', 'DELETE' or
104
        other queries without a 'RETURN' clause.
105

106
        Args:
107
            query (str): The Cypher query to execute
108
            params (dict[str, Any] | None, optional): Any parameters to pass to
109
                the query. Defaults to None. See also: https://memgraph.com/docs/querying/expressions#parameters
110
        """
111
        params = params or {}
1✔
112
        logger.trace(f"raw_execute: '{query}' *** with params: *** '{params}'")
1✔
113

114
        cursor = self.db_conn.cursor()
1✔
115
        cursor.execute(query, params)
1✔
116
        cursor.fetchall()
1✔
117

118
    def connected(self) -> bool:
1✔
119
        """Returns True if the database is connected, False otherwise"""
120
        return self.db_conn is not None and self.db_conn.status == mgclient.CONN_STATUS_READY
×
121

122
    def connect(self) -> mgclient.Connection:
1✔
123
        """Connects to the database and returns a Connection object"""
124
        sslmode = mgclient.MG_SSLMODE_REQUIRE if self.encrypted else mgclient.MG_SSLMODE_DISABLE
1✔
125
        connection = mgclient.connect(
1✔
126
            host=self.host,
127
            port=self.port,
128
            username=self.username,
129
            password=self.password,
130
            sslmode=sslmode,
131
            lazy=self.lazy,
132
            client_name=self.client_name,
133
        )
134
        connection.autocommit = True
1✔
135
        return connection
1✔
136

137
    def close(self) -> None:
1✔
138
        """Closes the connection to the database"""
139
        self.db_conn.close()
1✔
140
        self.closed = True
1✔
141

142
    @classmethod
1✔
143
    def singleton(cls) -> GraphDB:
1✔
144
        """This returns a singleton object for the graph database. If the
145
        singleton isn't created yet, it creates it.
146
        """
147
        global graph_db_singleton
148
        if not graph_db_singleton:
1✔
149
            graph_db_singleton = GraphDB()
1✔
150

151
        assert graph_db_singleton.closed is False
1✔
152
        return graph_db_singleton
1✔
153

154
    @staticmethod
1✔
155
    def to_networkx(
1✔
156
        db: GraphDB | None = None,
157
        node_ids: set[NodeId] | None = None,
158
        filter: NodeFilterFn | None = None,
159
    ) -> nx.DiGraph:
160
        """Converts the entire graph database (and local cache of objects) into
161
        a NetworkX graph
162

163
        Args:
164
            db (GraphDB | None, optional): The database to convert to NetworkX.
165
                Defaults to the GraphDB singleton if not specified.
166
            node_ids (set[NodeId] | None, optional): The NodeIDs to add to the
167
                NetworkX graph. Defaults to all IDs if not specified.
168
            filter (NodeFilterFn | None, optional): A Node filter to filter out
169
                nodes before adding them to the NetworkX graph. Also useful for a
170
                callback that can be used for progress updates. Defaults to None.
171

172
        Returns:
173
            nx.DiGraph: _description_
174
        """
175
        db = db or GraphDB.singleton()
×
176
        node_ids = node_ids or Node.all_ids(db=db)
×
177
        filter = filter or true_filter
×
178
        G = nx.DiGraph()
×
179

180
        def nx_add(n: Node) -> None:
×
181
            n_data = Node.to_dict(n, include_labels=True)
×
182

183
            # TODO: this converts labels to a string, but maybe there's a better
184
            # way to preserve the list so that it can be used for filtering in
185
            # external programs
186
            if "labels" in n_data and isinstance(n_data["labels"], set):
×
187
                n_data["labels"] = ", ".join(n_data["labels"])
×
188

189
            G.add_node(n.id, **n_data)
×
190

191
            for e in n.src_edges:
×
192
                e_data = Edge.to_dict(e, include_type=True)
×
193
                G.add_edge(e.src_id, e.dst_id, **e_data)
×
194

195
        # iterate all specified node_ids, adding all of them to the nx graph
196
        def nx_add_many(nodes: list[Node]) -> None:
×
197
            for n in nodes:
×
198
                if filter(n):
×
199
                    nx_add(n)
×
200

201
        Node.get_many(node_ids, load_edges=True, progress_callback=nx_add_many)
×
202

203
        return G
×
204

205

206
#######
207
# CACHE
208
#######
209
CacheKey = TypeVar("CacheKey")
1✔
210
CacheValue = TypeVar("CacheValue")
1✔
211
CacheDefault = TypeVar("CacheDefault")
1✔
212

213

214
class GraphCache(LRUCache[CacheKey, CacheValue], Generic[CacheKey, CacheValue]):
1✔
215
    """A generic cache that is used for both the Node cache and the Edge cache"""
1✔
216

217
    def __init__(self, maxsize: int):
1✔
218
        super().__init__(maxsize=maxsize)
1✔
219
        self.hits = 0
1✔
220
        self.misses = 0
1✔
221

222
    def __str__(self) -> str:
1✔
223
        return f"Size: {self.currsize}/{self.maxsize} ({self.currsize/self.maxsize*100:1.2f}%), Hits: {self.hits}, Misses: {self.misses}"
1✔
224

225
    def get(  # type: ignore [override]
1✔
226
        self,
227
        key: CacheKey,
228
        /,
229
        default: CacheValue | None = None,
230
    ) -> CacheValue | None:
231
        """Uses the specified CacheKey to fetch an object from the cache.
232

233
        Args:
234
            key (CacheKey): The key to use to fetch the object
235
            default (CacheValue | None, optional): If the object isn't found,
236
                the default value to return. Defaults to None.
237

238
        Returns:
239
            CacheValue | None: The object from the cache, or None if not found.
240
        """
241
        v = super().get(key)
1✔
242
        if not v:
1✔
243
            self.misses = self.misses + 1
1✔
244
            if self.currsize == self.maxsize:
1✔
245
                logger.warning(
×
246
                    f"Cache miss and cache is full ({self.currsize}/{self.maxsize}). Cache may start thrashing and performance may be impaired."
247
                )
248
        else:
249
            self.hits = self.hits + 1
1✔
250
        return v
1✔
251

252
    def clear(self) -> None:
1✔
253
        """Clears out all items from the cache and resets the cache
254
        statistics
255
        """
256
        super().clear()
1✔
257
        self.hits = 0
1✔
258
        self.misses = 0
1✔
259

260

261
#######
262
# EDGE
263
#######
264
class EdgeNotFound(Exception):
1✔
265
    pass
1✔
266

267

268
class EdgeCreateFailed(Exception):
1✔
269
    pass
1✔
270

271

272
def get_next_new_edge_id() -> EdgeId:
1✔
273
    global next_new_edge
274
    id = next_new_edge
1✔
275
    next_new_edge = cast(EdgeId, next_new_edge - 1)
1✔
276

277
    return id
1✔
278

279

280
class Edge(BaseModel, extra="allow"):
1✔
281
    """An edge (a.k.a. Relationship or Connection) between two Nodes. An edge obect automatically
1✔
282
    implements all phases of CRUD in the underlying graph database. This is a directional
283
    relationship with a "source" and "destination". The source and destination properties
284
    are dynamically loaded through property getters when they are called, and may trigger
285
    a graph database query if they don't already exist in the edge cache.
286
    """
287

288
    id: EdgeId = Field(exclude=True)
1✔
289
    # XXX: type, src_id, and dst_id used to be pydantic literals, but updating
290
    # the pydantic version broke them
291
    type: str = Field(exclude=True)
1✔
292
    src_id: NodeId = Field(exclude=True)
1✔
293
    dst_id: NodeId = Field(exclude=True)
1✔
294
    _no_save = False
1✔
295
    _new = False
1✔
296
    _deleted = False
1✔
297

298
    @field_validator("id", mode="before")
1✔
299
    def default_id(cls, id: EdgeId | None) -> EdgeId:
1✔
300
        if isinstance(id, int):
1✔
301
            return id
1✔
302

303
        return get_next_new_edge_id()
1✔
304

305
    @property
1✔
306
    def src(self) -> Node:
1✔
307
        return Node.get(self.src_id)
1✔
308

309
    @property
1✔
310
    def dst(self) -> Node:
1✔
311
        return Node.get(self.dst_id)
1✔
312

313
    @property
1✔
314
    def new(self) -> bool:
1✔
315
        return self._new
1✔
316

317
    def __init__(
1✔
318
        self,
319
        src_id: NodeId,
320
        dst_id: NodeId,
321
        type: str,
322
        *,
323
        id: EdgeId | None = None,
324
        data: dict[Any, Any] | None = None,
325
    ):
326
        data = data or {}
1✔
327
        super().__init__(
1✔
328
            src_id=src_id,
329
            dst_id=dst_id,
330
            type=type,
331
            id=id,
332
            **data,
333
        )
334

335
        if self.id < 0:
1✔
336
            self._new = True
1✔
337
            Edge.get_cache()[self.id] = self
1✔
338

339
    def __del__(self) -> None:
1✔
340
        # print("Edge.__del__:", self)
341
        Edge.save(self)
1✔
342

343
    def __repr__(self) -> str:
1✔
344
        return f"Edge({self.id} [{self.src_id}>>{self.dst_id}])"
×
345

346
    @classmethod
1✔
347
    def get_cache(self) -> EdgeCache:
1✔
348
        global edge_cache
349
        if edge_cache is None:
1✔
350
            settings = Config.get()
1✔
351
            edge_cache = EdgeCache(maxsize=settings.edge_cache_size)
1✔
352

353
        return edge_cache
1✔
354

355
    @classmethod
1✔
356
    def get(cls, id: EdgeId, *, db: GraphDB | None = None) -> Self:
1✔
357
        """Looks up an Edge based on it's ID. If the Edge is cached, the cached edge is returned;
358
        otherwise the Edge is queried from the graph database based the ID provided and a new
359
        Edge is returned and cached.
360

361
        Args:
362
            id (EdgeId): the unique identifier for the Edge
363
            db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton
364

365
        Returns:
366
            Self: returns the Edge requested by the id
367
        """
368
        cache = Edge.get_cache()
1✔
369
        e = cache.get(id)
1✔
370
        if not e:
1✔
371
            e = cls.load(id, db=db)
1✔
372
            cache[id] = e
1✔
373

374
        return cast(Self, e)
1✔
375

376
    @classmethod
1✔
377
    def load(cls, id: EdgeId, *, db: GraphDB | None = None) -> Self:
1✔
378
        """Loads an Edge from the graph database without attempting to check if the Edge
379
        already exists in the cache. Typically this is only called by Edge.get()
380

381
        Args:
382
            id (EdgeId): the unique identifier of the Edge to fetch
383
            db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton
384

385
        Raises:
386
            EdgeNotFound: if the specified ID does not exist in the cache or the database
387

388
        Returns:
389
            Self: returns the Edge requested by the id
390
        """
391
        db = db or GraphDB.singleton()
1✔
392
        edge_list = list(db.raw_fetch(f"MATCH (n)-[e]-(m) WHERE id(e) = {id} RETURN e LIMIT 1"))
1✔
393
        if not len(edge_list) == 1:
1✔
394
            raise EdgeNotFound(f"Couldn't find edge ID: {id}")
1✔
395

396
        e = edge_list[0]["e"]
1✔
397
        props = None
1✔
398
        if hasattr(e, "properties"):
1✔
399
            props = e.properties
1✔
400
        return cls(
1✔
401
            e.start_id,
402
            e.end_id,
403
            id=id,
404
            data=props,
405
            type=e.type,
406
        )
407

408
    @classmethod
1✔
409
    def save(cls, e: Self, *, db: GraphDB | None = None) -> Self:
1✔
410
        """Saves the edge to the database. Calls Edge.create if the edge is new, or Edge.update if
411
        edge already exists in the database.
412

413
        Args:
414
            e (Self): The edge to save
415
            db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton
416

417
        Returns:
418
            Self: The same edge that was passed in, for convenience. The Edge may be updated with a
419
            new identifier if it was newly created in the database.
420
        """
421
        if e._new:
1✔
422
            return cls.create(e, db=db)
1✔
423
        else:
424
            return cls.update(e, db=db)
1✔
425

426
    @classmethod
1✔
427
    def create(cls, e: Self, *, db: GraphDB | None = None) -> Self:
1✔
428
        """Creates a new edge in the database. Typically only called by Edge.save
429

430
        Args:
431
            e (Self): The edge to create
432
            db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton
433

434
        Raises:
435
            EdgeCreateFailed: Failed to write the edge to the database, for eample
436
                if the ID is wrong.
437

438
        Returns:
439
            Self: the edge that was created, with an updated identifier and other chagned attributes
440
        """
441
        if e._no_save or e.src._no_save or e.dst._no_save:
1✔
442
            return e
1✔
443

444
        db = db or GraphDB.singleton()
1✔
445
        old_id = e.id
1✔
446

447
        if e.src._new:
1✔
448
            Node.save(e.src)
1✔
449

450
        if e.dst._new:
1✔
451
            Node.save(e.dst)
1✔
452

453
        params = {"props": Edge.to_dict(e)}
1✔
454

455
        ret = list(
1✔
456
            db.raw_fetch(
457
                f"""
458
                MATCH (src), (dst)
459
                WHERE id(src) = {e.src_id} AND id(dst) = {e.dst_id} 
460
                CREATE (src)-[e:{e.type} $props]->(dst)
461
                RETURN id(e) as e_id
462
                """,
463
                params=params,
464
            )
465
        )
466

467
        if len(ret) != 1:
1✔
468
            raise EdgeCreateFailed("failed to create new edge")
×
469

470
        e.id = ret[0]["e_id"]
1✔
471
        e._new = False
1✔
472
        # update the cache; if being called during __del__ then the cache entry may not exist
473
        try:
1✔
474
            cache = Edge.get_cache()
1✔
475
            del cache[old_id]
1✔
476
            cache[e.id] = e
1✔
477
        except KeyError:
×
478
            pass
×
479
        # update references to edge id
480
        e.src.src_edges.replace(old_id, e.id)
1✔
481
        e.dst.dst_edges.replace(old_id, e.id)
1✔
482

483
        return e
1✔
484

485
    @classmethod
1✔
486
    def update(cls, e: Self, *, db: GraphDB | None = None) -> Self:
1✔
487
        """Updates the edge in the database. Typically only called by Edge.save
488

489
        Args:
490
            e (Self): The edge to update
491
            db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton
492

493
        Returns:
494
            Self: The same edge that was passed in, for convenience
495
        """
496
        if e._no_save:
1✔
497
            return e
1✔
498

499
        db = db or GraphDB.singleton()
1✔
500

501
        params = {"props": Edge.to_dict(e)}
1✔
502

503
        db.raw_execute(f"MATCH ()-[e]->() WHERE id(e) = {e.id} SET e = $props", params=params)
1✔
504

505
        return e
1✔
506

507
    @staticmethod
1✔
508
    def delete(e: Edge, *, db: GraphDB | None = None) -> None:
1✔
509
        """Deletes the specified edge from the database. If the edge has not already been persisted
510
        to the database, this marks the edge as deleted and returns.
511

512
        Args:
513
            e (Edge): The edge to delete
514
            db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton
515
        """
516
        e._deleted = True
1✔
517
        e._no_save = True
1✔
518
        db = db or GraphDB.singleton()
1✔
519

520
        # remove e from src and dst nodes
521
        e.src.src_edges.discard(e)
1✔
522
        e.dst.dst_edges.discard(e)
1✔
523

524
        # remove from cache
525
        edge_cache = Edge.get_cache()
1✔
526
        if e.id in edge_cache:
1✔
527
            del edge_cache[e.id]
1✔
528

529
        # delete from db
530
        if not e._new:
1✔
531
            db.raw_execute(f"MATCH ()-[e]->() WHERE id(e) = {e.id} DELETE e")
1✔
532

533
    @staticmethod
1✔
534
    def to_dict(e: Edge, include_type: bool = False) -> dict[str, Any]:
1✔
535
        """Convert a Edge to a Python dictionary"""
536
        ret = e.model_dump()
1✔
537
        if include_type and hasattr(e, "type"):
1✔
538
            ret["type"] = e.type
1✔
539
        return ret
1✔
540

541
    @staticmethod
1✔
542
    def to_id(e: Edge | EdgeId) -> EdgeId:
1✔
543
        if isinstance(e, Edge):
1✔
544
            return e.id
1✔
545
        else:
546
            return e
1✔
547

548

549
EdgeCache = GraphCache[EdgeId, Edge]
1✔
550
edge_cache: EdgeCache | None = None
1✔
551

552

553
#######
554
# EDGE LIST
555
#######
556
class EdgeFetchIterator:
1✔
557
    """The implementation of an iterator for an EdgeList. Only intended to be used internally by
1✔
558
    EdgeList.
559
    """
560

561
    def __init__(self, edge_list: list[EdgeId]):
1✔
562
        self.__edge_list = edge_list
1✔
563
        self.cur = 0
1✔
564

565
    def __iter__(self) -> EdgeFetchIterator:
1✔
566
        return self
1✔
567

568
    def __next__(self) -> Edge:
1✔
569
        if self.cur >= len(self.__edge_list):
1✔
570
            raise StopIteration
1✔
571

572
        id = self.__edge_list[self.cur]
1✔
573
        self.cur = self.cur + 1
1✔
574
        return Edge.get(id)
1✔
575

576

577
EdgeFilter = Callable[[Edge], bool] | str | EdgeId | None
1✔
578

579

580
class EdgeList(MutableSet[Edge | EdgeId], Mapping[int, Edge]):
1✔
581
    """A list of Edges that is used by Node for keeping track of the connections it has.
1✔
582
    Implements interfaces for both a MutableSet (i.e. set()) and a Mapping (i.e. read-only list())
583
    """
584

585
    def __init__(self, ids: list[EdgeId] | set[EdgeId]):
1✔
586
        self.__edges: list[EdgeId] = list(ids)
1✔
587

588
    def __iter__(self) -> EdgeFetchIterator:
1✔
589
        return EdgeFetchIterator(self.__edges)
1✔
590

591
    def __getitem__(self, key: int) -> Edge:
1✔
592
        return Edge.get(self.__edges[key])
1✔
593

594
    def __len__(self) -> int:
1✔
595
        return len(self.__edges)
1✔
596

597
    def __contains__(self, e: Any) -> bool:
1✔
598
        if isinstance(e, Edge) or isinstance(e, int):
1✔
599
            e_id = Edge.to_id(e)  # type: ignore
1✔
600
        else:
601
            return False
1✔
602

603
        return e_id in self.__edges
1✔
604

605
    def add(self, e: Edge | EdgeId) -> None:
1✔
606
        """Adds a new Edge to the list"""
607
        e_id = Edge.to_id(e)
1✔
608

609
        if e_id in self.__edges:
1✔
610
            return
×
611

612
        self.__edges.append(e_id)
1✔
613

614
    def discard(self, e: Edge | EdgeId) -> None:
1✔
615
        """Removes an edge to the list"""
616
        e_id = Edge.to_id(e)
1✔
617

618
        self.__edges.remove(e_id)
1✔
619

620
    def replace(self, old: Edge | EdgeId, new: Edge | EdgeId) -> None:
1✔
621
        """Replaces all instances of an old Edge with a new Edge. Useful for when an Edge is
622
        persisted to the graph database and its permanent ID is assigned
623
        """
624
        old_id = Edge.to_id(old)
1✔
625
        new_id = Edge.to_id(new)
1✔
626
        for i in range(len(self.__edges)):
1✔
627
            if self.__edges[i] == old_id:
1✔
628
                self.__edges[i] = new_id
1✔
629

630
    def count(self, f: EdgeFilter = None) -> int:
1✔
631
        return len(self.get_edges(f))
1✔
632

633
    def get_edges(self, f: EdgeFilter = None) -> list[Edge]:
1✔
634
        if not f:
1✔
635
            return list(self.__iter__())
1✔
636

637
        if isinstance(f, str):
1✔
638
            s = f
1✔
639
            f = lambda e: e.type == s  # noqa: E731
1✔
640

641
        if isinstance(f, int):
1✔
642
            n = f
1✔
643
            f = lambda e: e.id == n  # noqa: E731
1✔
644

645
        return list(filter(f, self.__iter__()))
1✔
646

647

648
#######
649
# NODE
650
#######
651
class NodeNotFound(Exception):
1✔
652
    """An exception raised when trying to retreive a Node that doesn't exist."""
1✔
653

654
    pass
1✔
655

656

657
class NodeCreationFailed(Exception):
1✔
658
    """An exception raised when trying to create a Node in the graph database fails"""
1✔
659

660
    pass
1✔
661

662

663
def get_next_new_node_id() -> NodeId:
1✔
664
    global next_new_node
665
    id = next_new_node
1✔
666
    next_new_node = cast(NodeId, next_new_node - 1)
1✔
667
    return id
1✔
668

669

670
class Node(BaseModel, extra="allow"):
1✔
671
    """An graph database node that automatically handles CRUD for the underlying graph database objects"""
1✔
672

673
    _id: NodeId
1✔
674
    labels: set[str] = Field(exclude=True, default_factory=lambda: set())
1✔
675
    _orig_labels: set[str]
1✔
676
    _src_edges: EdgeList
1✔
677
    _dst_edges: EdgeList
1✔
678
    _db: GraphDB
1✔
679
    _new = False
1✔
680
    _no_save = False
1✔
681
    _deleted = False
1✔
682

683
    @property
1✔
684
    def id(self) -> NodeId:
1✔
685
        return self._id
1✔
686

687
    @property
1✔
688
    def src_edges(self) -> EdgeList:
1✔
689
        return self._src_edges
1✔
690

691
    @property
1✔
692
    def dst_edges(self) -> EdgeList:
1✔
693
        return self._dst_edges
1✔
694

695
    @property
1✔
696
    def new(self) -> bool:
1✔
697
        return self._new
1✔
698

699
    def __init__(
1✔
700
        self,
701
        **kwargs: Any,
702
    ):
703
        super().__init__(**kwargs)
1✔
704

705
        # set passed-in private values or their defaults
706
        self._db = kwargs["_db"] if "_db" in kwargs else GraphDB.singleton()
1✔
707
        self._id = kwargs["_id"] if "_id" in kwargs else get_next_new_node_id()
1✔
708
        self._src_edges = kwargs["_src_edges"] if "_src_edges" in kwargs else EdgeList([])
1✔
709
        self._dst_edges = kwargs["_dst_edges"] if "_dst_edges" in kwargs else EdgeList([])
1✔
710

711
        if self.id < 0:
1✔
712
            self._new = True  # TODO: derived?
1✔
713
            Node.get_cache()[self.id] = self
1✔
714

715
        self._orig_labels = self.labels.copy()
1✔
716

717
    def __del__(self) -> None:
1✔
718
        # print("Node.__del__:", self)
719
        try:
1✔
720
            self.__class__.save(self, db=self._db)
1✔
721
        except Exception as e:
1✔
722
            err_msg = f"error saving during del: {e}"
1✔
723
            # logger.warning(err_msg)
724
            warnings.warn(err_msg, ErrorSavingDuringDelWarning)
1✔
725

726
    def __repr__(self) -> str:
1✔
727
        return f"Node({self.id})"
1✔
728

729
    def __str__(self) -> str:
1✔
UNCOV
730
        return f"Node({self.id}, labels={self.labels})"
×
731

732
    @classmethod
1✔
733
    def load(cls, id: NodeId, *, db: GraphDB | None = None) -> Self:
1✔
734
        """Loads a node from the database. Use `Node.get` or other methods instead.
735

736
        Args:
737
            id (NodeId): The identifier of the node to fetch
738
            db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton
739

740
        Raises:
741
            NodeNotFound: The node specified by the identifier does not exist in the database
742
            GraphDBInternalError: If the requested ID returns multiple nodes
743

744
        Returns:
745
            Self: The node from the database
746
        """
747
        res = cls.load_many(
1✔
748
            {
749
                id,
750
            },
751
            db=db,
752
        )
753

754
        # print("RES", res)
755

756
        if len(res) < 1:
1✔
757
            raise NodeNotFound(f"Couldn't find node ID: {id}")
×
758

759
        if len(res) > 1:
1✔
760
            raise GraphDBInternalError(
×
761
                f"Too many nodes returned while trying to load single node: {id}"
762
            )
763

764
        return res[0]
1✔
765

766
    @classmethod
1✔
767
    def load_many(
1✔
768
        cls, node_set: set[NodeId], db: GraphDB | None = NotImplemented, load_edges: bool = False
769
    ) -> list[Self]:
770
        db = db or GraphDB.singleton()
1✔
771
        node_ids = ",".join(map(str, node_set))
1✔
772

773
        if load_edges:
1✔
774
            edge_fmt = "e"
1✔
775
        else:
776
            edge_fmt = "{id: id(e), start: id(startNode(e)), end: id(endNode(e))}"
1✔
777
        res_iter = db.raw_fetch(
1✔
778
            f"""
779
                MATCH (n)-[e]-(m) WHERE id(n) IN [{node_ids}]
780
                RETURN n, collect({edge_fmt}) AS edges
781
                """,
782
        )
783

784
        # edges = list(
785
        #     map(lambda r: {"id": r["e_id"], "start": r["e_start"], "end": r["e_end"]}, res)
786
        # )
787
        # src_edges = list(map(lambda e: e["id"], filter(lambda e: e["start"] == id, edges)))
788
        # dst_edges = list(map(lambda e: e["id"], filter(lambda e: e["end"] == id, edges)))
789

790
        ret_list = list()
1✔
791
        for r in res_iter:
1✔
792
            n = r["n"]
1✔
793
            if n is None:
1✔
794
                raise NodeNotFound(f"Couldn't find node ID: {id}")
1✔
795

796
            if load_edges:
1✔
797
                # XXX: memgraph converts edges to Relationship objects if you
798
                # return the whole edge
799
                src_edges = list()
1✔
800
                dst_edges = list()
1✔
801
                edge_cache = Edge.get_cache()
1✔
802
                for e in r["edges"]:
1✔
803
                    # add edge_id to to the right list for the node creation below
804
                    if n.id == e.start_id:
1✔
805
                        src_edges.append(e.id)
1✔
806
                    else:
807
                        dst_edges.append(e.id)
1✔
808

809
                    # edge already loaded, continue to next one
810
                    if e.id in edge_cache:
1✔
811
                        continue
1✔
812

813
                    # create a new edge
814
                    props = None
1✔
815
                    if hasattr(e, "properties"):
1✔
816
                        props = e.properties
1✔
817
                    new_edge = Edge(
1✔
818
                        e.start_id,
819
                        e.end_id,
820
                        id=e.id,
821
                        data=props,
822
                        type=e.type,
823
                    )
824
                    edge_cache[e.id] = new_edge
1✔
825
            else:
826
                src_edges = [e["id"] for e in r["edges"] if e["start"] == n.id]
1✔
827
                dst_edges = [e["id"] for e in r["edges"] if e["end"] == n.id]
1✔
828
            new_node = cls(
1✔
829
                _id=n.id,
830
                _src_edges=EdgeList(src_edges),
831
                _dst_edges=EdgeList(dst_edges),
832
                labels=n.labels,
833
                **n.properties,
834
            )
835
            ret_list.append(new_node)
1✔
836

837
        return ret_list
1✔
838

839
    @classmethod
1✔
840
    def get_many(
1✔
841
        cls,
842
        node_ids: set[NodeId],
843
        *,
844
        batch_size: int = 128,
845
        db: GraphDB | None = None,
846
        load_edges: bool = False,
847
        return_nodes: bool = False,
848
        progress_callback: ProgressFn | None = None,
849
    ) -> list[Node]:
850
        db = db or GraphDB.singleton()
1✔
851

852
        c = Node.get_cache()
1✔
853
        if len(node_ids) > c.maxsize:
1✔
854
            raise GraphDBInternalError(
×
855
                f"get_many attempting to load more nodes than cache size ({len(node_ids)} > {c.maxsize})"
856
            )
857

858
        cache_ids = set(c.keys())
1✔
859
        fetch_ids = node_ids - cache_ids
1✔
860

861
        start = 0
1✔
862
        curr = batch_size
1✔
863
        ret_list = [c[nid] for nid in c]
1✔
864
        if progress_callback:
1✔
865
            progress_callback(ret_list)
×
866
        while start < len(fetch_ids):
1✔
867
            id_set = set(islice(fetch_ids, start, curr))
1✔
868

869
            res = cls.load_many(id_set, db=db, load_edges=load_edges)
1✔
870
            for n in res:
1✔
871
                c[n.id] = n
1✔
872

873
            if progress_callback:
1✔
874
                progress_callback(res)
×
875

876
            ret_list.extend(res)
1✔
877
            # import pprint
878
            # pprint.pp(list(res))
879
            # print(f"got {len(list(res))} nodes")
880

881
            start = curr
1✔
882
            curr += batch_size
1✔
883

884
        assert len(ret_list) == len(node_ids)
1✔
885
        return ret_list
1✔
886

887
    @classmethod
1✔
888
    def get_cache(cls) -> NodeCache:
1✔
889
        global node_cache
890
        if node_cache is None:
1✔
891
            settings = Config.get()
1✔
892
            node_cache = NodeCache(settings.node_cache_size)
1✔
893

894
        return node_cache
1✔
895

896
    @classmethod
1✔
897
    def get(cls, id: NodeId, *, db: GraphDB | None = None) -> Self:
1✔
898
        """Returns a cached node with the specified id. If no node is cached, it is retrieved from
899
        the database.
900

901

902
        Args:
903
            id (NodeId): The unique identifier of the node to fetch
904
            db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton
905

906
        Returns:
907
            Self: the cached or newly retrieved node
908
        """
909
        cache = Node.get_cache()
1✔
910
        n = cache.get(id)
1✔
911
        if not n:
1✔
912
            n = cls.load(id, db=db)
1✔
913
            cache[id] = n
1✔
914

915
        return cast(Self, n)
1✔
916

917
    @classmethod
1✔
918
    def save(cls, n: Self, *, db: GraphDB | None = None) -> Self:
1✔
919
        """Save a node to persistent storage
920

921
        Writes the specified node to the GraphDB for persistent storage. If the node does not
922
        already exist in storage, it is created via the `create` method. If the node does exist, it
923
        is updated via the `update` method.
924

925
        If the _no_save flag is True on the node, the save request will be silently ignored.
926

927
        Args:
928
            n (Self): The Node to be saved
929
            db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton
930

931
        Returns:
932
            Self: As a convenience, the node that was stored is returned. This may be useful
933
            since the the id of the node may change if it was created in the database.
934
        """
935
        if n._new:
1✔
936
            return cls.create(n, db=db)
1✔
937
        else:
938
            return cls.update(n, db=db)
1✔
939

940
    @classmethod
1✔
941
    def update(cls, n: Self, *, db: GraphDB | None = None) -> Self:
1✔
942
        """Update an existing node in the GraphDB.
943

944
        Calling `save` is preferred to using this method so that the caller doesn't need to know the
945
        state of the node.
946

947
        Args:
948
            n (Self): The node to be updated
949
            db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton
950

951
        Returns:
952
            Self: The node that was passed in, for convenience
953
        """
954
        if n._no_save:
1✔
955
            return n
1✔
956

957
        db = db or GraphDB.singleton()
1✔
958

959
        orig_labels = n._orig_labels
1✔
960
        curr_labels = set(n.labels)
1✔
961
        new_labels = curr_labels - orig_labels
1✔
962
        rm_labels = orig_labels - curr_labels
1✔
963
        set_label_str = Node.mklabels(new_labels)
1✔
964
        if set_label_str:
1✔
965
            set_query = f"SET n{set_label_str}, n = $props"
1✔
966
        else:
967
            set_query = "SET n = $props"
1✔
968
        rm_label_str = Node.mklabels(rm_labels)
1✔
969
        if rm_label_str:
1✔
970
            rm_query = f"REMOVE n{rm_label_str}"
1✔
971
        else:
972
            rm_query = ""
1✔
973

974
        params = {"props": Node.to_dict(n)}
1✔
975

976
        db.raw_execute(f"MATCH (n) WHERE id(n) = {n.id} {set_query} {rm_query}", params=params)
1✔
977

978
        return n
1✔
979

980
    @classmethod
1✔
981
    def create(cls, n: Self, *, db: GraphDB | None = None) -> Self:
1✔
982
        """Creates the specified node in the GraphDB.
983

984
        Calling `save` is preferred to using this method so that the caller doesn't need to know the
985
        state of the node.
986

987
        Args:
988
            n (Self): the node to be created
989
            db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton
990

991
        Raises:
992
            NodeCreationFailed: if creating the node failed in the database
993

994
        Returns:
995
            Self: the node that was passed in, albeit with a new `id` and potenitally other new
996
            fields
997
        """
998
        if n._no_save:
1✔
999
            return n
1✔
1000

1001
        db = db or GraphDB.singleton()
1✔
1002
        old_id = n.id
1✔
1003

1004
        label_str = Node.mklabels(n.labels)
1✔
1005
        params = {"props": Node.to_dict(n)}
1✔
1006

1007
        res = list(db.raw_fetch(f"CREATE (n{label_str} $props) RETURN id(n) as id", params=params))
1✔
1008

1009
        if not len(res) >= 1:
1✔
1010
            raise NodeCreationFailed(f"Couldn't create node ID: {id}")
×
1011

1012
        new_id = res[0]["id"]
1✔
1013
        n._id = new_id
1✔
1014
        n._new = False
1✔
1015
        # update the cache; if being called during c then the cache entry may not exist
1016
        try:
1✔
1017
            cache = Node.get_cache()
1✔
1018
            del cache[old_id]
1✔
1019
            cache[new_id] = n
1✔
1020
        except KeyError:
1✔
1021
            pass
1✔
1022

1023
        for e in n.src_edges:
1✔
1024
            assert e.src_id == old_id
1✔
1025
            e.src_id = new_id
1✔
1026

1027
        for e in n.dst_edges:
1✔
1028
            assert e.dst_id == old_id
1✔
1029
            e.dst_id = new_id
1✔
1030

1031
        return n
1✔
1032

1033
    @classmethod
1✔
1034
    def connect(
1✔
1035
        cls,
1036
        src: NodeId | Self,
1037
        dst: NodeId | Self,
1038
        type: str,
1039
        *,
1040
        db: GraphDB | None = None,
1041
    ) -> Edge:
1042
        """Connects two nodes (creates an Edge between two nodes)
1043

1044
        Args:
1045
            src (NodeId | Node): _description_
1046
            dst (NodeId | Node): _description_
1047
            type (str): _description_
1048
            db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton
1049

1050
        Returns:
1051
            Edge: _description_
1052
        """
1053
        if isinstance(src, Node):
1✔
1054
            src_id = src.id
1✔
1055
        else:
1056
            src_id = src
×
1057

1058
        if isinstance(dst, Node):
1✔
1059
            dst_id = dst.id
1✔
1060
        else:
1061
            dst_id = dst
×
1062

1063
        e = Edge(src_id, dst_id, type)
1✔
1064
        src_node = cls.get(src_id, db=db)
1✔
1065
        dst_node = cls.get(dst_id, db=db)
1✔
1066
        src_node.src_edges.add(e)
1✔
1067
        dst_node.dst_edges.add(e)
1✔
1068
        return e
1✔
1069

1070
    @staticmethod
1✔
1071
    def delete(n: Node, *, db: GraphDB | None = None) -> None:
1✔
1072
        db = db or GraphDB.singleton()
1✔
1073

1074
        # remove edges
1075
        for e in n.src_edges:
1✔
1076
            Edge.delete(e)
×
1077

1078
        for e in n.dst_edges:
1✔
1079
            Edge.delete(e)
×
1080

1081
        # remove from cache
1082
        node_cache = Node.get_cache()
1✔
1083
        if n.id in node_cache:
1✔
1084
            del node_cache[n.id]
1✔
1085

1086
        if not n._new:
1✔
1087
            db.raw_execute(f"MATCH (n) WHERE id(n) = {n.id} DELETE n")
1✔
1088

1089
        n._deleted = True
1✔
1090
        n._no_save = True
1✔
1091

1092
    @staticmethod
1✔
1093
    def to_dict(n: Node, include_labels: bool = False) -> dict[str, Any]:
1✔
1094
        """Convert a Node to a Python dictionary"""
1095
        # XXX: the excluded fields below shouldn't have been included in the
1096
        # first place because Pythonic should exclude fields with underscores
1097
        ret = n.model_dump(exclude={"_id", "_src_edges", "_dst_edges"})
1✔
1098

1099
        if include_labels and hasattr(n, "labels"):
1✔
1100
            ret["labels"] = n.labels
1✔
1101

1102
        return ret
1✔
1103

1104
    @staticmethod
1✔
1105
    def mklabels(labels: set[str]) -> str:
1✔
1106
        """Converts a list of strings into proper Cypher syntax for a graph database query"""
1107
        labels_list = [i for i in labels]
1✔
1108
        labels_list.sort()
1✔
1109
        label_str = ":".join(labels_list)
1✔
1110
        if len(label_str) > 0:
1✔
1111
            label_str = ":" + label_str
1✔
1112
        return label_str
1✔
1113

1114
    @staticmethod
1✔
1115
    def all_ids(db: GraphDB | None = None) -> set[NodeId]:
1✔
1116
        """Returns an exhaustive Set of all NodeIds that exist in both the graph
1117
        database and the NodeCache
1118
        """
1119
        db = db or GraphDB.singleton()
1✔
1120

1121
        # get all NodeIds in the cache
1122
        c = Node.get_cache()
1✔
1123
        cached_ids = set(c.keys())
1✔
1124

1125
        # get all NodeIds in the database
1126
        db_ids = {n["id"] for n in db.raw_fetch("MATCH (n) RETURN id(n) as id")}
1✔
1127

1128
        # return the combination of both
1129
        return db_ids.union(cached_ids)
1✔
1130

1131
    @staticmethod
1✔
1132
    def walk(
1✔
1133
        n: Node,
1134
        *,
1135
        mode: WalkMode = "both",
1136
        edge_filter: EdgeFilterFn | None = None,
1137
        # edge_callback: EdgeCallbackFn | None = None,
1138
        node_filter: NodeFilterFn | None = None,
1139
        node_callback: NodeCallbackFn | None = None,
1140
        _walk_history: set[int] | None = None,
1141
    ) -> None:
1142
        # if we have walked this node before, just return
1143
        _walk_history = _walk_history or set()
1✔
1144
        if n.id in _walk_history:
1✔
1145
            return
1✔
1146
        _walk_history.add(n.id)
1✔
1147

1148
        edge_filter = edge_filter or true_filter
1✔
1149
        node_filter = node_filter or true_filter
1✔
1150
        # edge_callback = edge_callback or no_callback
1151
        node_callback = node_callback or no_callback
1✔
1152

1153
        # callback for this node, if not filtered
1154
        if node_filter(n):
1✔
1155
            node_callback(n)
1✔
1156
        else:
1157
            return
1✔
1158

1159
        if mode == "src" or mode == "both":
1✔
1160
            for e in n.src_edges:
1✔
1161
                if edge_filter(e):
1✔
1162
                    Node.walk(
1✔
1163
                        e.dst,
1164
                        mode=mode,
1165
                        edge_filter=edge_filter,
1166
                        # edge_callback=edge_callback,
1167
                        node_filter=node_filter,
1168
                        node_callback=node_callback,
1169
                        _walk_history=_walk_history,
1170
                    )
1171

1172
        if mode == "dst" or mode == "both":
1✔
1173
            for e in n.dst_edges:
1✔
1174
                if edge_filter(e):
1✔
1175
                    Node.walk(
1✔
1176
                        e.src,
1177
                        mode=mode,
1178
                        edge_filter=edge_filter,
1179
                        # edge_callback=edge_callback,
1180
                        node_filter=node_filter,
1181
                        node_callback=node_callback,
1182
                        _walk_history=_walk_history,
1183
                    )
1184

1185

1186
WalkMode = Literal["src", "dst", "both"]
1✔
1187
NodeFilterFn = Callable[[Node], bool]
1✔
1188
EdgeFilterFn = Callable[[Edge], bool]
1✔
1189
ProgressFn = Callable[[list[Node]], None]
1✔
1190
NodeCallbackFn = Callable[[Node], None]
1✔
1191
EdgeCallbackFn = Callable[[Edge], None]
1✔
1192

1193
NodeCache = GraphCache[NodeId, Node]
1✔
1194
node_cache: NodeCache | None = None
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc