• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

atlanticwave-sdx / sdx-controller / 20110929781

10 Dec 2025 07:33PM UTC coverage: 53.895% (-1.4%) from 55.33%
20110929781

Pull #498

github

web-flow
Merge d8fa3a974 into 51b6778fc
Pull Request #498: when restart, sync the toplogy vlan/bw state from connections out of db

11 of 81 new or added lines in 2 files covered. (13.58%)

1 existing line in 1 file now uncovered.

1266 of 2349 relevant lines covered (53.9%)

1.08 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

41.86
/sdx_controller/messaging/rpc_queue_consumer.py
1
#!/usr/bin/env python
2
import json
2✔
3
import logging
2✔
4
import os
2✔
5
import threading
2✔
6
import time
2✔
7
import traceback
2✔
8
from queue import Queue
2✔
9

10
import pika
2✔
11
from sdx_datamodel.constants import (
2✔
12
    Constants,
13
    DomainStatus,
14
    MessageQueueNames,
15
    MongoCollections,
16
)
17
from sdx_datamodel.models.topology import SDX_TOPOLOGY_ID_prefix
2✔
18
from sdx_pce.models import ConnectionPath, ConnectionRequest, ConnectionSolution
2✔
19
from sdx_pce.topology.manager import TopologyManager
2✔
20

21
from sdx_controller.handlers.connection_handler import (
2✔
22
    ConnectionHandler,
23
    connection_state_machine,
24
    get_connection_status,
25
    parse_conn_status,
26
)
27
from sdx_controller.handlers.lc_message_handler import LcMessageHandler
2✔
28
from sdx_controller.models import connection
2✔
29
from sdx_controller.utils.parse_helper import ParseHelper
2✔
30

31
MQ_HOST = os.getenv("MQ_HOST")
2✔
32
MQ_PORT = os.getenv("MQ_PORT") or 5672
2✔
33
MQ_USER = os.getenv("MQ_USER") or "guest"
2✔
34
MQ_PASS = os.getenv("MQ_PASS") or "guest"
2✔
35
HEARTBEAT_INTERVAL = int(os.getenv("HEARTBEAT_INTERVAL", 10))  # seconds
2✔
36
HEARTBEAT_TOLERANCE = int(
2✔
37
    os.getenv("HEARTBEAT_TOLERANCE", 3)
38
)  # consecutive missed heartbeats allowed
39

40

41
# subscribe to the corresponding queue
42
SUB_QUEUE = MessageQueueNames.OXP_UPDATE
2✔
43

44
logger = logging.getLogger(__name__)
2✔
45

46
MongoCollections.SOLUTIONS = "solutions"
2✔
47

48

49
class HeartbeatMonitor:
2✔
50
    def __init__(self, db_instance):
2✔
51
        self.last_heartbeat = {}  # domain -> last heartbeat timestamp
2✔
52
        self.domain_status = {}  # domain -> current status (UP / UNKNOWN)
2✔
53
        self.lock = threading.Lock()
2✔
54
        self.monitoring = False
2✔
55
        self.db_instance = db_instance  # store DB instance
2✔
56

57
    def record_heartbeat(self, domain):
2✔
58
        """Record heartbeat from a domain and mark it as UP if previously UNKNOWN."""
59
        with self.lock:
×
60
            self.last_heartbeat[domain] = time.time()
×
61

62
            previous_status = self.domain_status.get(domain)
×
63
            self.domain_status[domain] = DomainStatus.UP
×
64

65
            # Update DB if status changed from UNKNOWN -> UP
66
            if previous_status == DomainStatus.UNKNOWN:
×
67
                logger.info(
×
68
                    f"[HeartbeatMonitor] Domain {domain} is BACK UP after missed heartbeats."
69
                )
70
                domain_dict_from_db = self.db_instance.get_value_from_db(
×
71
                    MongoCollections.DOMAINS, Constants.DOMAIN_DICT
72
                )
73
                if domain in domain_dict_from_db:
×
74
                    domain_dict_from_db[domain] = DomainStatus.UP
×
75
                    self.db_instance.add_key_value_pair_to_db(
×
76
                        MongoCollections.DOMAINS,
77
                        Constants.DOMAIN_DICT,
78
                        domain_dict_from_db,
79
                    )
80

81
            logger.debug(f"[HeartbeatMonitor] Heartbeat recorded for {domain}")
×
82

83
    def check_status(self):
2✔
84
        """Mark domains as UNKNOWN if heartbeats are missing."""
85
        now = time.time()
2✔
86
        with self.lock:
2✔
87
            for domain, last_time in self.last_heartbeat.items():
2✔
88
                if now - last_time > HEARTBEAT_TOLERANCE * HEARTBEAT_INTERVAL:
×
89
                    if self.domain_status.get(domain) != DomainStatus.UNKNOWN:
×
90
                        logger.warning(
×
91
                            f"[HeartbeatMonitor] Domain {domain} marked UNKNOWN (missed {HEARTBEAT_TOLERANCE} heartbeats)"
92
                        )
93
                        self.domain_status[domain] = DomainStatus.UNKNOWN
×
94

95
                        domain_dict_from_db = self.db_instance.get_value_from_db(
×
96
                            MongoCollections.DOMAINS, Constants.DOMAIN_DICT
97
                        )
98
                        if domain in domain_dict_from_db:
×
99
                            domain_dict_from_db[domain] = DomainStatus.UNKNOWN
×
100
                            self.db_instance.add_key_value_pair_to_db(
×
101
                                MongoCollections.DOMAINS,
102
                                Constants.DOMAIN_DICT,
103
                                domain_dict_from_db,
104
                            )
105

106
    def get_status(self, domain):
2✔
107
        """Return the current status of a domain."""
108
        with self.lock:
×
109
            return self.domain_status.get(domain, "unknown")
×
110

111
    def start_monitoring(self):
2✔
112
        """Start a background thread to monitor heartbeat status."""
113
        if self.monitoring:
2✔
114
            return
×
115
        self.monitoring = True
2✔
116
        logger.info("[HeartbeatMonitor] Started monitoring heartbeats.")
2✔
117

118
        def monitor_loop():
2✔
119
            while self.monitoring:
2✔
120
                self.check_status()
2✔
121
                time.sleep(HEARTBEAT_INTERVAL)
2✔
122

123
        t = threading.Thread(target=monitor_loop, daemon=True)
2✔
124
        t.start()
2✔
125

126

127
class RpcConsumer(object):
2✔
128
    def __init__(self, thread_queue, exchange_name, te_manager):
2✔
129
        self.logger = logging.getLogger(__name__)
2✔
130

131
        self.logger.info(f"[MQ] Using amqp://{MQ_USER}@{MQ_HOST}:{MQ_PORT}")
2✔
132

133
        self.connection = pika.BlockingConnection(
2✔
134
            pika.ConnectionParameters(
135
                host=MQ_HOST,
136
                port=MQ_PORT,
137
                credentials=pika.PlainCredentials(username=MQ_USER, password=MQ_PASS),
138
            )
139
        )
140

141
        self.channel = self.connection.channel()
2✔
142
        self.exchange_name = exchange_name
2✔
143

144
        self.channel.queue_declare(queue=SUB_QUEUE)
2✔
145
        self._thread_queue = thread_queue
2✔
146

147
        self.te_manager = te_manager
2✔
148

149
        self._exit_event = threading.Event()
2✔
150

151
    def on_request(self, ch, method, props, message_body):
2✔
152
        response = message_body
×
153
        self._thread_queue.put(message_body)
×
154

155
        self.connection = pika.BlockingConnection(
×
156
            pika.ConnectionParameters(
157
                host=MQ_HOST,
158
                port=MQ_PORT,
159
                credentials=pika.PlainCredentials(username=MQ_USER, password=MQ_PASS),
160
            )
161
        )
162
        self.channel = self.connection.channel()
×
163

164
        try:
×
165
            ch.basic_publish(
×
166
                exchange=self.exchange_name,
167
                routing_key=props.reply_to,
168
                properties=pika.BasicProperties(correlation_id=props.correlation_id),
169
                body=str(response),
170
            )
171
            ch.basic_ack(delivery_tag=method.delivery_tag)
×
172
        except Exception as err:
×
173
            self.logger.info(f"[MQ] encountered error when publishing: {err}")
×
174

175
    def start_consumer(self):
2✔
176
        self.channel.basic_qos(prefetch_count=1)
2✔
177
        self.channel.basic_consume(queue=SUB_QUEUE, on_message_callback=self.on_request)
2✔
178

179
        self.logger.info(" [MQ] Awaiting requests from queue: " + SUB_QUEUE)
2✔
180
        self.channel.start_consuming()
2✔
181

182
    def start_sdx_consumer(self, thread_queue, db_instance):
2✔
183
        rpc = RpcConsumer(thread_queue, "", self.te_manager)
2✔
184
        t1 = threading.Thread(target=rpc.start_consumer, args=(), daemon=True)
2✔
185
        t1.start()
2✔
186

187
        lc_message_handler = LcMessageHandler(db_instance, self.te_manager)
2✔
188
        parse_helper = ParseHelper()
2✔
189

190
        heartbeat_monitor = HeartbeatMonitor(db_instance)
2✔
191
        heartbeat_monitor.start_monitoring()
2✔
192

193
        latest_topo = {}
2✔
194
        domain_dict = {}
2✔
195

196
        # This part reads from DB when SDX controller initially starts.
197
        # It looks for domain_dict, if already in DB,
198
        # Then use the existing ones from DB.
199
        domain_dict_from_db = db_instance.get_value_from_db(
2✔
200
            MongoCollections.DOMAINS, Constants.DOMAIN_DICT
201
        )
202
        latest_topo_from_db = db_instance.get_value_from_db(
2✔
203
            MongoCollections.TOPOLOGIES, Constants.LATEST_TOPOLOGY
204
        )
205

206
        if domain_dict_from_db:
2✔
207
            domain_dict = domain_dict_from_db
×
208
            logger.debug("Domain list already exists in db: ")
×
209
            logger.debug(domain_dict)
×
210

211
        residul_bw = {}
2✔
212
        if latest_topo_from_db:
2✔
213
            latest_topo = latest_topo_from_db
2✔
214
            logger.debug("Topology already exists in db: ")
2✔
215
            # logger.debug(latest_topo)
216
            update_topology_manager = TopologyManager()
2✔
217
            update_topology_manager.add_topology(latest_topo)
2✔
218
            residul_bw = update_topology_manager.get_residul_bandwidth()
2✔
219
            logger.debug(residul_bw)
2✔
220

221
        # If topologies already saved in db, use them to initialize te_manager
222
        if domain_dict:
2✔
223
            for domain in domain_dict.keys():
×
224
                topology = db_instance.get_value_from_db(
×
225
                    MongoCollections.TOPOLOGIES, SDX_TOPOLOGY_ID_prefix + domain
226
                )
227

228
                if not topology:
×
229
                    continue
×
230

231
                # Get the actual thing minus the Mongo ObjectID.
232
                self.te_manager.add_topology(topology)
×
233
                logger.debug(f"Read {domain}: {topology}")
×
234
            # update topology/pce state in TE Manager
235

NEW
236
            graph = self.te_manager.generate_graph_te()
×
NEW
237
            logger.debug(f"restart graph = {graph.nodes};{graph.edges}")
×
NEW
238
            connections = db_instance.get_all_entries_in_collection(
×
239
                MongoCollections.CONNECTIONS
240
            )
NEW
241
            if not connections:
×
NEW
242
                logger.info("No connection was found")
×
243
            else:
NEW
244
                for connection in connections:
×
NEW
245
                    service_id = next(iter(connection))
×
NEW
246
                    status = get_connection_status(db_instance, service_id)
×
NEW
247
                    logger.info(
×
248
                        f"Restart: service_id: {service_id}, status: {status.get(service_id)}"
249
                    )
250
                    # 1. update the vlan tables in pce
NEW
251
                    domain_breakdown = db_instance.get_value_from_db(
×
252
                        MongoCollections.BREAKDOWNS, service_id
253
                    )
NEW
254
                    if not domain_breakdown:
×
NEW
255
                        logger.warning(f"Could not find breakdown for {service_id}")
×
NEW
256
                        continue
×
NEW
257
                    try:
×
NEW
258
                        vlan_tags_table = self.te_manager.vlan_tags_table
×
NEW
259
                        for domain, segment in domain_breakdown.items():
×
NEW
260
                            logger.debug(f"domain:{domain};segment:{segment}")
×
NEW
261
                            domain_table = vlan_tags_table.get(domain)
×
NEW
262
                            uni_a = segment.get("uni_a")
×
NEW
263
                            vlan_table = domain_table.get(uni_a.get("port_id"))
×
NEW
264
                            vlan_table[uni_a.get("tag").get("value")] = service_id
×
NEW
265
                            uni_z = segment.get("uni_z")
×
NEW
266
                            vlan_table = domain_table.get(uni_z.get("port_id"))
×
NEW
267
                            vlan_table[uni_z.get("tag").get("value")] = service_id
×
NEW
268
                    except Exception as e:
×
NEW
269
                        err = traceback.format_exc().replace("\n", ", ")
×
NEW
270
                        logger.error(
×
271
                            f"Error when recovering breakdown vlan assignment: {e} - {err}"
272
                        )
NEW
273
                        return f"Error: {e}", 410
×
NEW
274
            logger.debug(f"Restart: solutions for {connections}")
×
NEW
275
            connectionSolution_list = self.te_manager.connectionSolution_list
×
NEW
276
            connections = db_instance.get_all_entries_in_collection(
×
277
                MongoCollections.CONNECTIONS
278
            )
NEW
279
            if not connections:
×
NEW
280
                logger.info("No connection was found")
×
281
            else:
NEW
282
                for connection in connections:
×
NEW
283
                    try:
×
NEW
284
                        service_id = next(iter(connection))
×
NEW
285
                        response = get_connection_status(db_instance, service_id)
×
NEW
286
                        if not response:
×
NEW
287
                            continue
×
NEW
288
                        qos_metrics = response[service_id].get("qos_metrics")
×
NEW
289
                        if not qos_metrics:
×
NEW
290
                            continue
×
NEW
291
                        min_bw = qos_metrics.get("min_bw", {"value": 0.0}).get(
×
292
                            "value", 0
293
                        )
NEW
294
                        logger.debug(f"service_id:{service_id}, {response}")
×
NEW
295
                        solution_links = db_instance.get_value_from_db(
×
296
                            MongoCollections.SOLUTIONS, service_id
297
                        )
NEW
298
                        logger.debug(
×
299
                            f"service_id:{service_id};solution:{solution_links}"
300
                        )
NEW
301
                        if not solution_links:
×
NEW
302
                            logger.warning(
×
303
                                f"Could not find solution in DB for {service_id}"
304
                            )
NEW
305
                            continue
×
NEW
306
                        links = []
×
NEW
307
                        for link in solution_links:
×
NEW
308
                            source_node = self.te_manager.topology_manager.get_topology().get_node_by_port(
×
309
                                link.get("source")
310
                            )
NEW
311
                            destination_node = self.te_manager.topology_manager.get_topology().get_node_by_port(
×
312
                                link.get("destination")
313
                            )
NEW
314
                            source = [
×
315
                                x
316
                                for x, y in graph.nodes(data=True)
317
                                if y["id"] == source_node.id
318
                            ]
319

NEW
320
                            destination = [
×
321
                                x
322
                                for x, y in graph.nodes(data=True)
323
                                if y["id"] == destination_node.id
324
                            ]
NEW
325
                            links.append(
×
326
                                {"source": source[0], "destination": destination[0]}
327
                            )
328
                        # rebuild solution object
NEW
329
                        request = ConnectionRequest(
×
330
                            source=0,
331
                            destination=0,
332
                            required_bandwidth=min_bw,
333
                            required_latency=float("inf"),
334
                        )
NEW
335
                        link_map = [
×
336
                            ConnectionPath(link.get("source"), link.get("destination"))
337
                            for link in links
338
                        ]
NEW
339
                        solution = ConnectionSolution(
×
340
                            connection_map={request: link_map},
341
                            cost=0,
342
                            request_id=service_id,
343
                        )
NEW
344
                        connectionSolution_list.append(solution)
×
NEW
345
                    except Exception as e:
×
NEW
346
                        err = traceback.format_exc().replace("\n", ", ")
×
NEW
347
                        logger.error(
×
348
                            f"Error when recovering solution list: {e} - {err}"
349
                        )
NEW
350
                        return f"Error: {e}", 410
×
NEW
351
            logger.debug(f"Restart: residul_bw")
×
NEW
352
            if residul_bw:
×
NEW
353
                self.te_manager.update_available_bw_in_topology(residul_bw)
×
354

355
        while not self._exit_event.is_set():
2✔
356
            msg = thread_queue.get()
2✔
357
            logger.debug("MQ received message:" + str(msg))
×
358

359
            if not parse_helper.is_json(msg):
×
360
                logger.debug("Non JSON message, ignored")
×
361
                continue
×
362

363
            msg_json = json.loads(msg)
×
364
            if "type" in msg_json and msg_json.get("type") == "Heart Beat":
×
365
                domain = msg_json.get("domain")
×
366
                heartbeat_monitor.record_heartbeat(domain)
×
367
                logger.debug(f"Heart beat received from {domain}")
×
368
                continue
×
369

370
            try:
×
371
                lc_message_handler.process_lc_json_msg(
×
372
                    msg,
373
                    latest_topo,
374
                    domain_dict,
375
                )
376
            except Exception as exc:
×
377
                err = traceback.format_exc().replace("\n", ", ")
×
378
                logger.error(f"Failed to process LC message: {exc} -- {err}")
×
379

380
    def stop_threads(self):
2✔
381
        """
382
        Signal threads that we're ready to stop.
383
        """
384
        logger.info("[MQ] Stopping threads.")
×
385
        self.channel.stop_consuming()
×
386
        self._exit_event.set()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc