• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

atlanticwave-sdx / sdx-controller / 20010355409

07 Dec 2025 09:08PM UTC coverage: 54.809% (-0.5%) from 55.33%
20010355409

Pull #498

github

web-flow
Merge 5168cd797 into 51b6778fc
Pull Request #498: when restart, sync the toplogy vlan/bw state from connections out of db

10 of 40 new or added lines in 2 files covered. (25.0%)

68 existing lines in 1 file now uncovered.

1265 of 2308 relevant lines covered (54.81%)

1.1 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

50.85
/sdx_controller/messaging/rpc_queue_consumer.py
1
#!/usr/bin/env python
2
import json
2✔
3
import logging
2✔
4
import os
2✔
5
import threading
2✔
6
import time
2✔
7
import traceback
2✔
8
from queue import Queue
2✔
9

10
import pika
2✔
11
from sdx_datamodel.constants import (
2✔
12
    Constants,
13
    DomainStatus,
14
    MessageQueueNames,
15
    MongoCollections,
16
)
17
from sdx_datamodel.models.topology import SDX_TOPOLOGY_ID_prefix
2✔
18
from sdx_pce.models import ConnectionPath, ConnectionRequest, ConnectionSolution
2✔
19
from sdx_pce.topology.manager import TopologyManager
2✔
20

21
from sdx_controller.handlers.connection_handler import (
2✔
22
    ConnectionHandler,
23
    connection_state_machine,
24
    get_connection_status,
25
    parse_conn_status,
26
)
27
from sdx_controller.handlers.lc_message_handler import LcMessageHandler
2✔
28
from sdx_controller.models import connection
2✔
29
from sdx_controller.utils.parse_helper import ParseHelper
2✔
30

31
MQ_HOST = os.getenv("MQ_HOST")
2✔
32
MQ_PORT = os.getenv("MQ_PORT") or 5672
2✔
33
MQ_USER = os.getenv("MQ_USER") or "guest"
2✔
34
MQ_PASS = os.getenv("MQ_PASS") or "guest"
2✔
35
HEARTBEAT_INTERVAL = int(os.getenv("HEARTBEAT_INTERVAL", 10))  # seconds
2✔
36
HEARTBEAT_TOLERANCE = int(
2✔
37
    os.getenv("HEARTBEAT_TOLERANCE", 3)
38
)  # consecutive missed heartbeats allowed
39

40

41
# subscribe to the corresponding queue
42
SUB_QUEUE = MessageQueueNames.OXP_UPDATE
2✔
43

44
logger = logging.getLogger(__name__)
2✔
45

46
MongoCollections.SOLUTIONS = "solutions"
2✔
47

48

49
class HeartbeatMonitor:
2✔
50
    def __init__(self, db_instance):
2✔
51
        self.last_heartbeat = {}  # domain -> last heartbeat timestamp
2✔
52
        self.domain_status = {}  # domain -> current status (UP / UNKNOWN)
2✔
53
        self.lock = threading.Lock()
2✔
54
        self.monitoring = False
2✔
55
        self.db_instance = db_instance  # store DB instance
2✔
56

57
    def record_heartbeat(self, domain):
2✔
58
        """Record heartbeat from a domain and mark it as UP if previously UNKNOWN."""
59
        with self.lock:
×
60
            self.last_heartbeat[domain] = time.time()
×
61

62
            previous_status = self.domain_status.get(domain)
×
63
            self.domain_status[domain] = DomainStatus.UP
×
64

65
            # Update DB if status changed from UNKNOWN -> UP
66
            if previous_status == DomainStatus.UNKNOWN:
×
67
                logger.info(
×
68
                    f"[HeartbeatMonitor] Domain {domain} is BACK UP after missed heartbeats."
69
                )
70
                domain_dict_from_db = self.db_instance.get_value_from_db(
×
71
                    MongoCollections.DOMAINS, Constants.DOMAIN_DICT
72
                )
73
                if domain in domain_dict_from_db:
×
74
                    domain_dict_from_db[domain] = DomainStatus.UP
×
75
                    self.db_instance.add_key_value_pair_to_db(
×
76
                        MongoCollections.DOMAINS,
77
                        Constants.DOMAIN_DICT,
78
                        domain_dict_from_db,
79
                    )
80

81
            logger.debug(f"[HeartbeatMonitor] Heartbeat recorded for {domain}")
×
82

83
    def check_status(self):
2✔
84
        """Mark domains as UNKNOWN if heartbeats are missing."""
85
        now = time.time()
2✔
86
        with self.lock:
2✔
87
            for domain, last_time in self.last_heartbeat.items():
2✔
88
                if now - last_time > HEARTBEAT_TOLERANCE * HEARTBEAT_INTERVAL:
×
89
                    if self.domain_status.get(domain) != DomainStatus.UNKNOWN:
×
90
                        logger.warning(
×
91
                            f"[HeartbeatMonitor] Domain {domain} marked UNKNOWN (missed {HEARTBEAT_TOLERANCE} heartbeats)"
92
                        )
93
                        self.domain_status[domain] = DomainStatus.UNKNOWN
×
94

95
                        domain_dict_from_db = self.db_instance.get_value_from_db(
×
96
                            MongoCollections.DOMAINS, Constants.DOMAIN_DICT
97
                        )
98
                        if domain in domain_dict_from_db:
×
99
                            domain_dict_from_db[domain] = DomainStatus.UNKNOWN
×
100
                            self.db_instance.add_key_value_pair_to_db(
×
101
                                MongoCollections.DOMAINS,
102
                                Constants.DOMAIN_DICT,
103
                                domain_dict_from_db,
104
                            )
105

106
    def get_status(self, domain):
2✔
107
        """Return the current status of a domain."""
108
        with self.lock:
×
109
            return self.domain_status.get(domain, "unknown")
×
110

111
    def start_monitoring(self):
2✔
112
        """Start a background thread to monitor heartbeat status."""
113
        if self.monitoring:
2✔
114
            return
×
115
        self.monitoring = True
2✔
116
        logger.info("[HeartbeatMonitor] Started monitoring heartbeats.")
2✔
117

118
        def monitor_loop():
2✔
119
            while self.monitoring:
2✔
120
                self.check_status()
2✔
121
                time.sleep(HEARTBEAT_INTERVAL)
2✔
122

123
        t = threading.Thread(target=monitor_loop, daemon=True)
2✔
124
        t.start()
2✔
125

126

127
class RpcConsumer(object):
2✔
128
    def __init__(self, thread_queue, exchange_name, te_manager):
2✔
129
        self.logger = logging.getLogger(__name__)
2✔
130

131
        self.logger.info(f"[MQ] Using amqp://{MQ_USER}@{MQ_HOST}:{MQ_PORT}")
2✔
132

133
        self.connection = pika.BlockingConnection(
2✔
134
            pika.ConnectionParameters(
135
                host=MQ_HOST,
136
                port=MQ_PORT,
137
                credentials=pika.PlainCredentials(username=MQ_USER, password=MQ_PASS),
138
            )
139
        )
140

141
        self.channel = self.connection.channel()
2✔
142
        self.exchange_name = exchange_name
2✔
143

144
        self.channel.queue_declare(queue=SUB_QUEUE)
2✔
145
        self._thread_queue = thread_queue
2✔
146

147
        self.te_manager = te_manager
2✔
148

149
        self._exit_event = threading.Event()
2✔
150

151
    def on_request(self, ch, method, props, message_body):
2✔
152
        response = message_body
×
153
        self._thread_queue.put(message_body)
×
154

155
        self.connection = pika.BlockingConnection(
×
156
            pika.ConnectionParameters(
157
                host=MQ_HOST,
158
                port=MQ_PORT,
159
                credentials=pika.PlainCredentials(username=MQ_USER, password=MQ_PASS),
160
            )
161
        )
162
        self.channel = self.connection.channel()
×
163

164
        try:
×
165
            ch.basic_publish(
×
166
                exchange=self.exchange_name,
167
                routing_key=props.reply_to,
168
                properties=pika.BasicProperties(correlation_id=props.correlation_id),
169
                body=str(response),
170
            )
171
            ch.basic_ack(delivery_tag=method.delivery_tag)
×
172
        except Exception as err:
×
173
            self.logger.info(f"[MQ] encountered error when publishing: {err}")
×
174

175
    def start_consumer(self):
2✔
176
        self.channel.basic_qos(prefetch_count=1)
2✔
177
        self.channel.basic_consume(queue=SUB_QUEUE, on_message_callback=self.on_request)
2✔
178

179
        self.logger.info(" [MQ] Awaiting requests from queue: " + SUB_QUEUE)
2✔
180
        self.channel.start_consuming()
2✔
181

182
    def start_sdx_consumer(self, thread_queue, db_instance):
2✔
183
        rpc = RpcConsumer(thread_queue, "", self.te_manager)
2✔
184
        t1 = threading.Thread(target=rpc.start_consumer, args=(), daemon=True)
2✔
185
        t1.start()
2✔
186

187
        lc_message_handler = LcMessageHandler(db_instance, self.te_manager)
2✔
188
        parse_helper = ParseHelper()
2✔
189

190
        heartbeat_monitor = HeartbeatMonitor(db_instance)
2✔
191
        heartbeat_monitor.start_monitoring()
2✔
192

193
        latest_topo = {}
2✔
194
        domain_dict = {}
2✔
195

196
        # This part reads from DB when SDX controller initially starts.
197
        # It looks for domain_dict, if already in DB,
198
        # Then use the existing ones from DB.
199
        domain_dict_from_db = db_instance.get_value_from_db(
2✔
200
            MongoCollections.DOMAINS, Constants.DOMAIN_DICT
201
        )
202
        latest_topo_from_db = db_instance.get_value_from_db(
2✔
203
            MongoCollections.TOPOLOGIES, Constants.LATEST_TOPOLOGY
204
        )
205

206
        if domain_dict_from_db:
2✔
207
            domain_dict = domain_dict_from_db
×
208
            logger.debug("Domain list already exists in db: ")
×
209
            logger.debug(domain_dict)
×
210

211
        residul_bw = {}
2✔
212
        if latest_topo_from_db:
2✔
213
            latest_topo = latest_topo_from_db
2✔
214
            logger.debug("Topology already exists in db: ")
2✔
215
            # logger.debug(latest_topo)
216
            update_topology_manager = TopologyManager()
2✔
217
            update_topology_manager.add_topology(latest_topo)
2✔
218
            residul_bw = update_topology_manager.get_residul_bandwidth()
2✔
219
            logger.debug(residul_bw)
2✔
220

221
        # If topologies already saved in db, use them to initialize te_manager
222
        if domain_dict:
2✔
223
            for domain in domain_dict.keys():
×
224
                topology = db_instance.get_value_from_db(
×
225
                    MongoCollections.TOPOLOGIES, SDX_TOPOLOGY_ID_prefix + domain
226
                )
227

228
                if not topology:
×
229
                    continue
×
230

231
                # Get the actual thing minus the Mongo ObjectID.
232
                self.te_manager.add_topology(topology)
×
233
                logger.debug(f"Read {domain}: {topology}")
×
234
            # update topology/pce state in TE Manager
235

NEW
236
            graph = self.te_manager.generate_graph_te()
×
NEW
237
            connections = db_instance.get_all_entries_in_collection(
×
238
                MongoCollections.CONNECTIONS
239
            )
NEW
240
            if not connections:
×
NEW
241
                logger.info("No connection was found")
×
242
            else:
NEW
243
                for connection in connections:
×
NEW
244
                    service_id = next(iter(connection))
×
NEW
245
                    status = get_connection_status(db_instance, service_id)
×
NEW
246
                    logger.info(
×
247
                        f"Restart: service_id: {service_id}, status: {status.get(service_id)}"
248
                    )
249
                    # 1. update the vlan tables in pce
NEW
250
                    domain_breakdown = db_instance.get_value_from_db(
×
251
                        MongoCollections.BREAKDOWNS, service_id
252
                    )
NEW
253
                    if not domain_breakdown:
×
NEW
254
                        logger.warning(f"Could not find breakdown for {service_id}")
×
NEW
255
                        continue
×
NEW
256
                    try:
×
NEW
257
                        vlan_tags_table = self.te_manager.vlan_tags_table
×
NEW
258
                        for domain, segment in domain_breakdown.items():
×
NEW
259
                            logger.debug(f"domain:{domain};segment:{segment}")
×
NEW
260
                            domain_table = vlan_tags_table.get(domain)
×
NEW
261
                            uni_a = segment.get("uni_a")
×
NEW
262
                            vlan_table = domain_table.get(uni_a.get("port_id"))
×
NEW
263
                            vlan_table[uni_a.get("tag").get("value")] = service_id
×
NEW
264
                            uni_z = segment.get("uni_z")
×
NEW
265
                            vlan_table = domain_table.get(uni_z.get("port_id"))
×
NEW
266
                            vlan_table[uni_z.get("tag").get("value")] = service_id
×
NEW
267
                    except Exception as e:
×
NEW
268
                        err = traceback.format_exc().replace("\n", ", ")
×
NEW
269
                        logger.error(
×
270
                            f"Error when recovering breakdown vlan assignment: {e} - {err}"
271
                        )
NEW
272
                        return f"Error: {e}", 410
×
NEW
273
            if residul_bw:
×
NEW
274
                self.te_manager.update_available_bw_in_topology(residul_bw)
×
275

276
        while not self._exit_event.is_set():
2✔
277
            msg = thread_queue.get()
2✔
278
            logger.debug("MQ received message:" + str(msg))
×
279

280
            if not parse_helper.is_json(msg):
×
281
                logger.debug("Non JSON message, ignored")
×
282
                continue
×
283

284
            msg_json = json.loads(msg)
×
285
            if "type" in msg_json and msg_json.get("type") == "Heart Beat":
×
286
                domain = msg_json.get("domain")
×
287
                heartbeat_monitor.record_heartbeat(domain)
×
288
                logger.debug(f"Heart beat received from {domain}")
×
289
                continue
×
290

291
            try:
×
292
                lc_message_handler.process_lc_json_msg(
×
293
                    msg,
294
                    latest_topo,
295
                    domain_dict,
296
                )
297
            except Exception as exc:
×
298
                err = traceback.format_exc().replace("\n", ", ")
×
299
                logger.error(f"Failed to process LC message: {exc} -- {err}")
×
300

301
    def stop_threads(self):
2✔
302
        """
303
        Signal threads that we're ready to stop.
304
        """
305
        logger.info("[MQ] Stopping threads.")
×
306
        self.channel.stop_consuming()
×
307
        self._exit_event.set()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc