• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

atlanticwave-sdx / sdx-controller / 26188833755

20 May 2026 08:43PM UTC coverage: 47.136% (-2.4%) from 49.499%
26188833755

Pull #521

github

web-flow
Merge c74dead67 into 5a34ae56e
Pull Request #521: Recover error connection when SDX domain becomes UP

27 of 193 new or added lines in 3 files covered. (13.99%)

8 existing lines in 3 files now uncovered.

1399 of 2968 relevant lines covered (47.14%)

0.94 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

39.44
/sdx_controller/messaging/rpc_queue_consumer.py
1
#!/usr/bin/env python
2
import json
2✔
3
import logging
2✔
4
import os
2✔
5
import threading
2✔
6
import time
2✔
7
import traceback
2✔
8
from queue import Queue
2✔
9

10
import pika
2✔
11
from sdx_datamodel.constants import (
2✔
12
    Constants,
13
    DomainStatus,
14
    MessageQueueNames,
15
    MongoCollections,
16
)
17
from sdx_datamodel.models.topology import SDX_TOPOLOGY_ID_prefix
2✔
18
from sdx_pce.models import ConnectionPath, ConnectionRequest, ConnectionSolution
2✔
19
from sdx_pce.topology.manager import TopologyManager
2✔
20

21
from sdx_controller.handlers.connection_handler import (
2✔
22
    ConnectionHandler,
23
    connection_state_machine,
24
    get_connection_status,
25
    parse_conn_status,
26
)
27
from sdx_controller.handlers.lc_message_handler import LcMessageHandler
2✔
28
from sdx_controller.models import connection
2✔
29
from sdx_controller.utils.parse_helper import ParseHelper
2✔
30

31
MQ_HOST = os.getenv("MQ_HOST")
2✔
32
MQ_PORT = os.getenv("MQ_PORT") or 5672
2✔
33
MQ_USER = os.getenv("MQ_USER") or "guest"
2✔
34
MQ_PASS = os.getenv("MQ_PASS") or "guest"
2✔
35
HEARTBEAT_INTERVAL = int(os.getenv("HEARTBEAT_INTERVAL", 30))  # seconds
2✔
36
HEARTBEAT_TOLERANCE = int(
2✔
37
    os.getenv("HEARTBEAT_TOLERANCE", 3)
38
)  # consecutive missed heartbeats allowed
39

40

41
# subscribe to the corresponding queue
42
SUB_QUEUE = MessageQueueNames.OXP_UPDATE
2✔
43

44
logger = logging.getLogger(__name__)
2✔
45

46
MongoCollections.SOLUTIONS = "solutions"
2✔
47

48

49
class HeartbeatMonitor:
2✔
50
    def __init__(self, db_instance):
2✔
51
        self.last_heartbeat = {}  # domain -> last heartbeat timestamp
2✔
52
        self.domain_status = {}  # domain -> current status (UP / UNKNOWN)
2✔
53
        self.lock = threading.Lock()
2✔
54
        self.monitoring = False
2✔
55
        self.db_instance = db_instance  # store DB instance
2✔
56

57
    def record_heartbeat(self, domain):
2✔
58
        """Record heartbeat from a domain and mark it as UP if previously UNKNOWN."""
59
        with self.lock:
×
60
            self.last_heartbeat[domain] = time.time()
×
61

62
            previous_status = self.domain_status.get(domain)
×
NEW
63
            domain_dict_from_db = None
×
NEW
64
            if previous_status is None:
×
NEW
65
                domain_dict_from_db = self.db_instance.get_value_from_db(
×
66
                    MongoCollections.DOMAINS, Constants.DOMAIN_DICT
67
                )
NEW
68
                if domain_dict_from_db:
×
NEW
69
                    previous_status = domain_dict_from_db.get(domain)
×
UNCOV
70
            self.domain_status[domain] = DomainStatus.UP
×
71

72
            # Update DB if status changed from UNKNOWN -> UP
73
            if previous_status == DomainStatus.UNKNOWN:
×
74
                logger.info(
×
75
                    f"[HeartbeatMonitor] Domain {domain} is BACK UP after missed heartbeats."
76
                )
NEW
77
                if domain_dict_from_db is None:
×
NEW
78
                    domain_dict_from_db = self.db_instance.get_value_from_db(
×
79
                        MongoCollections.DOMAINS, Constants.DOMAIN_DICT
80
                    )
NEW
81
                if domain_dict_from_db and domain in domain_dict_from_db:
×
82
                    domain_dict_from_db[domain] = DomainStatus.UP
×
83
                    self.db_instance.add_key_value_pair_to_db(
×
84
                        MongoCollections.DOMAINS,
85
                        Constants.DOMAIN_DICT,
86
                        domain_dict_from_db,
87
                    )
NEW
88
                return True
×
89

90
            logger.debug(f"[HeartbeatMonitor] Heartbeat recorded for {domain}")
×
NEW
91
            return False
×
92

93
    def check_status(self):
2✔
94
        """Mark domains as UNKNOWN if heartbeats are missing."""
95
        now = time.time()
2✔
96
        with self.lock:
2✔
97
            for domain, last_time in self.last_heartbeat.items():
2✔
98
                if now - last_time > HEARTBEAT_TOLERANCE * HEARTBEAT_INTERVAL:
×
99
                    if self.domain_status.get(domain) != DomainStatus.UNKNOWN:
×
100
                        logger.warning(
×
101
                            f"[HeartbeatMonitor] Domain {domain} marked UNKNOWN (missed {HEARTBEAT_TOLERANCE} heartbeats)"
102
                        )
103
                        self.domain_status[domain] = DomainStatus.UNKNOWN
×
104

105
                        domain_dict_from_db = self.db_instance.get_value_from_db(
×
106
                            MongoCollections.DOMAINS, Constants.DOMAIN_DICT
107
                        )
108
                        if domain_dict_from_db and domain in domain_dict_from_db:
×
109
                            domain_dict_from_db[domain] = DomainStatus.UNKNOWN
×
110
                            self.db_instance.add_key_value_pair_to_db(
×
111
                                MongoCollections.DOMAINS,
112
                                Constants.DOMAIN_DICT,
113
                                domain_dict_from_db,
114
                            )
115

116
    def get_status(self, domain):
2✔
117
        """Return the current status of a domain."""
118
        with self.lock:
×
119
            return self.domain_status.get(domain, "unknown")
×
120

121
    def start_monitoring(self):
2✔
122
        """Start a background thread to monitor heartbeat status."""
123
        if self.monitoring:
2✔
124
            return
×
125
        self.monitoring = True
2✔
126
        logger.info("[HeartbeatMonitor] Started monitoring heartbeats.")
2✔
127

128
        def monitor_loop():
2✔
129
            while self.monitoring:
2✔
130
                self.check_status()
2✔
131
                time.sleep(HEARTBEAT_INTERVAL)
2✔
132

133
        t = threading.Thread(target=monitor_loop, daemon=True)
2✔
134
        t.start()
2✔
135

136

137
class RpcConsumer(object):
2✔
138
    def __init__(self, thread_queue, exchange_name, te_manager, queue_name=SUB_QUEUE):
2✔
139
        self.logger = logging.getLogger(__name__)
2✔
140

141
        self.logger.info(f"[MQ] Using amqp://{MQ_USER}@{MQ_HOST}:{MQ_PORT}")
2✔
142
        self.exchange_name = exchange_name
2✔
143
        self.queue_name = queue_name
2✔
144
        self._thread_queue = thread_queue
2✔
145

146
        self.te_manager = te_manager
2✔
147

148
        self._exit_event = threading.Event()
2✔
149
        self._connect()
2✔
150

151
    def _connect(self):
2✔
152
        self.connection = pika.BlockingConnection(
2✔
153
            pika.ConnectionParameters(
154
                host=MQ_HOST,
155
                port=MQ_PORT,
156
                credentials=pika.PlainCredentials(username=MQ_USER, password=MQ_PASS),
157
            )
158
        )
159
        self.channel = self.connection.channel()
2✔
160

161
        # RabbitMQ no longer permits transient non-exclusive queues by default.
162
        # This shared controller queue should be durable so it remains compatible
163
        # with newer broker defaults.
164
        self.channel.queue_declare(queue=self.queue_name, durable=True)
2✔
165

166
    def on_request(self, ch, method, props, message_body):
2✔
167
        response = message_body
×
168
        self._thread_queue.put(message_body)
×
169

170
        self.connection = pika.BlockingConnection(
×
171
            pika.ConnectionParameters(
172
                host=MQ_HOST,
173
                port=MQ_PORT,
174
                credentials=pika.PlainCredentials(username=MQ_USER, password=MQ_PASS),
175
            )
176
        )
177
        self.channel = self.connection.channel()
×
178

179
        try:
×
NEW
180
            if props.reply_to:
×
NEW
181
                ch.basic_publish(
×
182
                    exchange=self.exchange_name,
183
                    routing_key=props.reply_to,
184
                    properties=pika.BasicProperties(
185
                        correlation_id=props.correlation_id
186
                    ),
187
                    body=str(response),
188
                )
189
            ch.basic_ack(delivery_tag=method.delivery_tag)
×
190
        except Exception as err:
×
191
            self.logger.info(f"[MQ] encountered error when publishing: {err}")
×
192

193
    def start_consumer(self):
2✔
194
        while not self._exit_event.is_set():
2✔
195
            try:
2✔
196
                if self.connection.is_closed or self.channel.is_closed:
2✔
NEW
197
                    self._connect()
×
198
                self.channel.basic_qos(prefetch_count=1)
2✔
199
                self.channel.basic_consume(
2✔
200
                    queue=self.queue_name, on_message_callback=self.on_request
201
                )
202

203
                self.logger.info(
2✔
204
                    " [MQ] Awaiting requests from queue: " + self.queue_name
205
                )
206
                self.channel.start_consuming()
2✔
NEW
207
            except Exception as err:
×
NEW
208
                self.logger.warning(
×
209
                    f"[MQ] Consumer for queue {self.queue_name} disconnected: {err}"
210
                )
NEW
211
                try:
×
NEW
212
                    self.connection.close()
×
NEW
213
                except Exception:
×
NEW
214
                    pass
×
NEW
215
                time.sleep(5)
×
NEW
216
                try:
×
NEW
217
                    self._connect()
×
NEW
218
                except Exception as reconnect_err:
×
NEW
219
                    self.logger.warning(
×
220
                        f"[MQ] Reconnect failed for queue {self.queue_name}: {reconnect_err}"
221
                    )
NEW
222
                    time.sleep(5)
×
223

224
    def start_sdx_consumer(self, thread_queue, db_instance):
2✔
225
        rpc = RpcConsumer(thread_queue, "", self.te_manager)
2✔
226
        t1 = threading.Thread(target=rpc.start_consumer, args=(), daemon=True)
2✔
227
        t1.start()
2✔
228
        heartbeat_rpc = RpcConsumer(
2✔
229
            thread_queue, "", self.te_manager, queue_name=MessageQueueNames.HEARTBEATS
230
        )
231
        heartbeat_thread = threading.Thread(
2✔
232
            target=heartbeat_rpc.start_consumer, args=(), daemon=True
233
        )
234
        heartbeat_thread.start()
2✔
235

236
        lc_message_handler = LcMessageHandler(db_instance, self.te_manager)
2✔
237
        parse_helper = ParseHelper()
2✔
238

239
        heartbeat_monitor = HeartbeatMonitor(db_instance)
2✔
240
        heartbeat_monitor.start_monitoring()
2✔
241

242
        latest_topo = {}
2✔
243
        domain_dict = {}
2✔
244

245
        # This part reads from DB when SDX controller initially starts.
246
        # It looks for domain_dict, if already in DB,
247
        # Then use the existing ones from DB.
248
        domain_dict_from_db = db_instance.get_value_from_db(
2✔
249
            MongoCollections.DOMAINS, Constants.DOMAIN_DICT
250
        )
251
        latest_topo_from_db = db_instance.get_value_from_db(
2✔
252
            MongoCollections.TOPOLOGIES, Constants.LATEST_TOPOLOGY
253
        )
254

255
        if domain_dict_from_db:
2✔
256
            domain_dict = domain_dict_from_db
×
257
            logger.debug("Domain list already exists in db: ")
×
258
            logger.debug(domain_dict)
×
259

260
        residul_bw = {}
2✔
261
        if latest_topo_from_db:
2✔
262
            latest_topo = latest_topo_from_db
2✔
263
            logger.debug("Topology already exists in db: ")
2✔
264
            # logger.debug(latest_topo)
265
            update_topology_manager = TopologyManager()
2✔
266
            update_topology_manager.add_topology(latest_topo)
2✔
267
            residul_bw = update_topology_manager.get_residul_bandwidth()
2✔
268
            logger.debug(residul_bw)
2✔
269

270
        # If topologies already saved in db, use them to initialize te_manager
271
        if domain_dict:
2✔
272
            for domain in domain_dict.keys():
×
273
                topology = db_instance.get_value_from_db(
×
274
                    MongoCollections.TOPOLOGIES, SDX_TOPOLOGY_ID_prefix + domain
275
                )
276

277
                if not topology:
×
278
                    continue
×
279

280
                # Get the actual thing minus the Mongo ObjectID.
281
                self.te_manager.add_topology(topology)
×
282
                logger.debug(f"Read {domain}: {topology}")
×
283
            # update topology/pce state in TE Manager
284

285
            graph = self.te_manager.generate_graph_te()
×
286
            logger.debug(f"restart graph = {graph.nodes};{graph.edges}")
×
287
            connections = db_instance.get_all_entries_in_collection(
×
288
                MongoCollections.CONNECTIONS
289
            )
290
            if not connections:
×
291
                logger.info("No connection was found")
×
292
            else:
293
                for connection in connections:
×
294
                    service_id = next(iter(connection))
×
295
                    status = get_connection_status(db_instance, service_id)
×
296
                    logger.info(
×
297
                        f"Restart: service_id: {service_id}, status: {status.get(service_id)}"
298
                    )
299
                    # 1. update the vlan tables in pce
300
                    domain_breakdown = db_instance.get_value_from_db(
×
301
                        MongoCollections.BREAKDOWNS, service_id
302
                    )
303
                    if not domain_breakdown:
×
304
                        logger.warning(f"Could not find breakdown for {service_id}")
×
305
                        continue
×
306
                    try:
×
307
                        vlan_tags_table = self.te_manager.vlan_tags_table
×
308
                        for domain, segment in domain_breakdown.items():
×
309
                            logger.debug(f"domain:{domain};segment:{segment}")
×
310
                            domain_table = vlan_tags_table.get(domain.split("__", 1)[0])
×
311
                            uni_a = segment.get("uni_a")
×
312
                            vlan_table = domain_table.get(uni_a.get("port_id"))
×
313
                            vlan_table[uni_a.get("tag").get("value")] = service_id
×
314
                            uni_z = segment.get("uni_z")
×
315
                            vlan_table = domain_table.get(uni_z.get("port_id"))
×
316
                            vlan_table[uni_z.get("tag").get("value")] = service_id
×
317
                    except Exception as e:
×
318
                        err = traceback.format_exc().replace("\n", ", ")
×
319
                        logger.error(
×
320
                            f"Error when recovering breakdown vlan assignment: {e} - {err}"
321
                        )
322
                        return f"Error: {e}", 410
×
323
            logger.debug(f"Restart: solutions for {connections}")
×
324
            connectionSolution_list = self.te_manager.connectionSolution_list
×
325
            connections = db_instance.get_all_entries_in_collection(
×
326
                MongoCollections.CONNECTIONS
327
            )
328
            if not connections:
×
329
                logger.info("No connection was found")
×
330
            else:
331
                for connection in connections:
×
332
                    try:
×
333
                        service_id = next(iter(connection))
×
334
                        response = get_connection_status(db_instance, service_id)
×
335
                        if not response:
×
336
                            continue
×
337
                        qos_metrics = response[service_id].get("qos_metrics")
×
338
                        if not qos_metrics:
×
339
                            continue
×
340
                        min_bw = qos_metrics.get("min_bw", {"value": 0.0}).get(
×
341
                            "value", 0
342
                        )
343
                        logger.debug(f"service_id:{service_id}, {response}")
×
344
                        solution_links = db_instance.get_value_from_db(
×
345
                            MongoCollections.SOLUTIONS, service_id
346
                        )
347
                        logger.debug(
×
348
                            f"service_id:{service_id};solution:{solution_links}"
349
                        )
350
                        if not solution_links:
×
351
                            logger.warning(
×
352
                                f"Could not find solution in DB for {service_id}"
353
                            )
354
                            continue
×
355
                        links = []
×
356
                        for link in solution_links:
×
357
                            source_node = self.te_manager.topology_manager.get_topology().get_node_by_port(
×
358
                                link.get("source")
359
                            )
360
                            destination_node = self.te_manager.topology_manager.get_topology().get_node_by_port(
×
361
                                link.get("destination")
362
                            )
363
                            source = [
×
364
                                x
365
                                for x, y in graph.nodes(data=True)
366
                                if y["id"] == source_node.id
367
                            ]
368

369
                            destination = [
×
370
                                x
371
                                for x, y in graph.nodes(data=True)
372
                                if y["id"] == destination_node.id
373
                            ]
374
                            links.append(
×
375
                                {"source": source[0], "destination": destination[0]}
376
                            )
377
                        # rebuild solution object
378
                        request = ConnectionRequest(
×
379
                            source=0,
380
                            destination=0,
381
                            required_bandwidth=min_bw,
382
                            required_latency=float("inf"),
383
                        )
384
                        link_map = [
×
385
                            ConnectionPath(link.get("source"), link.get("destination"))
386
                            for link in links
387
                        ]
388
                        solution = ConnectionSolution(
×
389
                            connection_map={request: link_map},
390
                            cost=0,
391
                            request_id=service_id,
392
                        )
393
                        connectionSolution_list.append(solution)
×
394
                    except Exception as e:
×
395
                        err = traceback.format_exc().replace("\n", ", ")
×
396
                        logger.error(
×
397
                            f"Error when recovering solution list: {e} - {err}"
398
                        )
399
                        return f"Error: {e}", 410
×
400
            logger.debug(f"Restart: residul_bw")
×
401
            if residul_bw:
×
402
                self.te_manager.update_available_bw_in_topology(residul_bw)
×
403

404
        while not self._exit_event.is_set():
2✔
405
            msg = thread_queue.get()
2✔
406
            logger.debug("MQ received message:" + str(msg))
×
407

408
            if not parse_helper.is_json(msg):
×
409
                logger.debug("Non JSON message, ignored")
×
410
                continue
×
411

412
            msg_json = json.loads(msg)
×
413
            if "type" in msg_json and msg_json.get("type") == "Heart Beat":
×
414
                domain = msg_json.get("domain")
×
NEW
415
                if not domain:
×
NEW
416
                    continue
×
NEW
417
                recovered = heartbeat_monitor.record_heartbeat(domain)
×
418
                logger.debug(f"Heart beat received from {domain}")
×
NEW
419
                if recovered:
×
NEW
420
                    recovered_count = lc_message_handler.connection_handler.recover_domain_connections(
×
421
                        domain
422
                    )
NEW
423
                    logger.info(
×
424
                        f"Recovered LC domain {domain}; republished {recovered_count} affected L2VPN breakdowns."
425
                    )
UNCOV
426
                continue
×
427

428
            try:
×
429
                lc_message_handler.process_lc_json_msg(
×
430
                    msg,
431
                    latest_topo,
432
                    domain_dict,
433
                )
434
            except Exception as exc:
×
435
                err = traceback.format_exc().replace("\n", ", ")
×
436
                logger.error(f"Failed to process LC message: {exc} -- {err}")
×
437

438
    def stop_threads(self):
2✔
439
        """
440
        Signal threads that we're ready to stop.
441
        """
442
        logger.info("[MQ] Stopping threads.")
×
443
        self.channel.stop_consuming()
×
444
        self._exit_event.set()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc